0%

回溯法搜索总结

最近在看回溯搜索算法,觉得值得总结一下。

回溯法简单说就是,为了搜到合适的解做出尝试,尝试失败时就退回到之前状态,然后尝试其他的可能性。这类算法套路比较固定。总的思想我概括为:搜到就输出,搜不到就返回,符合条件就搜。

流程用文字描述,大概分为两类情况:

1.当搜索过程每前进一步,会遇到多种情况时:这类情况的代码形式上类似图的深度遍历

void search(int k){
    if (达到目标) 输出解
    else
       for each(所有可能情况)
       if (满足递归条件){
         保存状态
         search(k+1)
         恢复之前的状态
       }
}

2.当搜索过程每前进一步,只有两种选择时:这类情况的代码形式上类似二叉树的先根遍历

void search(int k){
   if (达到目标) 输出解
   else{
      if(符合选择1条件){
        保存状态和数据
        search(k+1, 参数1)
        恢复
      }
      if(符合选择2条件){
        保存状态和数据
        search(k+1, 参数2)
        恢复
      }
   }        
}

当然不一定所有问题都严格遵守上面的步骤。比如,有的时候搜索的分支可能并没改变数据(之后有举例),这时保存数据这步是可以没有的。

回溯法的优点在于省内存,搜索过程中产生解空间。由于采用了递归和深度优先的策略,最大耗费空间仅仅和搜索的最大深度有关。

第一个例子,给出n和m,从1到n中挑出m个数,产生所有可能的排列数和组合数,n小于10:

#include <iostream>
#include <vector>
using namespace std;

bool used[10]={false};//数字i是否使用过
int num[10];//n,m smaller than 10

int id1=0, id2=0;

vector<int> num2;
bool used2[10]={false};

void p(int n, int r, int k){
    if(k==0){
        id1++;
        cout<<id1<<": ";
        for(int i=r; i>=1; i--)cout<<num[i];
        cout<<"\t";
    }else{
        for(int i=1; i<=n; i++){
            if(!used[i]){
                num[k]=i;
                used[i]=true;
                p(n,r,k-1);
                used[i]=false;
            }
        }
    }
}

void permutation(int n, int r){
    p(n,r,r);
}

void c(int n, int r, int k){
    if((int)num2.size()==r){
        id2++;
        cout<<id2<<": ";
        for(int i=0; i<num2.size(); i++)cout<<num2[i];
        cout<<"\t";
    }else{
        for(int i=k; i<=n; i++){//i=k,保证排列结果是由小到大输出
            if(!used2[i]){
                num2.push_back(i);
                used2[i]=true;//save
                c(n,r,i+1);//search
                used2[i]=false;
                num2.pop_back();//go back

            }
        }
    }
}

void combination(int n, int r){
    c(n,r,1);
}

int main(){
    int n,m;
    cin>>n>>m;
    cout<<"permutation:"<<endl;
    permutation(n,m);
    cout<<endl;
    cout<<"combination"<<endl;
    combination(n,m);
    return 0;
}

当然上面的例子是输出所有的排列和组合。大部分时候我们只是搜索特定的一个结果,或某些结果。因此舍弃不必要的解是很重要的。这就涉及限界,或者说剪枝。如果搜索到某一步时再接着搜肯定没有结果,那么就停止在这个分支的搜索,通常用一个限制条件或外加判断函数来判断。

第二个例子,给定一个集合,知道它有n个元素,希望从n个数中取出若干个使得它们的和为c。第一行输入n和c,第二行输入集合内的数。把可行的数字组合(1组就行)输出,如果没有符合条件的组合,就输出NO SOLUTION。

#include <iostream>
#include <cstdio>
#include <cstdlib>
using namespace std;

int n,c;
int a[9000];
int r=0, mi[9000],s[9000], l[9000];
//s=sum, mi=min, l=answer list
void work(int x, int z){
    if(z==c){
        for(int i=1; i<=r; i++){
            cout<<l[i]<<" ";
            if(i==r) cout<<endl;
        }
        exit(0);
    }else{
    //限制条件,如果当前的和加上剩下所有的数超过目标值,且当前的和加上剩下最小的
    也不超过目标值则继续搜索
        if( x<=n && z+mi[x]<=c && z+s[x]>=c){
            if(z+a[x]<=c){//do1, add a[x], when a[x] can be added
                r++;
                l[r]=a[x];//save

                work(x+1,z+a[x]);//do1

                r--;//back
            }
            work(x+1,z);//do2, not to add a[x]
        }
    }
}
/* 
    x增加时有两个选择,将下一个数挑出,或者不挑出,所以有do1和do2两个选择
*/

int main(){
    scanf("%d%d", &n, &c);
    for(int i=1; i<=n; i++)
        scanf("%d", &a[i]);

    mi[n]=s[n]=a[n];
    for(int i=n-1; i>=1; i--){
        if(a[i]<mi[i+1]) mi[i]=a[i];
        else mi[i]=mi[i+1];
        s[i]=s[i+1]+a[i];
    }
    work(1,0);

    cout<<"No solution!"<<endl;
    return 0;
}

有的时候在递归过程中一些结果可能会被重复搜索(即:走了不同的路线,但是到达了相同目的地 -_-|| ),我们还可以建立一个表储存搜索过的结果,如果发现搜过这个结果了,就不要继续递归,直接从表里读取,这样能避免大量的重复递归。这个实现起来相对容易,就不再举例了。