ab4794e785a0aff1d63f79e8637d9d018c05f429
[charm.git] / examples / charm++ / satisfiability / main.C
1 #include <stdio.h>
2 #include <cstring>
3 #include <stdint.h>
4 #include <errno.h>
5 #include <limits.h>
6
7 #include <signal.h>
8 #include <zlib.h>
9
10 #include <vector>
11
12 #include "main.decl.h"
13 #include "main.h"
14
15 #include "par_SolverTypes.h"
16 #include "par_Solver.h"
17
18 #ifdef MINISAT
19 #include "Solver.h"
20 #endif
21
22 #ifdef TNM
23 #include "TNM.h"
24 #endif
25
26 using namespace std;
27
28 //#include "util.h"
29
30 typedef map<int, int> map_int_int;
31
32 CProxy_Main mainProxy;
33
34 #define CHUNK_LIMIT 1048576
35
36 char inputfile[50];
37
38 class StreamBuffer {
39     gzFile  in;
40     char    buf[CHUNK_LIMIT];
41     int     pos;
42     int     size;
43
44     void assureLookahead() {
45         if (pos >= size) {
46             pos  = 0;
47             size = gzread(in, buf, sizeof(buf)); } }
48
49 public:
50             StreamBuffer(gzFile i) : in(i), pos(0), size(0) {
51                 assureLookahead(); }
52
53             int  operator *  () { return (pos >= size) ? EOF : buf[pos]; }
54             void operator ++ () { pos++; assureLookahead(); }
55 };
56
57 static bool match(StreamBuffer& in, char* str) {
58     for (; *str != 0; ++str, ++in)
59         if (*str != *in)
60             return false;
61     return true;
62 }
63
64
65 void error_exit(char *error)
66 {
67     printf("%s\n", error);
68     CkExit();
69 }
70
71 void skipWhitespace(StreamBuffer& in) 
72 {
73     while ((*in >= 9 && *in <= 13) || *in == 32)
74         ++in;
75 }
76
77 void skipLine(StreamBuffer& in) {
78     for (;;){
79         if (*in == EOF || *in == '\0') return;
80         if (*in == '\n')
81         { ++in; return; }
82         ++in;
83     } 
84 }
85
86
87 int parseInt(StreamBuffer& in) {
88     int     val = 0;
89     bool    neg = false;
90     skipWhitespace(in);
91     if      (*in == '-') neg = true, ++in;
92     else if (*in == '+') ++in;
93     if (*in < '0' || *in > '9')
94         error_exit((char*)"ParseInt error\n");
95
96     while (*in >= '0' && *in <= '9')
97     {
98         val = val*10 + (*in - '0');
99         ++in;
100     }
101     return neg ? -val : val; 
102 }
103
104 static void readClause(StreamBuffer& in, par_SolverState& S, CkVec<par_Lit>& lits) {
105     int     parsed_lit, var;
106     lits.removeAll();
107     for (;;){
108         parsed_lit = parseInt(in);
109         if (parsed_lit == 0) break;
110         var = abs(parsed_lit)-1;
111         
112         S.occurrence[var]++;
113         if(parsed_lit>0)
114             S.positive_occurrence[var]++;
115         lits.push_back( par_Lit(parsed_lit));
116     }
117 }
118
119 /* unit propagation before real computing */
120
121 static void simplify(par_SolverState& S)
122 {
123     for(int i=0; i< S.unit_clause_index.size(); i++)
124     {
125 #ifdef DEBUG
126         CkPrintf("Inside simplify before processing, unit clause number:%d, i=%d\n", S.unit_clause_index.size(), i);
127 #endif
128        
129         par_Clause cl = S.clauses[S.unit_clause_index[i]];
130         //only one element in unit clause
131         par_Lit lit = cl[0];
132         S.clauses[S.unit_clause_index[i]].resize(0);
133
134         int pp_ = 1;
135         int pp_i_ = 2;
136         int pp_j_ = 1;
137
138        if(toInt(lit) < 0)
139        {
140            pp_ = -1;
141            pp_i_ = 1;
142            pp_j_ = 2;
143        }
144        S.occurrence[pp_*toInt(lit)-1] = -pp_i_;
145        map_int_int &inClauses = S.whichClauses[pp_*2*toInt(lit)-pp_i_];
146        map_int_int &inClauses_opposite = S.whichClauses[pp_*2*toInt(lit)-pp_j_];
147        
148        // literal with same sign
149        for( map_int_int::iterator iter = inClauses.begin(); iter!=inClauses.end(); iter++)
150        {
151            int cl_index = (*iter).first;
152 #ifdef DEBUG
153            CkPrintf(" %d \n \t \t literals in this clauses: ", cl_index);
154 #endif
155            par_Clause& cl_ = S.clauses[cl_index];
156            //for all the literals in this clauses, the occurrence decreases by 1
157            for(int k=0; k< cl_.size(); k++)
158            {
159                par_Lit lit_ = cl_[k];
160                if(toInt(lit_) == toInt(lit))
161                    continue;
162 #ifdef DEBUG
163                CkPrintf(" %d  ", toInt(lit_));
164 #endif
165                S.occurrence[abs(toInt(lit_)) - 1]--;
166                if(toInt(lit_) > 0)
167                {
168                    S.positive_occurrence[toInt(lit_)-1]--;
169                    map_int_int::iterator one_it = S.whichClauses[2*toInt(lit_)-2].find(cl_index);
170                    S.whichClauses[2*toInt(lit_)-2].erase(one_it);
171                }else
172                {
173                    map_int_int::iterator one_it = S.whichClauses[-2*toInt(lit_)-1].find(cl_index);
174                    S.whichClauses[-2*toInt(lit_)-1].erase(one_it);
175                }
176
177            }
178            
179            S.clauses[cl_index].resize(0); //this clause goes away. In order to keep index unchanged, resize it as 0
180        }
181        
182        for(map_int_int::iterator iter= inClauses_opposite.begin(); iter!=inClauses_opposite.end(); iter++)
183        {
184            int cl_index_ = (*iter).first;
185            par_Clause& cl_neg = S.clauses[cl_index_];
186            cl_neg.remove(-toInt(lit));
187            //becomes a unit clause
188            if(cl_neg.size() == 1)
189            {
190                S.unit_clause_index.push_back(cl_index_);
191            }
192        }
193
194     }
195
196     S.unit_clause_index.removeAll();
197 }
198
199
200
201 static void parse_confFile(gzFile input_stream, par_SolverState& S) {                  
202     StreamBuffer in(input_stream);    
203       CkVec<par_Lit> lits;                                                 
204     int i  = 0;
205     
206     for (;;){                                                      
207         //printf(" + on %d\n", i++);
208         skipWhitespace(in);                                        
209         if (*in == EOF)                                            
210             break;                                                 
211         else if (*in == 'p'){                                      
212             if (match(in, (char*)"p cnf")){                               
213                 int vars    = parseInt(in);                        
214                 int clauses = parseInt(in);                        
215                 printf("|  Number of variables:  %-12d                                         |\n", vars);
216                 printf("|  Number of clauses:    %-12d                                         |\n", clauses);
217                 
218
219                 S.var_size = vars;
220                 S.occurrence.resize(vars);
221                 S.positive_occurrence.resize(vars);
222                 S.whichClauses.resize(2*vars);
223                 for(int __i=0; __i<vars; __i++)
224                 {
225                     S.occurrence[__i] = 0;
226                     S.positive_occurrence[__i] = 0;
227                 }
228             }else{
229                 printf("PARSE ERROR! Unexpected char: %c\n", *in);
230                 error_exit((char*)"Parse Error\n");
231             }
232         } else if (*in == 'c' || *in == 'p')
233             skipLine(in);
234         else{
235             readClause(in, S, lits);
236             if( !S.addClause(lits))
237             {
238                 CkPrintf("conflict detected by addclauses\n");
239                 CkExit();
240             }
241             
242         }
243     
244     }
245 }
246
247
248 Main::Main(CkArgMsg* msg)
249 {
250
251     grainsize = 1;
252     par_SolverState* solver_msg = new (8 * sizeof(int))par_SolverState;
253     if(msg->argc < 2)
254     {
255         error_exit((char*)"Usage: sat filename grainsize\n");
256     }else
257         grainsize = atoi(msg->argv[2]);
258
259
260     CkPrintf("problem file:\t\t%s\ngrainsize:\t\t%d\nprocessor number:\t\t%d\n", msg->argv[1], grainsize, CkNumPes()); 
261
262     /* read file */
263
264     starttimer = CkWallTimer();
265
266     /*read information from file */
267     gzFile in = gzopen(msg->argv[1], "rb");
268
269     strcpy(inputfile, msg->argv[1]);
270
271     if(in == NULL)
272     {
273         error_exit((char*)"Invalid input filename\n");
274     }
275
276     parse_confFile(in, *solver_msg);
277
278     solver_msg->printSolution();
279     /*  unit propagation */ 
280     simplify(*solver_msg);
281 #ifdef DEBUG
282     for(int __i = 0; __i<solver_msg->occurrence.size(); __i++)
283     {
284         FILE *file;
285         char outputfile[50];
286         sprintf(outputfile, "%s.sat", inputfile);
287         file = fopen(outputfile, "w");
288         for(int i=0; i<assignment.size(); i++)
289         {
290             fprintf(file, "%d\n", assignment[i]);
291         }
292     }
293
294 #endif
295
296     int unsolved = solver_msg->unsolvedClauses();
297
298     if(unsolved == 0)
299     {
300         CkPrintf(" This problem is solved by pre-processing\n");
301         CkExit();
302     }
303     readfiletimer = CkWallTimer();
304     /*fire the first chare */
305     /* 1)Which variable is assigned which value this time, (variable, 1), current clauses status vector(), literal array activities */
306
307
308     /***  If grain size is larger than the clauses size, that means 'sequential' */
309     if(grainsize > solver_msg->clauses.size())
310     {
311         vector< vector<int> > seq_clauses;
312         for(int _i_=0; _i_<solver_msg->clauses.size(); _i_++)
313         {
314             if(solver_msg->clauses[_i_].size() > 0)
315             {
316                 vector<int> unsolvedclaus;
317                 par_Clause& cl = solver_msg->clauses[_i_];
318                 unsolvedclaus.resize(cl.size());
319                 for(int _j_=0; _j_<cl.size(); _j_++)
320                 {
321                     unsolvedclaus[_j_] = toInt(cl[_j_]);
322                 }
323                 seq_clauses.push_back(unsolvedclaus);
324             }
325         }
326         bool satisfiable_1 = seq_processing(solver_msg->var_size, seq_clauses);//seq_solve(next_state);
327
328         if(satisfiable_1)
329         {
330             CkPrintf("One solution found without using any parallel\n");
331         }else
332         {
333        
334             CkPrintf(" Unsatisfiable\n");
335         }
336         done(solver_msg->occurrence);
337         return;
338     }
339     mainProxy = thisProxy;
340     int max_index = get_max_element(solver_msg->occurrence);
341     
342     solver_msg->assigned_lit = par_Lit(max_index+1);
343     solver_msg->level = 0;
344     par_SolverState *not_msg = copy_solverstate(solver_msg);
345     
346     solver_msg->occurrence[max_index] = -2;
347     not_msg->assigned_lit = par_Lit(-max_index-1);
348     not_msg->occurrence[max_index] = -1;
349     
350     int positive_max = solver_msg->positive_occurrence[max_index];
351     if(positive_max >= solver_msg->occurrence[max_index] - positive_max)
352     {
353
354         // assign true first and then false
355         *((int *)CkPriorityPtr(solver_msg)) = INT_MIN;
356         CkSetQueueing(solver_msg, CK_QUEUEING_IFIFO);
357         solver_msg->lower = INT_MIN;
358         solver_msg->higher = 0;
359         CProxy_mySolver::ckNew(solver_msg);
360         
361         *((int *)CkPriorityPtr(not_msg)) = 0;
362         CkSetQueueing(not_msg, CK_QUEUEING_IFIFO);
363         not_msg->lower = 0;
364         not_msg->higher = INT_MAX;
365         CProxy_mySolver::ckNew(not_msg);
366     }else
367     {
368         *((int *)CkPriorityPtr(not_msg)) = INT_MIN;
369         CkSetQueueing(not_msg, CK_QUEUEING_IFIFO);
370         not_msg->lower = INT_MIN;
371         not_msg->higher = 0;
372         CProxy_mySolver::ckNew(not_msg);
373         
374         *((int *)CkPriorityPtr(solver_msg)) = 0;
375         CkSetQueueing(solver_msg, CK_QUEUEING_IFIFO);
376         solver_msg->lower = 0;
377         solver_msg->higher = INT_MAX;
378         CProxy_mySolver::ckNew(solver_msg);
379
380     }
381 }
382
383 Main::Main(CkMigrateMessage* msg) {}
384
385 void Main::done(CkVec<int> assignment)
386 {
387
388     double endtimer = CkWallTimer();
389
390     CkPrintf("\nFile reading and processing time:         %f\n", readfiletimer-starttimer);
391     CkPrintf("Solving time:                             %f\n", endtimer - readfiletimer);
392  
393     FILE *file;
394     char outputfile[50];
395     sprintf(outputfile, "%s.sat", inputfile);
396     file = fopen(outputfile, "w");
397
398     for(int i=0; i<assignment.size(); i++)
399     {
400         fprintf(file, "%d\n", assignment[i]);
401     }
402     CkExit();
403 }
404 #include "main.def.h"
405