MSA: Set PageArray's synchronization reduction client at contribute time
[charm.git] / src / libs / ck-libs / multiphaseSharedArrays / msa-distArray.h
1 // emacs mode line -*- mode: c++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*-
2 #ifndef MSA_DISTARRAY_H
3 #define MSA_DISTARRAY_H
4
5 #include <utility>
6 #include <algorithm>
7 #include "msa-DistPageMgr.h"
8
9
10 struct MSA_InvalidHandle { };
11
12 template <typename ENTRY>
13 class Writable
14 {
15     ENTRY &e;
16     
17 public:
18     Writable(ENTRY &e_) : e(e_) {}
19     inline const ENTRY& operator= (const ENTRY& rhs) { e = rhs; }
20 };
21
22 template <typename ENTRY, class ENTRY_OPS_CLASS>
23 class Accumulable
24 {
25     ENTRY &e;
26     
27 public:
28     Accumulable(ENTRY &e_) : e(e_) {}
29     void operator+=(const ENTRY &rhs_)
30         { ENTRY_OPS_CLASS::accumulate(e, rhs_); }
31 };
32
33
34 /**
35    The MSA1D class is a handle to a distributed shared array of items
36    of data type ENTRY. There are nEntries total numer of ENTRY's, with
37    ENTRIES_PER_PAGE data items per "page".  It is implemented as a
38    Chare Array of pages, and a Group representing the local cache.
39
40    The requirements for the templates are:
41      ENTRY: User data class stored in the array, with at least:
42         - A default constructor and destructor
43         - A working assignment operator
44         - A working pup routine
45      ENTRY_OPS_CLASS: Used to combine values for "accumulate":
46         - A method named "getIdentity", taking no arguments and
47           returning an ENTRY to use before any accumulation.
48         - A method named "accumulate", taking a source/dest ENTRY by reference
49           and an ENTRY to add to it by value or const reference.
50      ENTRIES_PER_PAGE: Optional integer number of ENTRY objects
51         to store and communicate at once.  For good performance,
52         make sure this value is a power of two.
53  */
54 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE=MSA_DEFAULT_ENTRIES_PER_PAGE>
55 class MSA1D
56 {
57 public:
58     typedef MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CacheGroup_t;
59     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
60     typedef CProxy_MSA_PageArray<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_PageArray_t;
61
62     // Sun's C++ compiler doesn't understand that nested classes are
63     // members for the sake of access to private data. (2008-10-23)
64     class Read; class Write; class Accum;
65     friend class Read; friend class Write; friend class Accum;
66
67         class Handle
68         {
69     public:
70         inline unsigned int length() const { return msa.length(); }
71
72         protected:
73         MSA1D &msa;
74         bool valid;
75
76         friend class MSA1D;
77
78         void inline checkInvalidate(MSA1D *m) 
79         {
80             if (m != &msa || !valid)
81                 throw MSA_InvalidHandle();
82             valid = false;
83         }
84
85         Handle(MSA1D &msa_) 
86             : msa(msa_), valid(true) 
87         { }
88         void checkValid()
89         {
90             if (!valid)
91                 throw MSA_InvalidHandle();
92         }
93
94     private:
95         // Disallow copy construction
96         Handle(Handle &);
97     };
98
99     class Read : public Handle
100     {
101     protected:
102         friend class MSA1D;
103         Read(MSA1D &msa_)
104             :  Handle(msa_) { }
105         using Handle::checkValid;
106         using Handle::checkInvalidate;
107
108     public:
109         inline const ENTRY& get(unsigned int idx)
110         {
111             checkValid();
112             return Handle::msa.get(idx); 
113         }
114         inline const ENTRY& operator[](unsigned int idx) { return get(idx); }
115         inline const ENTRY& operator()(unsigned int idx) { return get(idx); }
116         inline const ENTRY& get2(unsigned int idx)
117         {
118             checkValid();
119             return Handle::msa.get2(idx);
120         }
121     };
122
123     class Write : public Handle
124     {
125     protected:
126         friend class MSA1D;
127         Write(MSA1D &msa_)
128             : Handle(msa_) { }
129
130     public:
131         inline Writable<ENTRY> set(unsigned int idx)
132         {
133             Handle::checkValid();
134             return Writable<ENTRY>(Handle::msa.set(idx));
135         }
136         inline Writable<ENTRY> operator()(unsigned int idx)
137             { return set(idx); }
138     };
139
140     class Accum : public Handle
141     {
142     protected:
143         friend class MSA1D;
144         Accum(MSA1D &msa_)
145             : Handle(msa_) { }
146         using Handle::checkInvalidate;
147     public:
148         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> accumulate(unsigned int idx)
149         { 
150             Handle::checkValid();
151             return Accumulable<ENTRY, ENTRY_OPS_CLASS>(Handle::msa.accumulate(idx));
152         }
153         inline void accumulate(unsigned int idx, const ENTRY& ent)
154         {
155             Handle::checkValid();
156             Handle::msa.accumulate(idx, ent);
157         }
158
159         void contribute(unsigned int idx, const ENTRY *begin, const ENTRY *end)
160         {
161             Handle::checkValid();
162             for (const ENTRY *e = begin; e != end; ++e, ++idx)
163                 {
164                     Handle::msa.accumulate(idx, *e);
165                 }
166         }
167
168         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> operator() (unsigned int idx)
169             { return accumulate(idx); }
170     };
171
172 protected:
173     /// Total number of ENTRY's in the whole array.
174     unsigned int nEntries;
175     bool initHandleGiven;
176
177     /// Handle to owner of cache.
178     CacheGroup_t* cache;
179     CProxy_CacheGroup_t cg;
180
181     inline const ENTRY* readablePage(unsigned int page)
182     {
183         return (const ENTRY*)(cache->readablePage(page));
184     }
185
186     // known local page.
187     inline const ENTRY* readablePage2(unsigned int page)
188     {
189         return (const ENTRY*)(cache->readablePage2(page));
190     }
191
192     // Returns a pointer to the start of the local copy in the cache of the writeable page.
193     // @@ what if begin - end span across two or more pages?
194     inline ENTRY* writeablePage(unsigned int page, unsigned int offset)
195     {
196         return (ENTRY*)(cache->writeablePage(page, offset));
197     }
198
199 public:
200     // @@ Needed for Jade
201     inline MSA1D() 
202         :initHandleGiven(false) 
203     {}
204
205     virtual void pup(PUP::er &p){
206         p|nEntries;
207         p|cg;
208         if (p.isUnpacking()) cache=cg.ckLocalBranch();
209     }
210
211     /**
212       Create a completely new MSA array.  This call creates the
213       corresponding groups, so only call it once per array.
214     */
215     inline MSA1D(unsigned int nEntries_, unsigned int num_wrkrs, 
216                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES) 
217         : nEntries(nEntries_), initHandleGiven(false)
218     {
219         // first create the Page Array and the Page Group
220         unsigned int nPages = (nEntries + ENTRIES_PER_PAGE - 1)/ENTRIES_PER_PAGE;
221         CProxy_PageArray_t pageArray = CProxy_PageArray_t::ckNew(nPages);
222         cg = CProxy_CacheGroup_t::ckNew(nPages, pageArray, maxBytes, nEntries, num_wrkrs);
223         pageArray.setCacheProxy(cg);
224         pageArray.ckSetReductionClient(new CkCallback(CkIndex_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::SyncDone(), cg));
225         cache = cg.ckLocalBranch();
226     }
227
228     // Deprecated API for accessing CacheGroup directly.
229     inline MSA1D(CProxy_CacheGroup_t cg_) : cg(cg_), initHandleGiven(false)
230     {
231         cache = cg.ckLocalBranch();
232         nEntries = cache->getNumEntries();
233     }
234
235     inline ~MSA1D()
236     {
237         // TODO: how to get rid of the cache group and the page array
238         //(cache->getArray()).destroy();
239         //cg.destroy();
240         // TODO: calling FreeMem does not seem to work. Need to debug it.
241                                 cache->unroll();
242  //       cache->FreeMem();
243     }
244
245     /**
246      * this function is supposed to be called when the thread/object using this array
247      * migrates to another PE.
248      */
249     inline void changePE()
250     {
251         cache = cg.ckLocalBranch();
252
253         /* don't need to update the number of entries, as that does not change */
254     }
255
256     // ================ Accessor/Utility functions ================
257     /// Get the total length of the array, across all processors.
258     inline unsigned int length() const { return nEntries; }
259
260     inline const CProxy_CacheGroup_t &getCacheGroup() const { return cg; }
261
262     // Avoid using the term "page size" because it is confusing: does
263     // it mean in bytes or number of entries?
264     inline unsigned int getNumEntriesPerPage() const { return ENTRIES_PER_PAGE; }
265
266     /// Return the page this entry is stored at.
267     inline unsigned int getPageIndex(unsigned int idx)
268     {
269         return idx / ENTRIES_PER_PAGE;
270     }
271
272     /// Return the offset, in entries, that this entry is stored at within a page.
273     inline unsigned int getOffsetWithinPage(unsigned int idx)
274     {
275         return idx % ENTRIES_PER_PAGE;
276     }
277
278     // ================ MSA API ================
279
280     // We need to know the total number of workers across all
281     // processors, and we also calculate the number of worker threads
282     // running on this processor.
283     //
284     // Blocking method, basically does a barrier until all workers
285     // enroll.
286     inline void enroll(int num_workers)
287     {
288         // @@ This is a hack to identify the number of MSA1D
289         // threads on this processor.  This number is needed for sync.
290         //
291         // @@ What if a MSA1D thread migrates?
292         cache->enroll(num_workers);
293     }
294
295     // idx is the element to be read/written
296     //
297     // This function returns a reference to the first element on the
298     // page that contains idx.
299     inline ENTRY& getPageBottom(unsigned int idx, MSA_Page_Fault_t accessMode)
300     {
301         if (accessMode==Read_Fault) {
302             unsigned int page = idx / ENTRIES_PER_PAGE;
303             return const_cast<ENTRY&>(readablePage(page)[0]);
304         } else {
305             CkAssert(accessMode==Write_Fault || accessMode==Accumulate_Fault);
306             unsigned int page = idx / ENTRIES_PER_PAGE;
307             unsigned int offset = idx % ENTRIES_PER_PAGE;
308             ENTRY* e=writeablePage(page, offset);
309             return e[0];
310         }
311     }
312
313     inline void FreeMem()
314     {
315         cache->FreeMem();
316     }
317
318     /// Non-blocking prefetch of entries from start to end, inclusive.
319     /// Prefetch'd pages are locked into the cache, so you must call
320     ///   unlock afterwards.
321     inline void Prefetch(unsigned int start, unsigned int end)
322     {
323         unsigned int page1 = start / ENTRIES_PER_PAGE;
324         unsigned int page2 = end / ENTRIES_PER_PAGE;
325         cache->Prefetch(page1, page2);
326     }
327
328     /// Block until all prefetched pages arrive.
329     inline int WaitAll()    { return cache->WaitAll(); }
330
331     /// Unlock all locked pages
332     inline void Unlock()    { return cache->UnlockPages(); }
333
334     /// start and end are element indexes.
335     /// Unlocks completely spanned pages given a range of elements
336     /// index'd from "start" to "end", inclusive.  If start/end does not span a
337     /// page completely, i.e. start/end is in the middle of a page,
338     /// the entire page is still unlocked--in particular, this means
339     /// you should not have several adjacent ranges locked.
340     inline void Unlock(unsigned int start, unsigned int end)
341     {
342         unsigned int page1 = start / ENTRIES_PER_PAGE;
343         unsigned int page2 = end / ENTRIES_PER_PAGE;
344         cache->UnlockPages(page1, page2);
345     }
346
347     static const int DEFAULT_SYNC_SINGLE = 0;
348
349     inline Read &syncToRead(Handle &m, int single = DEFAULT_SYNC_SINGLE)
350     {
351         m.checkInvalidate(this);
352         delete &m;
353         sync(single);
354         return *(new Read(*this));
355     }
356
357     inline Write &syncToWrite(Handle &m, int single = DEFAULT_SYNC_SINGLE)
358     {
359         m.checkInvalidate(this);
360         delete &m;
361         sync(single);
362         return *(new Write(*this));
363     }
364
365     inline Accum &syncToAccum(Handle &m, int single = DEFAULT_SYNC_SINGLE)
366     {
367         m.checkInvalidate(this);
368         delete &m;
369         sync(single);
370         return *(new Accum(*this));
371     }
372
373     inline Write &getInitialWrite()
374     {
375         if (initHandleGiven)
376             throw MSA_InvalidHandle();
377
378         Write *w = new Write(*this);
379         sync();
380         initHandleGiven = true;
381         return *w;
382     }
383
384     inline Accum &getInitialAccum()
385     {
386         if (initHandleGiven)
387             throw MSA_InvalidHandle();
388
389         Accum *a = new Accum(*this);
390         sync();
391         initHandleGiven = true;
392         return *a;
393     }
394
395   // These are the meat of the MSA API, but they are only accessible
396   // through appropriate handles (defined in the public section above).
397 protected:
398     /// Return a read-only copy of the element at idx.
399     ///   May block if the element is not already in the cache.
400     inline const ENTRY& get(unsigned int idx)
401     {
402         unsigned int page = idx / ENTRIES_PER_PAGE;
403         unsigned int offset = idx % ENTRIES_PER_PAGE;
404         return readablePage(page)[offset];
405     }
406
407     inline const ENTRY& operator[](unsigned int idx)
408     {
409         return get(idx);
410     }
411
412     /// Return a read-only copy of the element at idx;
413     ///   ONLY WORKS WHEN ELEMENT IS ALREADY IN THE CACHE--
414     ///   WILL SEGFAULT IF ELEMENT NOT ALREADY PRESENT.
415     ///    Never blocks; may crash if element not already present.
416     inline const ENTRY& get2(unsigned int idx)
417     {
418         unsigned int page = idx / ENTRIES_PER_PAGE;
419         unsigned int offset = idx % ENTRIES_PER_PAGE;
420         return readablePage2(page)[offset];
421     }
422
423     /// Return a writeable copy of the element at idx.
424     ///    Never blocks; will create a new blank element if none exists locally.
425     ///    UNDEFINED if two threads set the same element.
426     inline ENTRY& set(unsigned int idx)
427     {
428         unsigned int page = idx / ENTRIES_PER_PAGE;
429         unsigned int offset = idx % ENTRIES_PER_PAGE;
430         ENTRY* e=writeablePage(page, offset);
431         return e[offset];
432     }
433
434     /// Fetch the ENTRY at idx to be accumulated.
435     ///   You must perform the accumulation on 
436     ///     the return value before calling "sync".
437     ///   Never blocks.
438     inline ENTRY& accumulate(unsigned int idx)
439     {
440         unsigned int page = idx / ENTRIES_PER_PAGE;
441         unsigned int offset = idx % ENTRIES_PER_PAGE;
442         return cache->accumulate(page, offset);
443     }
444     
445     /// Add ent to the element at idx.
446     ///   Never blocks.
447     ///   Merges together accumulates from different threads.
448     inline void accumulate(unsigned int idx, const ENTRY& ent)
449     {
450         ENTRY_OPS_CLASS::accumulate(accumulate(idx),ent);
451     }
452
453     /// Synchronize reads and writes across the entire array.
454     inline void sync(int single=0)
455     {
456         cache->SyncReq(single); 
457     }
458 };
459
460
461 // define a 2d distributed array based on the 1D array, support row major and column
462 // major arrangement of data
463 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE=MSA_DEFAULT_ENTRIES_PER_PAGE, MSA_Array_Layout_t ARRAY_LAYOUT=MSA_ROW_MAJOR>
464 class MSA2D : public MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>
465 {
466 public:
467     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
468     typedef MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> super;
469
470 protected:
471     unsigned int rows, cols;
472
473 public:
474     // @@ Needed for Jade
475     inline MSA2D() : super() {}
476     virtual void pup(PUP::er &p) {
477        super::pup(p);
478        p|rows; p|cols;
479     };
480
481         class Handle
482         {
483         protected:
484         MSA2D &msa;
485         bool valid;
486
487         friend class MSA2D;
488
489         inline void checkInvalidate(MSA2D *m)
490         {
491             if (&msa != m || !valid)
492                 throw MSA_InvalidHandle();
493             valid = false;
494         }
495
496         Handle(MSA2D &msa_) 
497             : msa(msa_), valid(true) 
498         { }
499         inline void checkValid()
500         {
501             if (!valid)
502                 throw MSA_InvalidHandle();
503         }
504     private:
505         // Disallow copy construction
506         Handle(Handle &);
507     };
508
509     class Read : public Handle
510     {
511     private:
512         friend class MSA2D;
513         Read(MSA2D &msa_)
514             :  Handle(msa_) { }
515
516     public: 
517         inline const ENTRY& get(unsigned int row, unsigned int col)
518         {
519             Handle::checkValid();
520             return Handle::msa.get(row, col);
521         }
522         inline const ENTRY& get2(unsigned int row, unsigned int col)
523         {
524             Handle::checkValid();
525             return Handle::msa.get2(row, col);
526         }
527
528         inline const ENTRY& operator() (unsigned int row, unsigned int col)
529             {
530                 return get(row,col);
531             }
532
533     };
534
535     class Write : public Handle
536     {
537     private:
538         friend class MSA2D;
539         Write(MSA2D &msa_)
540             :  Handle(msa_) { }
541
542     public: 
543         inline Writable<ENTRY> set(unsigned int row, unsigned int col)
544         {
545             Handle::checkValid();
546             return Writable<ENTRY>(Handle::msa.set(row, col));
547         }
548     };
549
550     inline MSA2D(unsigned int rows_, unsigned int cols_, unsigned int numwrkrs,
551                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES)
552         :super(rows_*cols_, numwrkrs, maxBytes)
553     {
554         rows = rows_; cols = cols_;
555     }
556
557     inline MSA2D(unsigned int rows_, unsigned int cols_, CProxy_CacheGroup_t cg_)
558         : rows(rows_), cols(cols_), super(cg_)
559     {}
560
561     // get the 1D index of the given entry as per the row major/column major format
562     inline unsigned int getIndex(unsigned int row, unsigned int col)
563     {
564         unsigned int index;
565
566         if(ARRAY_LAYOUT==MSA_ROW_MAJOR)
567             index = row*cols + col;
568         else
569             index = col*rows + row;
570
571         return index;
572     }
573
574     // Which page is (row, col) on?
575     inline unsigned int getPageIndex(unsigned int row, unsigned int col)
576     {
577         return getIndex(row, col)/ENTRIES_PER_PAGE;
578     }
579
580     inline unsigned int getOffsetWithinPage(unsigned int row, unsigned int col)
581     {
582         return getIndex(row, col)%ENTRIES_PER_PAGE;
583     }
584
585     inline unsigned int getRows(void) const {return rows;}
586     inline unsigned int getCols(void) const {return cols;}
587     inline unsigned int getColumns(void) const {return cols;}
588     inline MSA_Array_Layout_t getArrayLayout() const {return ARRAY_LAYOUT;}
589
590     inline void Prefetch(unsigned int start, unsigned int end)
591     {
592         // prefetch the start ... end rows/columns into the cache
593         if(start > end)
594         {
595             unsigned int temp = start;
596             start = end;
597             end = temp;
598         }
599
600         unsigned int index1 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(start, 0) : getIndex(0, start);
601         unsigned int index2 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(end, cols-1) : getIndex(rows-1, end);
602
603         MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::Prefetch(index1, index2);
604     }
605
606     // Unlocks pages starting from row "start" through row "end", inclusive
607     inline void UnlockPages(unsigned int start, unsigned int end)
608     {
609         if(start > end)
610         {
611             unsigned int temp = start;
612             start = end;
613             end = temp;
614         }
615
616         unsigned int index1 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(start, 0) : getIndex(0, start);
617         unsigned int index2 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(end, cols-1) : getIndex(rows-1, end);
618
619         MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::Unlock(index1, index2);
620     }
621
622     inline Read& syncToRead(Handle &m, int single = super::DEFAULT_SYNC_SINGLE)
623     {
624         m.checkInvalidate(this);
625         delete &m;
626         super::sync(single);
627         return *(new Read(*this));
628     }
629
630     inline Write& syncToWrite(Handle &m, int single = super::DEFAULT_SYNC_SINGLE)
631     {
632         m.checkInvalidate(this);
633         delete &m;
634         super::sync(single);
635         return *(new Write(*this));
636     }
637
638     inline Write& getInitialWrite()
639     {
640         if (super::initHandleGiven)
641             throw MSA_InvalidHandle();
642
643         Write *w = new Write(*this);
644         super::initHandleGiven = true;
645         return *w;
646     }
647
648 protected:
649     inline const ENTRY& get(unsigned int row, unsigned int col)
650     {
651         return super::get(getIndex(row, col));
652     }
653
654     // known local
655     inline const ENTRY& get2(unsigned int row, unsigned int col)
656     {
657         return super::get2(getIndex(row, col));
658     }
659
660     // MSA2D::
661     inline ENTRY& set(unsigned int row, unsigned int col)
662     {
663         return super::set(getIndex(row, col));
664     }
665 };
666
667 namespace MSA
668 {
669     using std::min;
670     using std::max;
671
672
673 /**
674    The MSA3D class is a handle to a distributed shared array of items
675    of data type ENTRY. There are nEntries total numer of ENTRY's, with
676    ENTRIES_PER_PAGE data items per "page".  It is implemented as a
677    Chare Array of pages, and a Group representing the local cache.
678
679    The requirements for the templates are:
680      ENTRY: User data class stored in the array, with at least:
681         - A default constructor and destructor
682         - A working assignment operator
683         - A working pup routine
684      ENTRY_OPS_CLASS: Used to combine values for "accumulate":
685         - A method named "getIdentity", taking no arguments and
686           returning an ENTRY to use before any accumulation.
687         - A method named "accumulate", taking a source/dest ENTRY by reference
688           and an ENTRY to add to it by value or const reference.
689      ENTRIES_PER_PAGE: Optional integer number of ENTRY objects
690         to store and communicate at once.  For good performance,
691         make sure this value is a power of two.
692  */
693 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE>
694 class MSA3D
695 {
696     unsigned dim_x, dim_y, dim_z;
697
698
699 public:
700     typedef MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CacheGroup_t;
701     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
702     typedef CProxy_MSA_PageArray<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_PageArray_t;
703
704     // Sun's C++ compiler doesn't understand that nested classes are
705     // members for the sake of access to private data. (2008-10-23)
706     class Read; class Write; class Accum;
707     friend class Read; friend class Write; friend class Accum;
708
709         class Handle
710         {
711         protected:
712         MSA3D *msa;
713         bool valid;
714
715         friend class MSA3D;
716
717         void inline checkInvalidate() 
718         {
719             if (!valid)
720                 throw MSA_InvalidHandle();
721             valid = false;
722         }
723
724         Handle(MSA3D *msa_) 
725             : msa(msa_), valid(true) 
726         { }
727         void checkValid()
728         {
729             if (!valid)
730                 throw MSA_InvalidHandle();
731         }
732         
733     public:
734         inline void syncRelease()
735             {
736                 checkInvalidate();
737                 if (msa->active)
738                     msa->cache->SyncRelease();
739                 else
740                     CmiAbort("sync from an inactive thread!\n");
741                 msa->active = false;
742             }
743
744         inline void syncDone()
745             {
746                 checkInvalidate();
747                 msa->sync(DEFAULT_SYNC_SINGLE);
748             }
749
750         inline Read syncToRead()
751             {
752                 checkInvalidate();
753                 msa->sync(DEFAULT_SYNC_SINGLE);
754                 return Read(msa);
755             }
756
757         inline Write syncToWrite()
758             {
759                 checkInvalidate();
760                 msa->sync(DEFAULT_SYNC_SINGLE);
761                 return Write(msa);
762             }
763
764         inline Accum syncToAccum()
765             {
766                 checkInvalidate();
767                 msa->sync(DEFAULT_SYNC_SINGLE);
768                 return Accum(msa);
769             }
770
771         void pup(PUP::er &p)
772             {
773                 bool real;
774                 if(!p.isUnpacking())
775                     real = msa != NULL;
776                 p|real;
777                 if(real)
778                 {
779                     if(p.isUnpacking())
780                         msa = new MSA3D;
781                     p|(*msa);
782                     p|valid;
783                 }
784             }
785
786         Handle() : msa(NULL), valid(false) {}
787     };
788
789     class Read : public Handle
790     {
791     protected:
792         friend class MSA3D;
793         Read(MSA3D *msa_)
794             :  Handle(msa_) { }
795         using Handle::checkValid;
796         using Handle::checkInvalidate;
797
798     public:
799         Read() {}
800
801         inline const ENTRY& get(unsigned x, unsigned y, unsigned z)
802         {
803             checkValid();
804             return Handle::msa->get(x, y, z); 
805         }
806         inline const ENTRY& operator()(unsigned x, unsigned y, unsigned z) { return get(x, y, z); }
807         inline const ENTRY& get2(unsigned x, unsigned y, unsigned z)
808         {
809             checkValid();
810             return Handle::msa->get2(x, y, z);
811         }
812
813         // Reads the specified range into the provided buffer in row-major order
814         void read(ENTRY *buf, unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2)
815         {
816             checkValid();
817
818             CkAssert(x1 <= x2);
819             CkAssert(y1 <= y2);
820             CkAssert(z1 <= z2);
821
822             CkAssert(x1 >= 0);
823             CkAssert(y1 >= 0);
824             CkAssert(z1 >= 0);
825
826             CkAssert(x2 < Handle::msa->dim_x);
827             CkAssert(y2 < Handle::msa->dim_y);
828             CkAssert(z2 < Handle::msa->dim_z);
829
830             unsigned i = 0;
831
832             for (unsigned ix = x1; ix <= x2; ++ix)
833                 for (unsigned iy = y1; iy <= y2; ++iy)
834                     for (unsigned iz = z1; iz <= z2; ++iz)
835                         buf[i++] = Handle::msa->get(ix, iy, iz);
836         }
837     };
838
839     class Write : public Handle
840     {
841     protected:
842         friend class MSA3D;
843         Write(MSA3D *msa_)
844             : Handle(msa_) { }
845
846     public:
847         Write() {}
848
849         inline Writable<ENTRY> set(unsigned x, unsigned y, unsigned z)
850         {
851             Handle::checkValid();
852             return Writable<ENTRY>(Handle::msa->set(x,y,z));
853         }
854         inline Writable<ENTRY> operator()(unsigned x, unsigned y, unsigned z)
855         {
856             return set(x,y,z);
857         }
858
859         void write(unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2, const ENTRY *buf)
860         {
861             Handle::checkValid();
862
863             CkAssert(x1 <= x2);
864             CkAssert(y1 <= y2);
865             CkAssert(z1 <= z2);
866
867             CkAssert(x1 >= 0);
868             CkAssert(y1 >= 0);
869             CkAssert(z1 >= 0);
870
871             CkAssert(x2 < Handle::msa->dim_x);
872             CkAssert(y2 < Handle::msa->dim_y);
873             CkAssert(z2 < Handle::msa->dim_z);
874
875             unsigned i = 0;
876
877             for (unsigned ix = x1; ix <= x2; ++ix)
878                 for (unsigned iy = y1; iy <= y2; ++iy)
879                     for (unsigned iz = z1; iz <= z2; ++iz)
880                     {
881                         if (isnan(buf[i]))
882                             CmiAbort("Tried to write a NaN!");
883                         Handle::msa->set(ix, iy, iz) = buf[i++];
884                     }
885         }
886 #if 0
887     private:
888         Write(Write &);
889 #endif
890     };
891
892     class Accum : public Handle
893     {
894     protected:
895         friend class MSA3D;
896         Accum(MSA3D *msa_)
897             : Handle(msa_) { }
898         using Handle::checkInvalidate;
899     public:
900         Accum() {}
901
902         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> accumulate(unsigned int x, unsigned int y, unsigned int z)
903         {
904             Handle::checkValid();
905             return Accumulable<ENTRY, ENTRY_OPS_CLASS>(Handle::msa->accumulate(x,y,z));
906         }
907         inline void accumulate(unsigned int x, unsigned int y, unsigned int z, const ENTRY& ent)
908         {
909             Handle::checkValid();
910             Handle::msa->accumulate(x,y,z, ent);
911         }
912
913         void accumulate(unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2, const ENTRY *buf)
914         {
915             Handle::checkValid();
916             CkAssert(x1 <= x2);
917             CkAssert(y1 <= y2);
918             CkAssert(z1 <= z2);
919
920             CkAssert(x1 >= 0);
921             CkAssert(y1 >= 0);
922             CkAssert(z1 >= 0);
923
924             CkAssert(x2 < Handle::msa->dim_x);
925             CkAssert(y2 < Handle::msa->dim_y);
926             CkAssert(z2 < Handle::msa->dim_z);
927
928             unsigned i = 0;
929
930             for (unsigned ix = x1; ix <= x2; ++ix)
931                 for (unsigned iy = y1; iy <= y2; ++iy)
932                     for (unsigned iz = z1; iz <= z2; ++iz)
933                         Handle::msa->accumulate(ix, iy, iz, buf[i++]);
934         }
935
936         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> operator() (unsigned int x, unsigned int y, unsigned int z)
937             { return accumulate(x,y,z); }
938     };
939
940 protected:
941     /// Total number of ENTRY's in the whole array.
942     unsigned int nEntries;
943     bool initHandleGiven;
944
945     /// Handle to owner of cache.
946     CacheGroup_t* cache;
947     CProxy_CacheGroup_t cg;
948
949     inline const ENTRY* readablePage(unsigned int page)
950     {
951         return (const ENTRY*)(cache->readablePage(page));
952     }
953
954     // known local page.
955     inline const ENTRY* readablePage2(unsigned int page)
956     {
957         return (const ENTRY*)(cache->readablePage2(page));
958     }
959
960     // Returns a pointer to the start of the local copy in the cache of the writeable page.
961     // @@ what if begin - end span across two or more pages?
962     inline ENTRY* writeablePage(unsigned int page, unsigned int offset)
963     {
964         return (ENTRY*)(cache->writeablePage(page, offset));
965     }
966
967 public:
968     // @@ Needed for Jade
969     inline MSA3D() 
970         :initHandleGiven(false) 
971     {}
972
973     virtual void pup(PUP::er &p){
974         p|dim_x;
975         p|dim_y;
976         p|dim_z;
977         p|nEntries;
978         p|cg;
979         if (p.isUnpacking()) cache=cg.ckLocalBranch();
980     }
981
982     /**
983       Create a completely new MSA array.  This call creates the
984       corresponding groups, so only call it once per array.
985     */
986     inline MSA3D(unsigned x, unsigned y, unsigned z, unsigned int num_wrkrs, 
987                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES)
988         : dim_x(x), dim_y(y), dim_z(z), initHandleGiven(false)
989     {
990         unsigned nEntries = x*y*z;
991         unsigned int nPages = (nEntries + ENTRIES_PER_PAGE - 1)/ENTRIES_PER_PAGE;
992         CProxy_PageArray_t pageArray = CProxy_PageArray_t::ckNew(nPages);
993         cg = CProxy_CacheGroup_t::ckNew(nPages, pageArray, maxBytes, nEntries, num_wrkrs);
994         pageArray.setCacheProxy(cg);
995         //pageArray.ckSetReductionClient(new CkCallback(CkIndex_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::SyncDone(), cg));
996         cache = cg.ckLocalBranch();
997     }
998
999     inline ~MSA3D()
1000     {
1001         // TODO: how to get rid of the cache group and the page array
1002         //(cache->getArray()).destroy();
1003         //cg.destroy();
1004         // TODO: calling FreeMem does not seem to work. Need to debug it.
1005         //cache->unroll();
1006         //cache->FreeMem();
1007     }
1008
1009     /**
1010      * this function is supposed to be called when the thread/object using this array
1011      * migrates to another PE.
1012      */
1013     inline void changePE()
1014     {
1015         cache = cg.ckLocalBranch();
1016
1017         /* don't need to update the number of entries, as that does not change */
1018     }
1019
1020     // ================ Accessor/Utility functions ================
1021
1022     inline const CProxy_CacheGroup_t &getCacheGroup() const { return cg; }
1023
1024     // Avoid using the term "page size" because it is confusing: does
1025     // it mean in bytes or number of entries?
1026     inline unsigned int getNumEntriesPerPage() const { return ENTRIES_PER_PAGE; }
1027
1028     inline unsigned int index(unsigned x, unsigned y, unsigned z)
1029     {
1030         CkAssert(x < dim_x);
1031         CkAssert(y < dim_y);
1032         CkAssert(z < dim_z);
1033         return ((x*dim_y) + y) * dim_z + z;
1034     }
1035     
1036     /// Return the page this entry is stored at.
1037     inline unsigned int getPageIndex(unsigned int idx)
1038     {
1039         return idx / ENTRIES_PER_PAGE;
1040     }
1041
1042     /// Return the offset, in entries, that this entry is stored at within a page.
1043     inline unsigned int getOffsetWithinPage(unsigned int idx)
1044     {
1045         return idx % ENTRIES_PER_PAGE;
1046     }
1047
1048     // ================ MSA API ================
1049
1050     // We need to know the total number of workers across all
1051     // processors, and we also calculate the number of worker threads
1052     // running on this processor.
1053     //
1054     // Blocking method, basically does a barrier until all workers
1055     // enroll.
1056     inline void enroll(int num_workers)
1057     {
1058         // @@ This is a hack to identify the number of MSA3D
1059         // threads on this processor.  This number is needed for sync.
1060         //
1061         // @@ What if a MSA3D thread migrates?
1062         cache->enroll(num_workers);
1063     }
1064
1065     // idx is the element to be read/written
1066     //
1067     // This function returns a reference to the first element on the
1068     // page that contains idx.
1069     inline ENTRY& getPageBottom(unsigned int idx, MSA_Page_Fault_t accessMode)
1070     {
1071         if (accessMode==Read_Fault) {
1072             unsigned int page = idx / ENTRIES_PER_PAGE;
1073             return const_cast<ENTRY&>(readablePage(page)[0]);
1074         } else {
1075             CkAssert(accessMode==Write_Fault || accessMode==Accumulate_Fault);
1076             unsigned int page = idx / ENTRIES_PER_PAGE;
1077             unsigned int offset = idx % ENTRIES_PER_PAGE;
1078             ENTRY* e=writeablePage(page, offset);
1079             return e[0];
1080         }
1081     }
1082
1083     inline void FreeMem()
1084     {
1085         cache->FreeMem();
1086     }
1087
1088     /// Non-blocking prefetch of entries from start to end, inclusive.
1089     /// Prefetch'd pages are locked into the cache, so you must call
1090     ///   unlock afterwards.
1091     inline void Prefetch(unsigned int start, unsigned int end)
1092     {
1093         unsigned int page1 = start / ENTRIES_PER_PAGE;
1094         unsigned int page2 = end / ENTRIES_PER_PAGE;
1095         cache->Prefetch(page1, page2);
1096     }
1097
1098     /// Block until all prefetched pages arrive.
1099     inline int WaitAll()    { return cache->WaitAll(); }
1100
1101     /// Unlock all locked pages
1102     inline void Unlock()    { return cache->UnlockPages(); }
1103
1104     /// start and end are element indexes.
1105     /// Unlocks completely spanned pages given a range of elements
1106     /// index'd from "start" to "end", inclusive.  If start/end does not span a
1107     /// page completely, i.e. start/end is in the middle of a page,
1108     /// the entire page is still unlocked--in particular, this means
1109     /// you should not have several adjacent ranges locked.
1110     inline void Unlock(unsigned int start, unsigned int end)
1111     {
1112         unsigned int page1 = start / ENTRIES_PER_PAGE;
1113         unsigned int page2 = end / ENTRIES_PER_PAGE;
1114         cache->UnlockPages(page1, page2);
1115     }
1116
1117     static const int DEFAULT_SYNC_SINGLE = 0;
1118
1119     inline Write getInitialWrite()
1120     {
1121         if (initHandleGiven)
1122             CmiAbort("Trying to get an MSA's initial handle a second time");
1123
1124         //Write *w = new Write(*this);
1125         //sync();
1126         initHandleGiven = true;
1127         return Write(this);
1128     }
1129
1130     inline Accum getInitialAccum()
1131     {
1132         if (initHandleGiven)
1133             CmiAbort("Trying to get an MSA's initial handle a second time");
1134
1135         //Accum *a = new Accum(*this);
1136         //sync();
1137         initHandleGiven = true;
1138         return Accum(this);
1139     }
1140
1141   // These are the meat of the MSA API, but they are only accessible
1142   // through appropriate handles (defined in the public section above).
1143 protected:
1144     /// Return a read-only copy of the element at idx.
1145     ///   May block if the element is not already in the cache.
1146     inline const ENTRY& get(unsigned x, unsigned y, unsigned z)
1147     {
1148         unsigned int idx = index(x,y,z);
1149         unsigned int page = idx / ENTRIES_PER_PAGE;
1150         unsigned int offset = idx % ENTRIES_PER_PAGE;
1151         return readablePage(page)[offset];
1152     }
1153
1154     /// Return a read-only copy of the element at idx;
1155     ///   ONLY WORKS WHEN ELEMENT IS ALREADY IN THE CACHE--
1156     ///   WILL SEGFAULT IF ELEMENT NOT ALREADY PRESENT.
1157     ///    Never blocks; may crash if element not already present.
1158     inline const ENTRY& get2(unsigned x, unsigned y, unsigned z)
1159     {
1160         unsigned int idx = index(x,y,z);
1161         unsigned int page = idx / ENTRIES_PER_PAGE;
1162         unsigned int offset = idx % ENTRIES_PER_PAGE;
1163         return readablePage2(page)[offset];
1164     }
1165
1166     /// Return a writeable copy of the element at idx.
1167     ///    Never blocks; will create a new blank element if none exists locally.
1168     ///    UNDEFINED if two threads set the same element.
1169     inline ENTRY& set(unsigned x, unsigned y, unsigned z)
1170     {
1171         unsigned int idx = index(x,y,z);
1172         unsigned int page = idx / ENTRIES_PER_PAGE;
1173         unsigned int offset = idx % ENTRIES_PER_PAGE;
1174         ENTRY* e=writeablePage(page, offset);
1175         return e[offset];
1176     }
1177
1178     /// Fetch the ENTRY at idx to be accumulated.
1179     ///   You must perform the accumulation on 
1180     ///     the return value before calling "sync".
1181     ///   Never blocks.
1182     inline ENTRY& accumulate(unsigned x, unsigned y, unsigned z)
1183     {
1184         unsigned int idx = index(x,y,z);
1185         unsigned int page = idx / ENTRIES_PER_PAGE;
1186         unsigned int offset = idx % ENTRIES_PER_PAGE;
1187         return cache->accumulate(page, offset);
1188     }
1189     
1190     /// Add ent to the element at idx.
1191     ///   Never blocks.
1192     ///   Merges together accumulates from different threads.
1193     inline void accumulate(unsigned x, unsigned y, unsigned z, const ENTRY& ent)
1194     {
1195         ENTRY_OPS_CLASS::accumulate(accumulate(x,y,z),ent);
1196     }
1197
1198     /// Synchronize reads and writes across the entire array.
1199     inline void sync(int single=0)
1200     {
1201         cache->SyncReq(single); 
1202     }
1203 };
1204
1205 }
1206 #endif