68d459b023009f5a3e3f5e0f997f8a29f42093d5
[charm.git] / examples / charm++ / satisfiability / main.C
1 #include <stdio.h>
2 #include <cstring>
3 #include <stdint.h>
4 #include <errno.h>
5 #include <limits.h>
6
7 #include <signal.h>
8 #include <zlib.h>
9
10 #include <vector>
11
12 #include "main.decl.h"
13 #include "main.h"
14
15 #include "par_SolverTypes.h"
16 #include "par_Solver.h"
17
18 #ifdef MINISAT
19 #include "Solver.h"
20 #endif
21
22 #ifdef TNM
23 #include "TNM.h"
24 #endif
25
26 using namespace std;
27
28 //#include "util.h"
29
30 typedef map<int, int> map_int_int;
31
32 CProxy_Main mainProxy;
33
34 #define CHUNK_LIMIT 1048576
35
36 char inputfile[50];
37
38 class StreamBuffer {
39     gzFile  in;
40     char    buf[CHUNK_LIMIT];
41     int     pos;
42     int     size;
43
44     void assureLookahead() {
45         if (pos >= size) {
46             pos  = 0;
47             size = gzread(in, buf, sizeof(buf)); } }
48
49 public:
50             StreamBuffer(gzFile i) : in(i), pos(0), size(0) {
51                 assureLookahead(); }
52
53             int  operator *  () { return (pos >= size) ? EOF : buf[pos]; }
54             void operator ++ () { pos++; assureLookahead(); }
55 };
56
57 static bool match(StreamBuffer& in, char* str) {
58     for (; *str != 0; ++str, ++in)
59         if (*str != *in)
60             return false;
61     return true;
62 }
63
64
65 void error_exit(char *error)
66 {
67     printf("%s\n", error);
68     CkExit();
69 }
70
71 void skipWhitespace(StreamBuffer& in) 
72 {
73     while ((*in >= 9 && *in <= 13) || *in == 32)
74         ++in;
75 }
76
77 void skipLine(StreamBuffer& in) {
78     for (;;){
79         if (*in == EOF || *in == '\0') return;
80         if (*in == '\n')
81         { ++in; return; }
82         ++in;
83     } 
84 }
85
86
87 int parseInt(StreamBuffer& in) {
88     int     val = 0;
89     bool    neg = false;
90     skipWhitespace(in);
91     if      (*in == '-') neg = true, ++in;
92     else if (*in == '+') ++in;
93     if (*in < '0' || *in > '9')
94         error_exit((char*)"ParseInt error\n");
95
96     while (*in >= '0' && *in <= '9')
97     {
98         val = val*10 + (*in - '0');
99         ++in;
100     }
101     return neg ? -val : val; 
102 }
103
104 static void readClause(StreamBuffer& in, par_SolverState& S, CkVec<par_Lit>& lits) {
105     int     parsed_lit, var;
106     lits.removeAll();
107     for (;;){
108         parsed_lit = parseInt(in);
109         if (parsed_lit == 0) break;
110         var = abs(parsed_lit)-1;
111         
112         S.occurrence[var]++;
113         if(parsed_lit>0)
114             S.positive_occurrence[var]++;
115         lits.push_back( par_Lit(parsed_lit));
116     }
117 }
118
119 /* unit propagation before real computing */
120
121 static void simplify(par_SolverState& S)
122 {
123     for(int i=0; i< S.unit_clause_index.size(); i++)
124     {
125 #ifdef DEBUG
126         CkPrintf("Inside simplify before processing, unit clause number:%d, i=%d\n", S.unit_clause_index.size(), i);
127 #endif
128        
129         par_Clause cl = S.clauses[S.unit_clause_index[i]];
130         //only one element in unit clause
131         par_Lit lit = cl[0];
132         S.clauses[S.unit_clause_index[i]].resize(0);
133
134         int pp_ = 1;
135         int pp_i_ = 2;
136         int pp_j_ = 1;
137
138        if(toInt(lit) < 0)
139        {
140            pp_ = -1;
141            pp_i_ = 1;
142            pp_j_ = 2;
143        }
144        S.occurrence[pp_*toInt(lit)-1] = -pp_i_;
145        map_int_int &inClauses = S.whichClauses[pp_*2*toInt(lit)-pp_i_];
146        map_int_int &inClauses_opposite = S.whichClauses[pp_*2*toInt(lit)-pp_j_];
147        
148        // literal with same sign
149        for( map_int_int::iterator iter = inClauses.begin(); iter!=inClauses.end(); iter++)
150        {
151            int cl_index = (*iter).first;
152 #ifdef DEBUG
153            CkPrintf(" %d \n \t \t literals in this clauses: ", cl_index);
154 #endif
155            par_Clause& cl_ = S.clauses[cl_index];
156            //for all the literals in this clauses, the occurrence decreases by 1
157            for(int k=0; k< cl_.size(); k++)
158            {
159                par_Lit lit_ = cl_[k];
160                if(toInt(lit_) == toInt(lit))
161                    continue;
162 #ifdef DEBUG
163                CkPrintf(" %d  ", toInt(lit_));
164 #endif
165                S.occurrence[abs(toInt(lit_)) - 1]--;
166                if(toInt(lit_) > 0)
167                {
168                    S.positive_occurrence[toInt(lit_)-1]--;
169                    map_int_int::iterator one_it = S.whichClauses[2*toInt(lit_)-2].find(cl_index);
170                    S.whichClauses[2*toInt(lit_)-2].erase(one_it);
171                }else
172                {
173                    map_int_int::iterator one_it = S.whichClauses[-2*toInt(lit_)-1].find(cl_index);
174                    S.whichClauses[-2*toInt(lit_)-1].erase(one_it);
175                }
176
177            }
178            
179            S.clauses[cl_index].resize(0); //this clause goes away. In order to keep index unchanged, resize it as 0
180        }
181        
182        for(map_int_int::iterator iter= inClauses_opposite.begin(); iter!=inClauses_opposite.end(); iter++)
183        {
184            int cl_index_ = (*iter).first;
185            par_Clause& cl_neg = S.clauses[cl_index_];
186            cl_neg.remove(-toInt(lit));
187            //becomes a unit clause
188            if(cl_neg.size() == 1)
189            {
190                S.unit_clause_index.push_back(cl_index_);
191            }
192        }
193
194     }
195
196     S.unit_clause_index.removeAll();
197 }
198
199
200
201 static void parse_confFile(gzFile input_stream, par_SolverState& S) {                  
202     StreamBuffer in(input_stream);    
203       CkVec<par_Lit> lits;                                                 
204     int i  = 0;
205     
206     for (;;){                                                      
207         //printf(" + on %d\n", i++);
208         skipWhitespace(in);                                        
209         if (*in == EOF)                                            
210             break;                                                 
211         else if (*in == 'p'){                                      
212             if (match(in, (char*)"p cnf")){                               
213                 int vars    = parseInt(in);                        
214                 int clauses = parseInt(in);                        
215                 printf("|  Number of variables:  %-12d                                         |\n", vars);
216                 printf("|  Number of clauses:    %-12d                                         |\n", clauses);
217                 
218
219                 S.var_size = vars;
220                 S.occurrence.resize(vars);
221                 S.positive_occurrence.resize(vars);
222                 S.whichClauses.resize(2*vars);
223                 for(int __i=0; __i<vars; __i++)
224                 {
225                     S.occurrence[__i] = 0;
226                     S.positive_occurrence[__i] = 0;
227                 }
228             }else{
229                 printf("PARSE ERROR! Unexpected char: %c\n", *in);
230                 error_exit((char*)"Parse Error\n");
231             }
232         } else if (*in == 'c' || *in == 'p')
233             skipLine(in);
234         else{
235             readClause(in, S, lits);
236             if( !S.addClause(lits))
237             {
238                 CkPrintf("conflict detected by addclauses\n");
239                 CkExit();
240             }
241             
242         }
243     
244     }
245 }
246
247
248 Main::Main(CkArgMsg* msg)
249 {
250
251     grainsize = 1;
252     par_SolverState* solver_msg = new (8 * sizeof(int))par_SolverState;
253     if(msg->argc < 2)
254     {
255         error_exit((char*)"Usage: sat filename grainsize\n");
256     }else
257         grainsize = atoi(msg->argv[2]);
258
259
260     CkPrintf("problem file:\t\t%s\ngrainsize:\t\t%d\nprocessor number:\t\t%d\n", msg->argv[1], grainsize, CkNumPes()); 
261
262     /* read file */
263
264     starttimer = CmiWallTimer();
265
266     /*read information from file */
267     gzFile in = gzopen(msg->argv[1], "rb");
268
269     strcpy(inputfile, msg->argv[1]);
270
271     if(in == NULL)
272     {
273         error_exit((char*)"Invalid input filename\n");
274     }
275
276     parse_confFile(in, *solver_msg);
277
278     /*  unit propagation */ 
279     simplify(*solver_msg);
280 #ifdef DEBUG
281     for(int __i = 0; __i<solver_msg->occurrence.size(); __i++)
282     {
283         FILE *file;
284         char outputfile[50];
285         sprintf(outputfile, "%s.sat", inputfile);
286         file = fopen(outputfile, "w");
287         for(int i=0; i<assignment.size(); i++)
288         {
289             fprintf(file, "%d\n", assignment[i]);
290         }
291     }
292
293 #endif
294
295     int unsolved = solver_msg->unsolvedClauses();
296
297     if(unsolved == 0)
298     {
299         CkPrintf(" This problem is solved by pre-processing\n");
300         CkExit();
301     }
302     readfiletimer = CmiWallTimer();
303     /*fire the first chare */
304     /* 1)Which variable is assigned which value this time, (variable, 1), current clauses status vector(), literal array activities */
305
306
307     /***  If grain size is larger than the clauses size, that means 'sequential' */
308     if(grainsize > solver_msg->clauses.size())
309     {
310         vector< vector<int> > seq_clauses;
311         for(int _i_=0; _i_<solver_msg->clauses.size(); _i_++)
312         {
313             if(solver_msg->clauses[_i_].size() > 0)
314             {
315                 vector<int> unsolvedclaus;
316                 par_Clause& cl = solver_msg->clauses[_i_];
317                 unsolvedclaus.resize(cl.size());
318                 for(int _j_=0; _j_<cl.size(); _j_++)
319                 {
320                     unsolvedclaus[_j_] = toInt(cl[_j_]);
321                 }
322                 seq_clauses.push_back(unsolvedclaus);
323             }
324         }
325         bool satisfiable_1 = seq_processing(solver_msg->var_size, seq_clauses);//seq_solve(next_state);
326
327         if(satisfiable_1)
328         {
329             CkPrintf("One solution found without using any parallel\n");
330         }else
331         {
332        
333             CkPrintf(" Unsatisfiable\n");
334         }
335         done(solver_msg->occurrence);
336         return;
337     }
338     mainProxy = thisProxy;
339     int max_index = get_max_element(solver_msg->occurrence);
340     
341     solver_msg->assigned_lit = par_Lit(max_index+1);
342     solver_msg->level = 0;
343     par_SolverState *not_msg = copy_solverstate(solver_msg);
344     
345     solver_msg->occurrence[max_index] = -2;
346     not_msg->assigned_lit = par_Lit(-max_index-1);
347     not_msg->occurrence[max_index] = -1;
348     
349     int positive_max = solver_msg->positive_occurrence[max_index];
350     if(positive_max >= solver_msg->occurrence[max_index] - positive_max)
351     {
352
353         // assign true first and then false
354         *((int *)CkPriorityPtr(solver_msg)) = INT_MIN;
355         CkSetQueueing(solver_msg, CK_QUEUEING_IFIFO);
356         solver_msg->lower = INT_MIN;
357         solver_msg->higher = 0;
358         CProxy_mySolver::ckNew(solver_msg);
359         
360         *((int *)CkPriorityPtr(not_msg)) = 0;
361         CkSetQueueing(not_msg, CK_QUEUEING_IFIFO);
362         not_msg->lower = 0;
363         not_msg->higher = INT_MAX;
364         CProxy_mySolver::ckNew(not_msg);
365     }else
366     {
367         *((int *)CkPriorityPtr(not_msg)) = INT_MIN;
368         CkSetQueueing(not_msg, CK_QUEUEING_IFIFO);
369         not_msg->lower = INT_MIN;
370         not_msg->higher = 0;
371         CProxy_mySolver::ckNew(not_msg);
372         
373         *((int *)CkPriorityPtr(solver_msg)) = 0;
374         CkSetQueueing(solver_msg, CK_QUEUEING_IFIFO);
375         solver_msg->lower = 0;
376         solver_msg->higher = INT_MAX;
377         CProxy_mySolver::ckNew(solver_msg);
378
379     }
380 }
381
382 Main::Main(CkMigrateMessage* msg) {}
383
384 void Main::done(CkVec<int> assignment)
385 {
386
387     double endtimer = CmiWallTimer();
388
389     CkPrintf("\nFile reading and processing time:         %f\n", readfiletimer-starttimer);
390     CkPrintf("Solving time:                             %f\n", endtimer - readfiletimer);
391  
392     FILE *file;
393     char outputfile[50];
394     sprintf(outputfile, "%s.sat", inputfile);
395     file = fopen(outputfile, "w");
396
397     for(int i=0; i<assignment.size(); i++)
398     {
399         fprintf(file, "%d\n", assignment[i]);
400     }
401     CkExit();
402 }
403 #include "main.def.h"
404