最近在看回溯搜索算法,觉得值得总结一下。
回溯法简单说就是,为了搜到合适的解做出尝试,尝试失败时就退回到之前状态,然后尝试其他的可能性。这类算法套路比较固定。总的思想我概括为:搜到就输出,搜不到就返回,符合条件就搜。
流程用文字描述,大概分为两类情况:
1.当搜索过程每前进一步,会遇到多种情况时:这类情况的代码形式上类似图的深度遍历
void search(int k){
if (达到目标) 输出解
else
for each(所有可能情况)
if (满足递归条件){
保存状态
search(k+1)
恢复之前的状态
}
}
2.当搜索过程每前进一步,只有两种选择时:这类情况的代码形式上类似二叉树的先根遍历
void search(int k){
if (达到目标) 输出解
else{
if(符合选择1条件){
保存状态和数据
search(k+1, 参数1)
恢复
}
if(符合选择2条件){
保存状态和数据
search(k+1, 参数2)
恢复
}
}
}
当然不一定所有问题都严格遵守上面的步骤。比如,有的时候搜索的分支可能并没改变数据(之后有举例),这时保存数据这步是可以没有的。
回溯法的优点在于省内存,搜索过程中产生解空间。由于采用了递归和深度优先的策略,最大耗费空间仅仅和搜索的最大深度有关。
第一个例子,给出n和m,从1到n中挑出m个数,产生所有可能的排列数和组合数,n小于10:
#include <iostream>
#include <vector>
using namespace std;
bool used[10]={false};//数字i是否使用过
int num[10];//n,m smaller than 10
int id1=0, id2=0;
vector<int> num2;
bool used2[10]={false};
void p(int n, int r, int k){
if(k==0){
id1++;
cout<<id1<<": ";
for(int i=r; i>=1; i--)cout<<num[i];
cout<<"\t";
}else{
for(int i=1; i<=n; i++){
if(!used[i]){
num[k]=i;
used[i]=true;
p(n,r,k-1);
used[i]=false;
}
}
}
}
void permutation(int n, int r){
p(n,r,r);
}
void c(int n, int r, int k){
if((int)num2.size()==r){
id2++;
cout<<id2<<": ";
for(int i=0; i<num2.size(); i++)cout<<num2[i];
cout<<"\t";
}else{
for(int i=k; i<=n; i++){//i=k,保证排列结果是由小到大输出
if(!used2[i]){
num2.push_back(i);
used2[i]=true;//save
c(n,r,i+1);//search
used2[i]=false;
num2.pop_back();//go back
}
}
}
}
void combination(int n, int r){
c(n,r,1);
}
int main(){
int n,m;
cin>>n>>m;
cout<<"permutation:"<<endl;
permutation(n,m);
cout<<endl;
cout<<"combination"<<endl;
combination(n,m);
return 0;
}
当然上面的例子是输出所有的排列和组合。大部分时候我们只是搜索特定的一个结果,或某些结果。因此舍弃不必要的解是很重要的。这就涉及限界,或者说剪枝。如果搜索到某一步时再接着搜肯定没有结果,那么就停止在这个分支的搜索,通常用一个限制条件或外加判断函数来判断。
第二个例子,给定一个集合,知道它有n个元素,希望从n个数中取出若干个使得它们的和为c。第一行输入n和c,第二行输入集合内的数。把可行的数字组合(1组就行)输出,如果没有符合条件的组合,就输出NO SOLUTION。
#include <iostream>
#include <cstdio>
#include <cstdlib>
using namespace std;
int n,c;
int a[9000];
int r=0, mi[9000],s[9000], l[9000];
//s=sum, mi=min, l=answer list
void work(int x, int z){
if(z==c){
for(int i=1; i<=r; i++){
cout<<l[i]<<" ";
if(i==r) cout<<endl;
}
exit(0);
}else{
//限制条件,如果当前的和加上剩下所有的数超过目标值,且当前的和加上剩下最小的
也不超过目标值则继续搜索
if( x<=n && z+mi[x]<=c && z+s[x]>=c){
if(z+a[x]<=c){//do1, add a[x], when a[x] can be added
r++;
l[r]=a[x];//save
work(x+1,z+a[x]);//do1
r--;//back
}
work(x+1,z);//do2, not to add a[x]
}
}
}
/*
x增加时有两个选择,将下一个数挑出,或者不挑出,所以有do1和do2两个选择
*/
int main(){
scanf("%d%d", &n, &c);
for(int i=1; i<=n; i++)
scanf("%d", &a[i]);
mi[n]=s[n]=a[n];
for(int i=n-1; i>=1; i--){
if(a[i]<mi[i+1]) mi[i]=a[i];
else mi[i]=mi[i+1];
s[i]=s[i+1]+a[i];
}
work(1,0);
cout<<"No solution!"<<endl;
return 0;
}
有的时候在递归过程中一些结果可能会被重复搜索(即:走了不同的路线,但是到达了相同目的地 -_-|| ),我们还可以建立一个表储存搜索过的结果,如果发现搜过这个结果了,就不要继续递归,直接从表里读取,这样能避免大量的重复递归。这个实现起来相对容易,就不再举例了。