replace ckvec with map in satisfiability program
[charm.git] / examples / charm++ / satisfiability / par_Solver.C
1 #include "main.decl.h"
2 #include "par_SolverTypes.h"
3 #include "par_Solver.h"
4 #include <map>
5 #include <vector>
6
7 #ifdef MINISAT
8 #include "Solver.h"
9 #endif
10
11 #ifdef TNM
12 #include "TNM.h"
13 #endif
14
15 using namespace std;
16
17 extern CProxy_Main mainProxy;
18 extern int grainsize;
19
20 par_SolverState::par_SolverState()
21 {
22     var_frequency = 0;
23 }
24
25
26 int par_SolverState::clausesSize()
27 {
28     return clauses.size();
29 }
30
31 void par_SolverState::attachClause(par_Clause& c)
32 {
33
34 }
35
36 int par_SolverState::unsolvedClauses()
37 {
38     int unsolved = 0;
39     for(int i=0; i< clauses.size(); i++)
40     {
41         if(clauses[i].size() > 0)
42             unsolved++;
43     }
44
45     return unsolved;
46 }
47
48 void par_SolverState::printSolution()
49 {
50     for(int _i=0; _i<var_size; _i++)
51     {
52         CkPrintf("%d\n", occurrence[_i]);
53     }
54  
55 }
56
57 /* add clause, before adding, check unit conflict */
58 bool par_SolverState::addClause(CkVec<par_Lit>& ps)
59 {
60     /*TODO precheck is needed here */
61     if(ps.size() == 1)//unit clause
62     {
63         for(int _i=0; _i<unit_clause_index.size(); _i++)
64         {
65             int index = unit_clause_index[_i];
66
67             par_Clause unit = clauses[index];
68             par_Lit     unit_lit = unit[0];
69
70             if(unit_lit == ~ps[0])
71             {
72    
73                 CkPrintf("clause conflict between %d and %d\n", index, clauses.size());
74                 return false;
75             }
76         }
77        /* all unit clauses are checked, no repeat, no conflict */
78         unit_clause_index.push_back(clauses.size());
79     }else{
80     //check whether the clause is already satisfiable in any case
81     /* if x and ~x exist in the same clause, this clause is already satisfiable, we do not have to deal with it */
82    
83     for(int i=0; i< ps.size(); i++)
84     {
85         par_Lit lit = ps[i];
86
87         for(int j=0; j<i; j++)
88         {
89             if(lit == ~ps[j])
90             {
91                 CkPrintf("This clause is already satisfiable\n");
92                 for(int _i=0; _i<ps.size(); _i++)
93                 {
94                    occurrence[abs(toInt(ps[_i])-1)]--; 
95                    if(toInt(ps[_i]) > 0)
96                        positive_occurrence[abs(toInt(ps[_i])-1)]--; 
97                 }
98                 return true;
99             }
100         }
101     }
102     }
103     clauses.push_back(par_Clause());
104     clauses[clauses.size()-1].attachdata(ps, false);
105     // build the linklist for each literal pointing to the clause, where the literal occurs
106     for(int i=0; i< ps.size(); i++)
107     {
108         par_Lit lit = ps[i];
109
110         if(toInt(lit) > 0)
111             whichClauses[2*toInt(lit)-2].insert(pair<int, int>(clauses.size()-1, 1));
112         else
113             whichClauses[-2*toInt(lit)-1].insert(pair<int, int> (clauses.size()-1, 1));
114
115     }
116
117     return true;
118 }
119
120 void par_SolverState::assignclause(CkVec<par_Clause>& cls )
121 {
122     clauses.removeAll();
123     for(int i=0; i<cls.size(); i++)
124     {
125         clauses.push_back( cls[i]);
126     }
127 }
128
129 /* *********** Solver chare */
130
131 mySolver::mySolver(par_SolverState* state_msg)
132 {
133
134     /* Which variable get assigned  */
135     par_Lit lit = state_msg->assigned_lit;
136 #ifdef DEBUG    
137     CkPrintf("\n\nNew chare: literal = %d, occurrence size=%d, level=%d \n", toInt(lit), state_msg->occurrence.size(), state_msg->level);
138 #endif    
139     par_SolverState *next_state = copy_solverstate(state_msg);
140
141     //Unit clauses
142     /* use this value to propagate the clauses */
143     // deal with the clauses where this literal is
144     int _unit_ = -1;
145     while(1){
146         int pp_ = 1;
147         int pp_i_ = 2;
148         int pp_j_ = 1;
149
150         if(toInt(lit) < 0)
151         {
152             pp_ = -1;
153             pp_i_ = 1;
154             pp_j_ = 2;
155         }
156
157         next_state->occurrence[pp_*toInt(lit)-1] = -pp_i_;
158     
159
160         map_int_int &inClauses = next_state->whichClauses[pp_*2*toInt(lit)-pp_i_];
161         map_int_int  &inClauses_opposite = next_state->whichClauses[pp_*2*toInt(lit)-pp_j_];
162
163     // literal with same sign, remove all these clauses
164        for( map_int_int::iterator iter = inClauses.begin(); iter!=inClauses.end(); iter++)
165        {
166            int cl_index = (*iter).first;
167
168            par_Clause& cl_ = next_state->clauses[cl_index];
169            //for all the literals in this clauses, the occurrence decreases by 1
170            for(int k=0; k< cl_.size(); k++)
171            {
172                par_Lit lit_ = cl_[k];
173                if(toInt(lit_) == toInt(lit))
174                    continue;
175                next_state->occurrence[abs(toInt(lit_)) - 1]--;
176                if(toInt(lit_) > 0)
177                {
178                    next_state->positive_occurrence[toInt(lit_)-1]--;
179                    map_int_int::iterator one_it = next_state->whichClauses[2*toInt(lit_)-2].find(cl_index);
180                    next_state->whichClauses[2*toInt(lit_)-2].erase(one_it);
181                }else
182                {
183                    map_int_int::iterator one_it = next_state->whichClauses[-2*toInt(lit_)-1].find(cl_index);
184                    next_state->whichClauses[-2*toInt(lit_)-1].erase(one_it);
185                }
186                
187            } //finish dealing with all literal in the clause
188            next_state->clauses[cl_index].resize(0);
189        } //finish dealing with clauses where the literal occur the same
190        /* opposite to the literal */
191        for(map_int_int::iterator iter= inClauses_opposite.begin(); iter!=inClauses_opposite.end(); iter++)
192        {
193            int cl_index_ = (*iter).first;
194            par_Clause& cl_neg = next_state->clauses[cl_index_];
195            cl_neg.remove(-toInt(lit));
196
197            /*becomes a unit clause */
198            if(cl_neg.size() == 1)
199            {
200                next_state->unit_clause_index.push_back(cl_index_);
201            }else if (cl_neg.size() == 0)
202            {
203                return;
204            }
205        }
206
207        _unit_++;
208        if(_unit_ == next_state->unit_clause_index.size())
209            break;
210        par_Clause cl = next_state->clauses[next_state->unit_clause_index[_unit_]];
211
212        while(cl.size() == 0){
213            _unit_++;
214            if(_unit_ == next_state->unit_clause_index.size())
215                break;
216            cl = next_state->clauses[next_state->unit_clause_index[_unit_]];
217        };
218
219        if(_unit_ == next_state->unit_clause_index.size())
220            break;
221        lit = cl[0];
222
223
224     }
225     /***************/
226
227     int unsolved = next_state->unsolvedClauses();
228     if(unsolved == 0)
229     {
230         CkPrintf("One solution found in parallel processing \n");
231         //next_state->printSolution();
232         mainProxy.done(next_state->occurrence);
233         return;
234     }
235     int max_index = get_max_element(next_state->occurrence);
236 #ifdef DEBUG
237     CkPrintf("max index = %d\n", max_index);
238 #endif
239     //if() left literal unassigned is larger than a grain size, parallel 
240     ///* When we start sequential 3SAT Grainsize problem*/
241    
242     /* the other branch */
243     par_SolverState *new_msg2 = copy_solverstate(next_state);;
244
245     next_state->level = state_msg->level+1;
246
247     int lower = state_msg->lower;
248     int higher = state_msg->higher;
249     int middle = (lower+higher)/2;
250     int positive_max = next_state->positive_occurrence[max_index];
251     if(positive_max >= next_state->occurrence[max_index] - positive_max)
252     {
253         next_state->assigned_lit = par_Lit(max_index+1);
254         next_state->occurrence[max_index] = -2;
255     }
256     else
257     {
258         next_state->assigned_lit = par_Lit(-max_index-1);
259         next_state->occurrence[max_index] = -1;
260     }
261     bool satisfiable_1 = true;
262
263     if(unsolved > grainsize)
264     {
265         next_state->lower = lower + 1;
266         next_state->higher = middle;
267         *((int *)CkPriorityPtr(next_state)) = lower+1;
268         CkSetQueueing(next_state, CK_QUEUEING_IFIFO);
269         CProxy_mySolver::ckNew(next_state);
270     }
271     else //sequential
272     {
273         /* This code is urgly. Need to revise it later. Convert par data structure to sequential 
274          */
275         vector< vector<int> > seq_clauses;
276         //seq_clauses.resize(next_state->clauses.size());
277         for(int _i_=0; _i_<next_state->clauses.size(); _i_++)
278         {
279             if(next_state->clauses[_i_].size() > 0)
280             {
281                 vector<int> unsolvedclaus;
282                 par_Clause& cl = next_state->clauses[_i_];
283                 unsolvedclaus.resize(cl.size());
284                 for(int _j_=0; _j_<cl.size(); _j_++)
285                 {
286                     unsolvedclaus[_j_] = toInt(cl[_j_]);
287                 }
288                 seq_clauses.push_back(unsolvedclaus);
289             }
290         }
291         satisfiable_1 = seq_processing(next_state->var_size, seq_clauses);//seq_solve(next_state);
292         if(satisfiable_1)
293         {
294             CkPrintf("One solution found by sequential processing \n");
295             mainProxy.done(next_state->occurrence);
296             return;
297         }
298     }
299
300     new_msg2->level = state_msg->level+1;
301     if(positive_max >= new_msg2->occurrence[max_index] - positive_max)
302     {
303         new_msg2->assigned_lit = par_Lit(-max_index-1);
304         new_msg2->occurrence[max_index] = -1;
305     }
306     else
307     {
308         new_msg2->assigned_lit = par_Lit(max_index+1);
309         new_msg2->occurrence[max_index] = -2;
310     }
311     unsolved = new_msg2->unsolvedClauses();
312     if(unsolved > grainsize)
313     {
314         new_msg2->lower = middle + 1;
315         new_msg2->higher = higher-1;
316         *((int *)CkPriorityPtr(new_msg2)) = middle+1;
317         CkSetQueueing(new_msg2, CK_QUEUEING_IFIFO);
318         CProxy_mySolver::ckNew(new_msg2);
319     }
320     else
321     {
322        
323         vector< vector<int> > seq_clauses;
324         for(int _i_=0; _i_<new_msg2->clauses.size(); _i_++)
325         {
326             par_Clause& cl = new_msg2->clauses[_i_];
327             if(cl.size() > 0)
328             {
329                 vector<int> unsolvedclaus;
330                 unsolvedclaus.resize(cl.size());
331                 for(int _j_=0; _j_<cl.size(); _j_++)
332                 {
333                     unsolvedclaus[_j_] = toInt(cl[_j_]);
334                 }
335                 seq_clauses.push_back(unsolvedclaus);
336             }
337         }
338
339         bool ret = seq_processing(new_msg2->var_size,  seq_clauses);//seq_solve(next_state);
340         //bool ret = Solver::seq_processing(new_msg2->clauses);//seq_solve(new_msg2);
341         if(ret)
342         {
343             CkPrintf("One solution found by sequential processing \n");
344             mainProxy.done(new_msg2->occurrence);
345         }
346         return;
347     }
348
349 }
350
351 /* Which literals are already assigned, which is assigned this interation, the unsolved clauses */
352 /* should all these be passed as function parameters */
353 /* solve the 3sat in sequence */
354
355 long long int computes = 0;
356 bool mySolver::seq_solve(par_SolverState* state_msg)
357 {
358     /* Which variable get assigned  */
359     par_Lit lit = state_msg->assigned_lit;
360        
361 #ifdef DEBUG
362     CkPrintf("\n\n Computes=%d Sequential SAT New chare: literal = %d,  level=%d, unsolved clauses=%d\n", computes++, toInt(assigned_var), state_msg->level, state_msg->clauses.size());
363     //CkPrintf("\n\n Computes=%d Sequential SAT New chare: literal = %d, occurrence size=%d, level=%d \n", computes++, toInt(assigned_var), state_msg->occurrence.size(), state_msg->level);
364 #endif
365     par_SolverState *next_state = copy_solverstate(state_msg);
366     
367     //Unit clauses
368     /* use this value to propagate the clauses */
369 #ifdef DEBUG
370     CkPrintf(" remainning clause size is %d\n", (state_msg->clauses).size());
371 #endif
372
373     int _unit_ = -1;
374     while(1){
375     int pp_ = 1;
376     int pp_i_ = 2;
377     int pp_j_ = 1;
378
379     if(toInt(lit) < 0)
380     {
381         pp_ = -1;
382         pp_i_ = 1;
383         pp_j_ = 2;
384     }
385
386     next_state->occurrence[pp_*toInt(lit)-1] = -pp_i_;
387     
388     map_int_int &inClauses = next_state->whichClauses[pp_*2*toInt(lit)-pp_i_];
389     map_int_int &inClauses_opposite = next_state->whichClauses[pp_*2*toInt(lit)-pp_j_];
390
391     // literal with same sign, remove all these clauses
392     
393     for( map_int_int::iterator iter = inClauses.begin(); iter!=inClauses.end(); iter++)
394     {
395         int cl_index = (*iter).first;
396         par_Clause& cl_ = next_state->clauses[cl_index];
397         //for all the literals in this clauses, the occurrence decreases by 1
398         for(int k=0; k< cl_.size(); k++)
399         {
400             par_Lit lit_ = cl_[k];
401             if(toInt(lit_) == toInt(lit))
402                 continue;
403             next_state->occurrence[abs(toInt(lit_)) - 1]--;
404             if(toInt(lit_) > 0)
405             {
406                 next_state->positive_occurrence[toInt(lit_)-1]--;
407                 map_int_int::iterator one_it = next_state->whichClauses[2*toInt(lit_)-2].find(cl_index);
408                 next_state->whichClauses[2*toInt(lit_)-2].erase(one_it);
409             }else
410             {
411                 map_int_int::iterator one_it = next_state->whichClauses[-2*toInt(lit_)-1].find(cl_index);
412                 next_state->whichClauses[-2*toInt(lit_)-1].erase(one_it);
413             }
414
415         }
416         next_state->clauses[cl_index].resize(0);
417     }
418    
419     for(map_int_int::iterator iter= inClauses_opposite.begin(); iter!=inClauses_opposite.end(); iter++)
420     {
421         int cl_index_ = (*iter).first;
422         par_Clause& cl_neg = next_state->clauses[cl_index_];
423         cl_neg.remove(-toInt(lit));
424             //becomes a unit clause
425          if(cl_neg.size() == 1)
426          {
427              next_state->unit_clause_index.push_back(cl_index_);
428          }else if (cl_neg.size() == 0)
429          {
430                 return false;
431          }
432     }
433    
434     _unit_++;
435     if(_unit_ == next_state->unit_clause_index.size())
436         break;
437     par_Clause cl = next_state->clauses[next_state->unit_clause_index[_unit_]];
438     
439     while(cl.size() == 0){
440         _unit_++;
441         if(_unit_ == next_state->unit_clause_index.size())
442             break;
443         cl = next_state->clauses[next_state->unit_clause_index[_unit_]];
444         
445     };
446
447     if(_unit_ == next_state->unit_clause_index.size())
448         break;
449     
450     lit = cl[0];
451     }
452    
453     int unsolved = next_state->unsolvedClauses();
454     if(unsolved == 0)
455     {
456         CkPrintf("One solution found in sequential processing, check the output file for assignment\n");
457         mainProxy.done(next_state->occurrence);
458         return true;
459     }
460     
461     /**********************/
462     
463         /* it would be better to insert the unit literal in order of their occurrence */
464         /* make a decision and then fire new tasks */
465         /* if there is unit clause, should choose these first??? TODO */
466         /* TODO which variable to pick up */
467         /*unit clause literal and also which occurrs most times */
468         int max_index =  get_max_element(next_state->occurrence);
469 #ifdef DEBUG
470         CkPrintf("max index = %d\n", max_index);
471 #endif
472         next_state->level = state_msg->level+1;
473
474         par_SolverState *new_msg2 = copy_solverstate(next_state);;
475         
476         int positive_max = next_state->positive_occurrence[max_index];
477         if(positive_max >= next_state->occurrence[max_index] - positive_max)
478         {
479             next_state->occurrence[max_index] = -2;
480             next_state->assigned_lit = par_Lit(max_index+1);
481         }
482         else
483         {
484             next_state->occurrence[max_index] = -1;
485             next_state->assigned_lit = par_Lit(-max_index-1);
486         } 
487
488         bool   satisfiable_1 = seq_solve(next_state);
489         if(satisfiable_1)
490         {
491             return true;
492         }
493         
494         new_msg2->level = state_msg->level+1;
495        
496         if(positive_max >= next_state->occurrence[max_index] - positive_max)
497         {
498             new_msg2->occurrence[max_index] = -1;
499             new_msg2->assigned_lit = par_Lit(-max_index-1);
500         }
501         else
502         {
503             new_msg2->occurrence[max_index] = -2;
504             new_msg2->assigned_lit = par_Lit(max_index+1);
505         } 
506             
507         bool satisfiable_0 = seq_solve(new_msg2);
508         if(satisfiable_0)
509         {
510             return true;
511         }
512
513         //CkPrintf("Unsatisfiable through sequential\n");
514         return false;
515
516 }