Adapted the fft-trans to the newly polished API.
[charm.git] / example / fft-trans / fft1d.C
1 #include "fft1d.decl.h"
2 #include <fftw3.h>
3 #include <limits>
4 #include "fileio.h"
5 #include "NodeHelperAPI.h"
6
7 #define TWOPI 6.283185307179586
8
9 /*readonly*/ CProxy_Main mainProxy;
10 /*readonly*/ int numChunks;
11 /*readonly*/ int numThreads;
12 /*readonly*/ uint64_t N;
13 static CmiNodeLock fft_plan_lock;
14 #include "fftmacro.h"
15
16 CProxy_FuncNodeHelper nodeHelperProxy;
17 /** called by initnode once per node to support node level locking for
18     fftw plan create/destroy operations */
19 #define NODEHELPER_MODE NODEHELPER_STATIC 
20
21 extern "C" void doCalc(int first,int last, void *result, int paramNum, void * param)
22 {
23   //result=first;
24   fft_execute(((fft_plan*)param)[first]);
25 }
26
27 void initplanlock ()
28
29 {
30   fft_plan_lock=CmiCreateLock();
31 }
32
33 struct fftMsg : public CMessage_fftMsg {
34   int source;
35   fft_complex *data;
36 };
37
38 struct Main : public CBase_Main {
39   double start;
40   CProxy_fft fftProxy;
41
42   Main(CkArgMsg* m) {
43     numChunks = atoi(m->argv[1]);
44     N = atol(m->argv[2]);
45     if(m->argc>=4)
46       numThreads = atol(m->argv[3]);
47     else
48       numThreads = CmiMyNodeSize();  //default to 1/core
49     delete m;
50     
51     mainProxy = thisProxy;
52
53     /* how to nodify this computation? */
54     /* We make one block alloc per chare and divide the work evenly
55        across the number of threads.
56      * cache locality issues... 
57
58      *       The NodeHelper scheme presents a problem in cache
59      *       ignorance.  We push tasks into the queue as the remote
60      *       message dependencies are met, however the execution of
61      *       dequeued tasks performance will have significant cache
62      *       limitations obfuscated to our scheduler.  Our helper
63      *       threads will block while fetching data into cache local
64      *       to the thread.  If we only have 1 thread per core, we
65      *       have no way to self overlap those operations.  This
66      *       implies that there are probably conditions under which
67      *       more than one nodehelper thread per core will result in
68      *       better performance.  A natural sweet spot for these
69      *       should be explored in the SMT case where one thread per
70      *       SMT will allow for natural overlap of execution based on
71      *       cache availability, as controlled by the OS without
72      *       additional pthread context switching overhead.  A further
73      *       runtime based virtualized overthreading may provide
74      *       further benefits depending on thread overhead.
75      */
76     if (N % numChunks != 0)
77       CkAbort("numChunks not a factor of N\n");
78
79     // Construct an array of fft chares to do the calculation
80     fftProxy = CProxy_fft::ckNew(numChunks);
81
82     // Construct a nodehelper to do the calculation
83     nodeHelperProxy = NodeHelper_Init(NODEHELPER_MODE, numThreads);
84     
85     CkStartQD(CkIndex_Main::initDone((CkQdMsg *)0), &thishandle);
86   }
87
88   void initDone(CkQdMsg *msg){
89     delete msg;
90     startFFT();
91   }
92   
93   void startFFT() {
94     start = CkWallTimer();
95     // Broadcast the 'go' signal to the fft chare array
96     fftProxy.doFFT();
97   }
98
99   void doneFFT() {
100     double time = CkWallTimer() - start;
101     double gflops = 5 * (double)N*N * log2((double)N*N) / (time * 1000000000);
102     CkPrintf("chares: %d\ncores: %d\nThreads: %d\nsize: %ld\ntime: %f sec\nrate: %f GFlop/s\n",
103              numChunks, CkNumPes(), numThreads, N*N, time, gflops);
104
105     fftProxy.initValidation();
106   }
107
108   void printResidual(realType r) {
109     CkPrintf("residual = %g\n", r);
110     CkExit();
111   }
112
113 };
114
115 struct fft : public CBase_fft {
116   fft_SDAG_CODE
117
118   int iteration, count;
119   uint64_t n;
120   fft_plan *plan;
121   fft_plan p1;
122   fftMsg **msgs;
123   fft_complex *in, *out;
124   bool validating;
125   int nPerThread;
126   fft() {
127     __sdag_init();
128
129     validating = false;
130
131     n = N*N/numChunks;
132
133     in = (fft_complex*) fft_malloc(sizeof(fft_complex) * n);
134     out = (fft_complex*) fft_malloc(sizeof(fft_complex) * n);
135     nPerThread= n/numThreads;
136     int length[] = {nPerThread};
137     CmiLock(fft_plan_lock);
138     size_t offset=0;
139     plan= new fft_plan[numThreads];
140     for(int i=0; i < numThreads; i++,offset+=nPerThread)
141       {
142         /* ??? should the dist be nPerThread as the fft is performed as 1d of length nPerThread?? */
143         plan[i] = fft_plan_many_dft(1, length, N/numChunks/numThreads, out+offset, length, 1, N/numThreads,
144                             out+offset, length, 1, N/numThreads, FFTW_FORWARD, FFTW_ESTIMATE);
145       }
146     CmiUnlock(fft_plan_lock);
147     srand48(thisIndex);
148     for(int i = 0; i < n; i++) {
149       in[i][0] = drand48();
150       in[i][1] = drand48();
151     }
152
153     msgs = new fftMsg*[numChunks];
154     for(int i = 0; i < numChunks; i++) {
155       msgs[i] = new (n/numChunks) fftMsg;
156       msgs[i]->source = thisIndex;
157     }
158
159     // Reduction to the mainchare to signal that initialization is complete
160     //contribute(CkCallback(CkIndex_Main::startFFT(), mainProxy));
161   }
162
163   void sendTranspose(fft_complex *src_buf) {
164     // All-to-all transpose by constructing and sending
165     // point-to-point messages to each chare in the array.
166     for(int i = thisIndex; i < thisIndex+numChunks; i++) {
167       //  Stagger communication order to avoid hotspots and the
168       //  associated contention.
169       int k = i % numChunks;
170       for(int j = 0, l = 0; j < N/numChunks; j++)
171         memcpy(msgs[k]->data[(l++)*N/numChunks], src_buf[k*N/numChunks+j*N], sizeof(fft_complex)*N/numChunks);
172
173       // Tag each message with the iteration in which it was
174       // generated, to prevent mis-matched messages from chares that
175       // got all of their input quickly and moved to the next step.
176       CkSetRefNum(msgs[k], iteration);
177       thisProxy[k].getTranspose(msgs[k]);
178       // Runtime system takes ownership of messages once they're sent
179       msgs[k] = NULL;
180     }
181   }
182
183   void applyTranspose(fftMsg *m) {
184     int k = m->source;
185     for(int j = 0, l = 0; j < N/numChunks; j++)
186       for(int i = 0; i < N/numChunks; i++) {
187         out[k*N/numChunks+(i*N+j)][0] = m->data[l][0];
188         out[k*N/numChunks+(i*N+j)][1] = m->data[l++][1];
189       }
190
191     // Save just-received messages to reuse for later sends, to
192     // avoid reallocation
193     delete msgs[k];
194     msgs[k] = m;
195     msgs[k]->source = thisIndex;
196   }
197
198   void twiddle(realType sign) {
199     realType a, c, s, re, im;
200
201     int k = thisIndex;
202     for(int i = 0; i < N/numChunks; i++)
203       for(int j = 0; j < N; j++) {
204         a = sign * (TWOPI*(i+k*N/numChunks)*j)/(N*N);
205         c = cos(a);
206         s = sin(a);
207
208         int idx = i*N+j;
209
210         re = c*out[idx][0] - s*out[idx][1];
211         im = s*out[idx][0] + c*out[idx][1];
212         out[idx][0] = re;
213         out[idx][1] = im;
214       }
215   }
216   void fftHelperLaunch()
217   {
218     //kick off thread computation
219     //FuncNodeHelper *nth = nodeHelperProxy[CkMyNode()].ckLocalBranch();
220     //nth->parallelizeFunc(doCalc, numThreads, numThreads, thisIndex, numThreads, 1, 1, plan, 0, NULL);
221     NodeHelper_Parallelize(nodeHelperProxy, doCalc, 0, NULL, 0, numChunks, 0, numChunks-1);
222     
223   }
224
225   void initValidation() {
226     memcpy(in, out, sizeof(fft_complex) * n);
227
228     validating = true;
229     int length[] = {nPerThread};
230     CmiLock(fft_plan_lock);
231     size_t offset=0;
232     plan= new fft_plan[numThreads];
233     for(int i=0; i < numThreads; i++,offset+=nPerThread)
234       {
235         //      fft_destroy_plan(plan[i]);
236         plan[i] = fft_plan_many_dft(1, length, N/numChunks/numThreads, out+offset, length, 1, N/numThreads,
237                             out+offset, length, 1, N/numThreads, FFTW_BACKWARD, FFTW_ESTIMATE);
238       }
239     CmiUnlock(fft_plan_lock);
240     contribute(CkCallback(CkIndex_Main::startFFT(), mainProxy));
241   }
242
243   void calcResidual() {
244     double infNorm = 0.0;
245
246     srand48(thisIndex);
247     for(int i = 0; i < n; i++) {
248       out[i][0] = out[i][0]/(N*N) - drand48();
249       out[i][1] = out[i][1]/(N*N) - drand48();
250
251       double mag = sqrt(pow(out[i][0], 2) + pow(out[i][1], 2));
252       if(mag > infNorm) infNorm = mag;
253     }
254
255     double r = infNorm / (std::numeric_limits<double>::epsilon() * log((double)N * N));
256
257     CkCallback cb(CkReductionTarget(Main, printResidual), mainProxy);
258     contribute(sizeof(double), &r, CkReduction::max_double, cb);
259   }
260
261   fft(CkMigrateMessage* m) {}
262   ~fft() {}
263 };
264
265 #include "fft1d.def.h"