526c12389c3c196965cf3638ab0b5170677fd197
[charm.git] / example / fft-trans / fft1d.C
1 #include "fft1d.decl.h"
2 #include <fftw3.h>
3 #include <limits>
4 #include "fileio.h"
5 #include "NodeHelper.h"
6
7 #define TWOPI 6.283185307179586
8
9 /*readonly*/ CProxy_Main mainProxy;
10 /*readonly*/ int numChunks;
11 /*readonly*/ int numThreads;
12 /*readonly*/ uint64_t N;
13 static CmiNodeLock fft_plan_lock;
14 #include "fftmacro.h"
15 CProxy_FuncNodeHelper nodeHelperProxy;
16 /** called by initnode once per node to support node level locking for
17     fftw plan create/destroy operations */
18 #define MODE 1
19
20 extern "C" void doCalc(int first,int last, int & result, int paramNum, void * param)
21 {
22   result=first;
23   fft_execute(((fft_plan*)param)[first]);
24 }
25
26 void initplanlock ()
27
28 {
29   fft_plan_lock=CmiCreateLock();
30 }
31
32 struct fftMsg : public CMessage_fftMsg {
33   int source;
34   fft_complex *data;
35 };
36
37 struct Main : public CBase_Main {
38   double start;
39   CProxy_fft fftProxy;
40
41   Main(CkArgMsg* m) {
42     numChunks = atoi(m->argv[1]);
43     N = atol(m->argv[2]);
44     if(m->argc>=4)
45       numThreads = atol(m->argv[3]);
46     else
47       numThreads = CmiMyNodeSize();  //default to 1/core
48     delete m;
49     
50     mainProxy = thisProxy;
51
52     /* how to nodify this computation? */
53     /* We make one block alloc per chare and divide the work evenly
54        across the number of threads.
55      * cache locality issues... 
56
57      *       The NodeHelper scheme presents a problem in cache
58      *       ignorance.  We push tasks into the queue as the remote
59      *       message dependencies are met, however the execution of
60      *       dequeued tasks performance will have significant cache
61      *       limitations obfuscated to our scheduler.  Our helper
62      *       threads will block while fetching data into cache local
63      *       to the thread.  If we only have 1 thread per core, we
64      *       have no way to self overlap those operations.  This
65      *       implies that there are probably conditions under which
66      *       more than one nodehelper thread per core will result in
67      *       better performance.  A natural sweet spot for these
68      *       should be explored in the SMT case where one thread per
69      *       SMT will allow for natural overlap of execution based on
70      *       cache availability, as controlled by the OS without
71      *       additional pthread context switching overhead.  A further
72      *       runtime based virtualized overthreading may provide
73      *       further benefits depending on thread overhead.
74      */
75     if (N % numChunks != 0)
76       CkAbort("numChunks not a factor of N\n");
77
78     // Construct an array of fft chares to do the calculation
79     fftProxy = CProxy_fft::ckNew(numChunks);
80     // Construct a nodehelper to do the calculation
81     nodeHelperProxy = CProxy_FuncNodeHelper::ckNew(MODE, numChunks, numThreads);
82     
83   }
84
85   void startFFT() {
86     start = CkWallTimer();
87     // Broadcast the 'go' signal to the fft chare array
88     fftProxy.doFFT();
89   }
90
91   void doneFFT() {
92     double time = CkWallTimer() - start;
93     double gflops = 5 * (double)N*N * log2((double)N*N) / (time * 1000000000);
94     CkPrintf("chares: %d\ncores: %d\nThreads: %d\nsize: %ld\ntime: %f sec\nrate: %f GFlop/s\n",
95              numChunks, CkNumPes(), numThreads, N*N, time, gflops);
96
97     fftProxy.initValidation();
98   }
99
100   void printResidual(realType r) {
101     CkPrintf("residual = %g\n", r);
102     CkExit();
103   }
104
105 };
106
107 struct fft : public CBase_fft {
108   fft_SDAG_CODE
109
110   int iteration, count;
111   uint64_t n;
112   fft_plan *plan;
113   fft_plan p1;
114   fftMsg **msgs;
115   fft_complex *in, *out;
116   bool validating;
117   int nPerThread;
118   fft() {
119     __sdag_init();
120
121     validating = false;
122
123     n = N*N/numChunks;
124
125     in = (fft_complex*) fft_malloc(sizeof(fft_complex) * n);
126     out = (fft_complex*) fft_malloc(sizeof(fft_complex) * n);
127     nPerThread= n/numThreads;
128     int length[] = {nPerThread};
129     CmiLock(fft_plan_lock);
130     size_t offset=0;
131     plan= new fft_plan[numThreads];
132     for(int i=0; i < numThreads; i++,offset+=nPerThread)
133       {
134         /* ??? should the dist be nPerThread as the fft is performed as 1d of length nPerThread?? */
135         plan[i] = fft_plan_many_dft(1, length, N/numChunks/numThreads, out+offset, length, 1, N/numThreads,
136                             out+offset, length, 1, N/numThreads, FFTW_FORWARD, FFTW_ESTIMATE);
137       }
138     CmiUnlock(fft_plan_lock);
139     srand48(thisIndex);
140     for(int i = 0; i < n; i++) {
141       in[i][0] = drand48();
142       in[i][1] = drand48();
143     }
144
145     msgs = new fftMsg*[numChunks];
146     for(int i = 0; i < numChunks; i++) {
147       msgs[i] = new (n/numChunks) fftMsg;
148       msgs[i]->source = thisIndex;
149     }
150
151     // Reduction to the mainchare to signal that initialization is complete
152     contribute(CkCallback(CkIndex_Main::startFFT(), mainProxy));
153   }
154
155   void sendTranspose(fft_complex *src_buf) {
156     // All-to-all transpose by constructing and sending
157     // point-to-point messages to each chare in the array.
158     for(int i = thisIndex; i < thisIndex+numChunks; i++) {
159       //  Stagger communication order to avoid hotspots and the
160       //  associated contention.
161       int k = i % numChunks;
162       for(int j = 0, l = 0; j < N/numChunks; j++)
163         memcpy(msgs[k]->data[(l++)*N/numChunks], src_buf[k*N/numChunks+j*N], sizeof(fft_complex)*N/numChunks);
164
165       // Tag each message with the iteration in which it was
166       // generated, to prevent mis-matched messages from chares that
167       // got all of their input quickly and moved to the next step.
168       CkSetRefNum(msgs[k], iteration);
169       thisProxy[k].getTranspose(msgs[k]);
170       // Runtime system takes ownership of messages once they're sent
171       msgs[k] = NULL;
172     }
173   }
174
175   void applyTranspose(fftMsg *m) {
176     int k = m->source;
177     for(int j = 0, l = 0; j < N/numChunks; j++)
178       for(int i = 0; i < N/numChunks; i++) {
179         out[k*N/numChunks+(i*N+j)][0] = m->data[l][0];
180         out[k*N/numChunks+(i*N+j)][1] = m->data[l++][1];
181       }
182
183     // Save just-received messages to reuse for later sends, to
184     // avoid reallocation
185     delete msgs[k];
186     msgs[k] = m;
187     msgs[k]->source = thisIndex;
188   }
189
190   void twiddle(realType sign) {
191     realType a, c, s, re, im;
192
193     int k = thisIndex;
194     for(int i = 0; i < N/numChunks; i++)
195       for(int j = 0; j < N; j++) {
196         a = sign * (TWOPI*(i+k*N/numChunks)*j)/(N*N);
197         c = cos(a);
198         s = sin(a);
199
200         int idx = i*N+j;
201
202         re = c*out[idx][0] - s*out[idx][1];
203         im = s*out[idx][0] + c*out[idx][1];
204         out[idx][0] = re;
205         out[idx][1] = im;
206       }
207   }
208   void fftHelperLaunch()
209   {
210     //kick off thread computation
211     FuncNodeHelper *nth = nodeHelperProxy[CkMyNode()].ckLocalBranch();
212     nth->parallelizeFunc(doCalc, numThreads, numThreads, thisIndex, numThreads, 1, 1, plan, 0, NULL);
213     
214   }
215
216   void initValidation() {
217     memcpy(in, out, sizeof(fft_complex) * n);
218
219     validating = true;
220     int length[] = {nPerThread};
221     CmiLock(fft_plan_lock);
222     size_t offset=0;
223     plan= new fft_plan[numThreads];
224     for(int i=0; i < numThreads; i++,offset+=nPerThread)
225       {
226         //      fft_destroy_plan(plan[i]);
227         plan[i] = fft_plan_many_dft(1, length, N/numChunks/numThreads, out+offset, length, 1, N/numThreads,
228                             out+offset, length, 1, N/numThreads, FFTW_BACKWARD, FFTW_ESTIMATE);
229       }
230     CmiUnlock(fft_plan_lock);
231     contribute(CkCallback(CkIndex_Main::startFFT(), mainProxy));
232   }
233
234   void calcResidual() {
235     double infNorm = 0.0;
236
237     srand48(thisIndex);
238     for(int i = 0; i < n; i++) {
239       out[i][0] = out[i][0]/(N*N) - drand48();
240       out[i][1] = out[i][1]/(N*N) - drand48();
241
242       double mag = sqrt(pow(out[i][0], 2) + pow(out[i][1], 2));
243       if(mag > infNorm) infNorm = mag;
244     }
245
246     double r = infNorm / (std::numeric_limits<double>::epsilon() * log((double)N * N));
247
248     CkCallback cb(CkReductionTarget(Main, printResidual), mainProxy);
249     contribute(sizeof(double), &r, CkReduction::max_double, cb);
250   }
251
252   fft(CkMigrateMessage* m) {}
253   ~fft() {}
254 };
255
256 #include "fft1d.def.h"