Fix direct calls to SDAG entry methods
[namd.git] / src / CudaPmeSolver.h
1 #ifndef CUDAPMESOLVER_H
2 #define CUDAPMESOLVER_H
3 #include "PmeSolver.h"
4 #include "CudaPmeSolver.decl.h"
5
6 #ifdef NAMD_CUDA
7 class CudaPmeXYZInitMsg : public CMessage_CudaPmeXYZInitMsg {
8 public:
9         CudaPmeXYZInitMsg(PmeGrid& pmeGrid) : pmeGrid(pmeGrid) {}
10         PmeGrid pmeGrid;
11 };
12
13 class CudaPmeXYInitMsg : public CMessage_CudaPmeXYInitMsg {
14 public:
15         CudaPmeXYInitMsg(PmeGrid& pmeGrid, CProxy_CudaPmePencilXY& pmePencilXY, CProxy_CudaPmePencilZ& pmePencilZ,
16                 CProxy_PmePencilXYMap& xyMap, CProxy_PmePencilXMap& zMap) : 
17                 pmeGrid(pmeGrid), pmePencilXY(pmePencilXY), pmePencilZ(pmePencilZ), xyMap(xyMap), zMap(zMap) {}
18         PmeGrid pmeGrid;
19   CProxy_CudaPmePencilXY pmePencilXY;
20   CProxy_CudaPmePencilZ pmePencilZ;
21   CProxy_PmePencilXMap zMap;
22   CProxy_PmePencilXYMap xyMap;
23 };
24
25 class CudaPmeXInitMsg : public CMessage_CudaPmeXInitMsg {
26 public:
27         CudaPmeXInitMsg(PmeGrid& pmeGrid,
28                 CProxy_CudaPmePencilX& pmePencilX, CProxy_CudaPmePencilY& pmePencilY, CProxy_CudaPmePencilZ& pmePencilZ,
29                 CProxy_PmePencilXMap& xMap, CProxy_PmePencilXMap& yMap, CProxy_PmePencilXMap& zMap) : 
30                 pmeGrid(pmeGrid), pmePencilX(pmePencilX), pmePencilY(pmePencilY), pmePencilZ(pmePencilZ),
31                 xMap(xMap), yMap(yMap), zMap(zMap) {}
32         PmeGrid pmeGrid;
33   CProxy_CudaPmePencilX pmePencilX;
34   CProxy_CudaPmePencilY pmePencilY;
35   CProxy_CudaPmePencilZ pmePencilZ;
36   CProxy_PmePencilXMap xMap;
37   CProxy_PmePencilXMap yMap;
38   CProxy_PmePencilXMap zMap;
39 };
40
41 class InitDeviceMsg : public CMessage_InitDeviceMsg {
42 public:
43         InitDeviceMsg(CProxy_ComputePmeCUDADevice deviceProxy) : deviceProxy(deviceProxy) {}
44         CProxy_ComputePmeCUDADevice deviceProxy;
45 };
46
47 class InitDeviceMsg2 : public CMessage_InitDeviceMsg2 {
48 public:
49         InitDeviceMsg2(int deviceID, cudaStream_t stream, CProxy_ComputePmeCUDAMgr mgrProxy) : 
50         deviceID(deviceID), stream(stream), mgrProxy(mgrProxy) {}
51         int deviceID;
52         cudaStream_t stream;
53         CProxy_ComputePmeCUDAMgr mgrProxy;
54 };
55
56 class CudaPmePencilXYZ : public CBase_CudaPmePencilXYZ {
57 public:
58         CudaPmePencilXYZ() {}
59         CudaPmePencilXYZ(CkMigrateMessage *m) {}
60         void initialize(CudaPmeXYZInitMsg *msg);
61         void initializeDevice(InitDeviceMsg *msg);
62         void energyAndVirialDone();
63 private:
64         void backwardDone();
65   CProxy_ComputePmeCUDADevice deviceProxy;
66 };
67
68 struct DeviceBuffer {
69         DeviceBuffer(int deviceID, bool isPeerDevice, float2* data) : deviceID(deviceID), isPeerDevice(isPeerDevice), data(data) {}
70         bool isPeerDevice;
71         int deviceID;
72         cudaEvent_t event;
73         float2 *data;
74 };
75
76 class DeviceDataMsg : public CMessage_DeviceDataMsg {
77 public:
78         DeviceDataMsg(int i, cudaEvent_t event, float2 *data) : i(i), event(event), data(data) {}
79         int i;
80         cudaEvent_t event;
81         float2 *data;
82 };
83
84 class CudaPmePencilXY : public CBase_CudaPmePencilXY {
85 public:
86         CudaPmePencilXY_SDAG_CODE
87         CudaPmePencilXY() : numGetDeviceBuffer(0), eventCreated(false) {}
88         CudaPmePencilXY(CkMigrateMessage *m) : numGetDeviceBuffer(0), eventCreated(false) {}
89         ~CudaPmePencilXY();
90         void initialize(CudaPmeXYInitMsg *msg);
91         void initializeDevice(InitDeviceMsg *msg);
92 private:
93         void forwardDone();
94         void backwardDone();
95         void recvDataFromZ(PmeBlockMsg *msg);
96         void start(const CkCallback &);
97         void setDeviceBuffers();
98         float2* getData(const int i, const bool sameDevice);
99         int deviceID;
100         cudaStream_t stream;
101         cudaEvent_t event;
102         bool eventCreated;
103         int imsgZ;
104         int numDeviceBuffers;
105         int numGetDeviceBuffer;
106         std::vector<DeviceBuffer> deviceBuffers;
107   CProxy_ComputePmeCUDADevice deviceProxy;
108   CProxy_CudaPmePencilZ pmePencilZ;
109   CProxy_PmePencilXMap zMap;
110 };
111
112 class CudaPmePencilX : public CBase_CudaPmePencilX {
113 public:
114         CudaPmePencilX_SDAG_CODE
115         CudaPmePencilX() : numGetDeviceBuffer(0), eventCreated(false) {}
116         CudaPmePencilX(CkMigrateMessage *m) : numGetDeviceBuffer(0), eventCreated(false) {}
117         ~CudaPmePencilX();
118         void initialize(CudaPmeXInitMsg *msg);
119         void initializeDevice(InitDeviceMsg *msg);
120 private:
121         void forwardDone();
122         void backwardDone();
123         void recvDataFromY(PmeBlockMsg *msg);
124         void start(const CkCallback &);
125         void setDeviceBuffers();
126         float2* getData(const int i, const bool sameDevice);
127         int deviceID;
128         cudaStream_t stream;
129         cudaEvent_t event;
130         bool eventCreated;
131         int imsgY;
132         int numDeviceBuffers;
133         int numGetDeviceBuffer;
134         std::vector<DeviceBuffer> deviceBuffers;
135   CProxy_ComputePmeCUDADevice deviceProxy;
136   CProxy_CudaPmePencilY pmePencilY;
137   CProxy_PmePencilXMap yMap;
138 };
139
140 class CudaPmePencilY : public CBase_CudaPmePencilY {
141 public:
142         CudaPmePencilY_SDAG_CODE
143         CudaPmePencilY() : numGetDeviceBufferZ(0), numGetDeviceBufferX(0), eventCreated(false) {}
144         CudaPmePencilY(CkMigrateMessage *m) : numGetDeviceBufferZ(0), numGetDeviceBufferX(0), eventCreated(false) {}
145         ~CudaPmePencilY();
146         void initialize(CudaPmeXInitMsg *msg);
147         void initializeDevice(InitDeviceMsg2 *msg);
148 private:
149         void forwardDone();
150         void backwardDone();
151         void recvDataFromX(PmeBlockMsg *msg);
152         void recvDataFromZ(PmeBlockMsg *msg);
153         void start(const CkCallback &);
154         void setDeviceBuffers();
155         float2* getDataForX(const int i, const bool sameDevice);
156         float2* getDataForZ(const int i, const bool sameDevice);
157         int deviceID;
158         cudaStream_t stream;
159         cudaEvent_t event;
160         bool eventCreated;
161         int imsgZ, imsgX;
162         int imsgZZ, imsgXX;
163         int numGetDeviceBufferZ;
164         int numGetDeviceBufferX;
165         int numDeviceBuffersZ;
166         int numDeviceBuffersX;
167         std::vector<DeviceBuffer> deviceBuffersZ;
168         std::vector<DeviceBuffer> deviceBuffersX;
169   CProxy_CudaPmePencilX pmePencilX;
170   CProxy_CudaPmePencilZ pmePencilZ;
171   CProxy_PmePencilXMap xMap;
172   CProxy_PmePencilXMap zMap;
173 };
174
175 class CudaPmePencilZ : public CBase_CudaPmePencilZ {
176 public:
177         CudaPmePencilZ_SDAG_CODE
178         CudaPmePencilZ() : numGetDeviceBufferY(0), numGetDeviceBufferXY(0), eventCreated(false) {}
179         CudaPmePencilZ(CkMigrateMessage *m) : numGetDeviceBufferY(0), numGetDeviceBufferXY(0), eventCreated(false) {}
180         ~CudaPmePencilZ();
181         void initialize(CudaPmeXInitMsg *msg);
182         void initialize(CudaPmeXYInitMsg *msg);
183         void initializeDevice(InitDeviceMsg2 *msg);
184         void energyAndVirialDone();
185 private:
186         void backwardDone();
187         void recvDataFromY(PmeBlockMsg *msg);
188         void start(const CkCallback &);
189         void setDeviceBuffers();
190         float2* getData(const int i, const bool sameDevice);
191         int deviceID;
192         cudaStream_t stream;
193         cudaEvent_t event;
194         bool eventCreated;
195         int imsgY;
196         int numDeviceBuffers;
197         int numGetDeviceBufferY;
198         std::vector<DeviceBuffer> deviceBuffers;
199   CProxy_CudaPmePencilY pmePencilY;
200   CProxy_PmePencilXMap yMap;
201
202         bool useXYslab;
203         int numGetDeviceBufferXY;
204   CProxy_CudaPmePencilXY pmePencilXY;
205   CProxy_PmePencilXYMap xyMap;
206 };
207
208 #endif // NAMD_CUDA
209 #endif //CUDAPMESOLVER_H