General Backtracking Algorithm
- The General Format
- Example: Permute
- Example: Combination Sum
- Example: Letter Combinations of a Phone Number
I’ll describe a general format for solving backtracking problems on Leetcode, or in interviews. Not all of the parts are needed all of the time, and sometimes it’s challenging to figure out where to put a piece of logic, but so far, this has been working well for me.
The General Format
This is the general format of the backtracking solution. The state
variable
is a particular solution in the search space that is being built up one unit at
a time with each call to the solve(state)
function. The
get_candidates(state)
function generates these new additional units of a
particular solution with each call to the solve(state)
function. The
is_solution(state)
function is our escape hatch out of this recursive
process. The process_solution(state)
function does any transformations
required to change the state object into the expected format of a solution.
This usually means collapsing a list of characters into a string or something.
The is_valid(state)
function is an additional filter to skip any invalid
solution candidates. (This could be wrapped into the get_candidates(state)
function, but having it separate reminds you about this possibility.)
from typing import List
class Solution:
def problem(self, nums: List[int]) -> List[List[int]]:
def is_solution(state) -> bool:
pass
def process_solution(state) -> None:
pass
def get_candidates(state) -> List:
pass
def is_valid(state) -> bool:
pass
def solve(state):
if is_solution(state):
process_solution(state)
return
candidates = get_candidates(state)
for candidate in candidates:
if is_valid(state + [candidate]):
solve(state + [candidate])
solutions = []
solve([])
return solutions
nums = [1,2,3] # whatever input
Solution().problem(nums)
Example: Permute
Problem statement: Given an array nums of distinct integers, return all the possible permutations. You can return the answer in any order.
In this case, the is_valid(state)
function always returns True
. The
get_candidates(state)
function provides a list of integers not already
present in the current state, since repeats are not allowed in permutations.
The is_solution(state)
function pops out of the recursion if we have a state
that is the same length of the original list; there’s no need for any further
checks.
class Solution:
def permute(self, nums: List[int]) -> List[List[int]]:
def is_solution(state):
if len(state) == len(nums):
return True
return False
def process_solution(state):
solutions.append(state)
def get_candidates(state):
candidates = [i for i in nums if i not in state]
return candidates
def is_valid(state):
return True
def solve(state):
if is_solution(state):
process_solution(state)
return
candidates = get_candidates(state)
for candidate in candidates:
if is_valid(state + [candidate]):
solve(state + [candidate])
solutions = []
solve([])
return solutions
Solution().permute([1, 2, 3])
# >>> [[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]
Example: Combination Sum
Problem statement: Combination Sum — Given an array of distinct integers candidates and a target integer target, return a list of all unique combinations of candidates where the chosen numbers sum to target. You may return the combinations in any order. The same number may be chosen from candidates an unlimited number of times. Two combinations are unique if the frequency of at least one of the chosen numbers is different.
In this case, is_valid(state)
is checking to make sure we’re not “re-trying”
any solutions we’ve already attempted.
class Solution:
def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
def is_solution(state):
if sum(state) == target:
return True
return False
def process_solution(state):
solutions.append(state)
def get_candidates(state):
items = [i for i in candidates if sum(state) + i <= target]
return items
def is_valid(state):
state = sorted(state)
if state in attempts:
return False
attempts.append(state)
return True
def solve(state):
"""
General backtracking algorithm.
"""
if is_solution(state):
process_solution(state)
return
candidates = get_candidates(state)
for candidate in candidates:
if is_valid(state + [candidate]):
solve(state + [candidate])
attempts = []
solutions = []
solve([])
return solutions
Solution().combinationSum([2,3,6,7], 7)
# >>> [[2, 2, 3], [7]]
Example: Letter Combinations of a Phone Number
Problem statement: Letter Combinations of a Phone Number — Given a string containing digits from 2-9 inclusive, return all possible letter combinations that the number could represent. Return the answer in any order.
class Solution:
def letterCombinations(self, digits: str):
if not digits:
return []
def is_solution(state) -> bool:
return len(state) == len(digits)
def process_solution(state):
solutions.append("".join(state))
def get_candidates(state):
if len(state) == len(digits):
return []
digit_to_letters = {
'2': 'abc',
'3': 'def',
'4': 'ghi',
'5': 'jkl',
'6': 'mno',
'7': 'pqrs',
'8': 'tuv',
'9': 'wxyz'
}
digit = digits[len(state)]
return list(digit_to_letters[digit])
def is_valid(state):
return True
def solve(state):
if is_solution(state):
process_solution(state)
return
candidates = get_candidates(state)
for candidate in candidates:
if is_valid(state + [candidate]):
solve(state + [candidate])
solutions = []
solve([])
return solutions
Solution().letterCombinations("23")
# >>> ['ad', 'ae', 'af', 'bd', 'be', 'bf', 'cd', 'ce', 'cf']