Fix direct calls to SDAG entry methods
[namd.git] / src / PmeSolver.h
1 #ifndef PMESOLVER_H
2 #define PMESOLVER_H
3 #include <vector>
4 #include "ReductionMgr.h"
5 #include "PatchMap.h"
6 #include "PmeSolverUtil.h"
7 #include "PmeSolver.decl.h"
8
9 class PmePencilXYZMap : public CBase_PmePencilXYZMap {
10 public:
11   PmePencilXYZMap(int pe) : pe(pe) {
12   }
13   //PmePencilXYZMap(CkMigrateMessage *m) {}
14   int registerArray(CkArrayIndex& numElements, CkArrayID aid) {
15     return 0;
16   }
17   virtual int procNum(int, const CkArrayIndex& idx) {
18     return pe;
19   }
20   virtual void populateInitial(int, CkArrayOptions &, void *msg, CkArrMgr *mgr) {
21     if (pe == CkMyPe()) {
22       if ( ! msg ) NAMD_bug("PmePencilXYZMap::populateInitial, multiple pencils on a pe?");
23       mgr->insertInitial(CkArrayIndex1D(0), msg);
24       msg = NULL;
25     }
26     mgr->doneInserting();
27     if (msg != NULL) CkFreeMsg(msg);
28   }
29 private:
30   const int pe;
31 };
32
33 class PmePencilXMap : public CBase_PmePencilXMap {
34 public:
35   PmePencilXMap(int ia, int ib, int width, const std::vector<int>& pes) : ia(ia), ib(ib), width(width), pes(pes) {}
36   int registerArray(CkArrayIndex& numElements, CkArrayID aid) {
37     return 0;
38   }
39   virtual int procNum(int, const CkArrayIndex& idx) {
40     int ind = idx.data()[ia] + idx.data()[ib] * width;
41     if (ind < 0 || ind >= pes.size())
42       NAMD_bug("PmePencilXMap::procNum, index out of bounds");
43     return pes[ind];
44   }
45   virtual void populateInitial(int, CkArrayOptions &, void *msg, CkArrMgr *mgr) {
46     for (int i=0;i < pes.size();i++) {
47       if (pes[i] == CkMyPe()) {
48         if ( msg == NULL ) NAMD_bug("PmePencilXMap::populateInitial, multiple pencils on a pe?");
49         CkArrayIndex3D ai(0,0,0);
50         ai.data()[ib] = i / width;
51         ai.data()[ia] = i % width;
52         //fprintf(stderr, "Pe %d i %d at %d %d\n", pes[i], i, ai.data()[ia], ai.data()[ib]);
53         if ( procNum(0,ai) != CkMyPe() ) NAMD_bug("PmePencilXMap::populateInitial, map is inconsistent");
54         mgr->insertInitial(ai,msg);
55         msg = NULL;
56       }
57     }
58     mgr->doneInserting();
59     if (msg != NULL) CkFreeMsg(msg);
60   }
61 private:
62   // Index of CkArrayIndex data()
63   const int ia, ib;
64   // Width of the 2D array in pes
65   const int width;
66   // List of Pes. Index is given by pes[i + j*width]
67   const std::vector<int> pes;
68 };
69
70 class PmePencilXYMap : public CBase_PmePencilXYMap {
71 public:
72   PmePencilXYMap(const std::vector<int>& pes) : pes(pes) {}
73   int registerArray(CkArrayIndex& numElements, CkArrayID aid) {
74     return 0;
75   }
76   virtual int procNum(int, const CkArrayIndex& idx) {
77     int ind = idx.data()[2];
78     if (ind < 0 || ind >= pes.size())
79       NAMD_bug("PmePencilXYMap::procNum, index out of bounds");
80     return pes[ind];
81   }
82   virtual void populateInitial(int, CkArrayOptions &, void *msg, CkArrMgr *mgr) {
83     for (int i=0;i < pes.size();i++) {
84       if (pes[i] == CkMyPe()) {
85         if ( msg == NULL ) NAMD_bug("PmePencilXYMap::populateInitial, multiple pencils on a pe?");
86         CkArrayIndex3D ai(0,0,0);
87         ai.data()[2] = i;
88         if ( procNum(0,ai) != CkMyPe() ) NAMD_bug("PmePencilXYMap::populateInitial, map is inconsistent");
89         mgr->insertInitial(ai, msg);
90         msg = NULL;
91       }
92     }
93     mgr->doneInserting();
94     if (msg != NULL) CkFreeMsg(msg);
95   }
96 private:
97   // List of Pes.
98   const std::vector<int> pes;
99 };
100
101 class PmeStartMsg : public CMessage_PmeStartMsg {
102 public:
103   float *data;
104   int dataSize;
105   int device;
106 };
107
108 class PmeRunMsg : public CMessage_PmeRunMsg {
109 public:
110   bool doEnergy, doVirial;
111   int numStrayAtoms;
112   Lattice lattice;
113 };
114
115 class PmeDoneMsg : public CMessage_PmeDoneMsg {
116 public:
117   PmeDoneMsg(int i, int j) : i(i), j(j) {}
118   int i, j;
119 };
120
121 class PmeBlockMsg : public CMessage_PmeBlockMsg {
122 public:
123   float2 *data;
124   int dataSize;
125   int x, y, z;
126   bool doEnergy, doVirial;
127   int numStrayAtoms;
128   Lattice lattice;
129 };
130
131 class PmePencilXYZ : public CBase_PmePencilXYZ {
132 public:
133   PmePencilXYZ_SDAG_CODE
134   PmePencilXYZ();
135   PmePencilXYZ(CkMigrateMessage *m);
136   virtual ~PmePencilXYZ();
137   void skip();
138 protected:
139   PmeGrid pmeGrid;
140   bool doEnergy, doVirial;
141   FFTCompute* fftCompute;
142   PmeKSpaceCompute* pmeKSpaceCompute;
143   Lattice lattice;
144   int numStrayAtoms;
145   virtual void backwardDone();
146   void submitReductions();
147 private:
148   void forwardFFT();
149   void backwardFFT();
150   void forwardDone();
151   void initFFT(PmeStartMsg *msg);
152
153   SubmitReduction* reduction;
154
155 };
156
157 class PmePencilXY : public CBase_PmePencilXY {
158 public:
159   PmePencilXY_SDAG_CODE
160   PmePencilXY();
161   PmePencilXY(CkMigrateMessage *m);
162   virtual ~PmePencilXY();
163 protected:
164   PmeGrid pmeGrid;
165   bool doEnergy, doVirial;
166   FFTCompute* fftCompute;
167   PmeTranspose* pmeTranspose;
168   std::vector<int> blockSizes;
169   Lattice lattice;
170   int numStrayAtoms;
171   void initBlockSizes();
172   int imsg;
173 private:
174   void forwardFFT();
175   void backwardFFT();
176   void initFFT(PmeStartMsg *msg);
177   virtual void forwardDone();
178   virtual void backwardDone();
179   virtual void recvDataFromZ(PmeBlockMsg *msg);
180   virtual void start(const CkCallback &);
181
182 };
183
184 class PmePencilX : public CBase_PmePencilX {
185 public:
186   PmePencilX_SDAG_CODE
187   PmePencilX();
188   PmePencilX(CkMigrateMessage *m);
189   virtual ~PmePencilX();
190 protected:
191   PmeGrid pmeGrid;
192   bool doEnergy, doVirial;
193   FFTCompute* fftCompute;
194   PmeTranspose* pmeTranspose;
195   std::vector<int> blockSizes;
196   Lattice lattice;
197   int numStrayAtoms;
198   void initBlockSizes();
199   int imsg;
200 private:
201   void forwardFFT();
202   void backwardFFT();
203   void initFFT(PmeStartMsg *msg);
204   virtual void forwardDone();
205   virtual void backwardDone();
206   virtual void recvDataFromY(PmeBlockMsg *msg);
207   virtual void start(const CkCallback &);
208
209 };
210
211 class PmePencilY : public CBase_PmePencilY {
212 public:
213   PmePencilY_SDAG_CODE
214   PmePencilY();
215   PmePencilY(CkMigrateMessage *m);
216   virtual ~PmePencilY();
217 protected:
218   PmeGrid pmeGrid;
219   bool doEnergy, doVirial;
220   FFTCompute* fftCompute;
221   PmeTranspose* pmeTranspose;
222   std::vector<int> blockSizes;
223   Lattice lattice;
224   int numStrayAtoms;
225   void initBlockSizes();
226   int imsg;
227 private:
228   void forwardFFT();
229   void backwardFFT();
230   void initFFT(PmeStartMsg *msg);
231   virtual void forwardDone();
232   virtual void backwardDone();
233   virtual void recvDataFromX(PmeBlockMsg *msg);
234   virtual void recvDataFromZ(PmeBlockMsg *msg);
235   virtual void start(const CkCallback &);
236
237 };
238
239 class PmePencilZ : public CBase_PmePencilZ {
240 public:
241   PmePencilZ_SDAG_CODE
242   PmePencilZ();
243   PmePencilZ(CkMigrateMessage *m);
244   virtual ~PmePencilZ();
245   void skip();
246 protected:
247   PmeGrid pmeGrid;
248   bool doEnergy, doVirial;
249   FFTCompute* fftCompute;
250   PmeTranspose* pmeTranspose;
251   PmeKSpaceCompute* pmeKSpaceCompute;
252   std::vector<int> blockSizes;
253   Lattice lattice;
254   int numStrayAtoms;
255   void initBlockSizes();
256   void submitReductions();
257   int imsg;
258 private:
259   void forwardFFT();
260   void backwardFFT();
261   void forwardDone();
262   void initFFT(PmeStartMsg *msg);
263   virtual void backwardDone();
264   virtual void recvDataFromY(PmeBlockMsg *msg);
265   virtual void start(const CkCallback &);
266
267   SubmitReduction* reduction;
268
269 };
270
271 // #define CK_TEMPLATES_ONLY
272 // #include "PmeSolver.def.h"
273 // #undef CK_TEMPLATES_ONLY
274
275 #endif // PMESOLVER_H