Make Writable<T>::operator= return its RHS, as required
[charm.git] / src / libs / ck-libs / multiphaseSharedArrays / msa-distArray.h
1 // emacs mode line -*- mode: c++; tab-width: 4; c-basic-offset: 4; indent-tabs-mode: nil -*-
2 #ifndef MSA_DISTARRAY_H
3 #define MSA_DISTARRAY_H
4
5 #include <utility>
6 #include <algorithm>
7 #include "msa-DistPageMgr.h"
8
9
10 struct MSA_InvalidHandle { };
11
12 template <typename ENTRY>
13 class Writable
14 {
15     ENTRY &e;
16     
17 public:
18     Writable(ENTRY &e_) : e(e_) {}
19     inline const ENTRY& operator= (const ENTRY& rhs) { e = rhs; return rhs; }
20 };
21
22 template <typename ENTRY, class ENTRY_OPS_CLASS>
23 class Accumulable
24 {
25     ENTRY &e;
26     
27 public:
28     Accumulable(ENTRY &e_) : e(e_) {}
29     void operator+=(const ENTRY &rhs_)
30         { ENTRY_OPS_CLASS::accumulate(e, rhs_); }
31 };
32
33
34 /**
35    The MSA1D class is a handle to a distributed shared array of items
36    of data type ENTRY. There are nEntries total numer of ENTRY's, with
37    ENTRIES_PER_PAGE data items per "page".  It is implemented as a
38    Chare Array of pages, and a Group representing the local cache.
39
40    The requirements for the templates are:
41      ENTRY: User data class stored in the array, with at least:
42         - A default constructor and destructor
43         - A working assignment operator
44         - A working pup routine
45      ENTRY_OPS_CLASS: Used to combine values for "accumulate":
46         - A method named "getIdentity", taking no arguments and
47           returning an ENTRY to use before any accumulation.
48         - A method named "accumulate", taking a source/dest ENTRY by reference
49           and an ENTRY to add to it by value or const reference.
50      ENTRIES_PER_PAGE: Optional integer number of ENTRY objects
51         to store and communicate at once.  For good performance,
52         make sure this value is a power of two.
53  */
54 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE=MSA_DEFAULT_ENTRIES_PER_PAGE>
55 class MSA1D
56 {
57 public:
58     typedef MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CacheGroup_t;
59     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
60     typedef CProxy_MSA_PageArray<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_PageArray_t;
61
62     // Sun's C++ compiler doesn't understand that nested classes are
63     // members for the sake of access to private data. (2008-10-23)
64     class Read; class Write; class Accum;
65     friend class Read; friend class Write; friend class Accum;
66
67         class Handle
68         {
69     public:
70         inline unsigned int length() const { return msa.length(); }
71
72         protected:
73         MSA1D &msa;
74         bool valid;
75
76         friend class MSA1D;
77
78         void inline checkInvalidate(MSA1D *m) 
79         {
80             if (m != &msa || !valid)
81                 throw MSA_InvalidHandle();
82             valid = false;
83         }
84
85         Handle(MSA1D &msa_) 
86             : msa(msa_), valid(true) 
87         { }
88         void checkValid()
89         {
90             if (!valid)
91                 throw MSA_InvalidHandle();
92         }
93
94     private:
95         // Disallow copy construction
96         Handle(Handle &);
97     };
98
99     class Read : public Handle
100     {
101     protected:
102         friend class MSA1D;
103         Read(MSA1D &msa_)
104             :  Handle(msa_) { }
105         using Handle::checkValid;
106         using Handle::checkInvalidate;
107
108     public:
109         inline const ENTRY& get(unsigned int idx)
110         {
111             checkValid();
112             return Handle::msa.get(idx); 
113         }
114         inline const ENTRY& operator[](unsigned int idx) { return get(idx); }
115         inline const ENTRY& operator()(unsigned int idx) { return get(idx); }
116         inline const ENTRY& get2(unsigned int idx)
117         {
118             checkValid();
119             return Handle::msa.get2(idx);
120         }
121     };
122
123     class Write : public Handle
124     {
125     protected:
126         friend class MSA1D;
127         Write(MSA1D &msa_)
128             : Handle(msa_) { }
129
130     public:
131         inline Writable<ENTRY> set(unsigned int idx)
132         {
133             Handle::checkValid();
134             return Writable<ENTRY>(Handle::msa.set(idx));
135         }
136         inline Writable<ENTRY> operator()(unsigned int idx)
137             { return set(idx); }
138     };
139
140     class Accum : public Handle
141     {
142     protected:
143         friend class MSA1D;
144         Accum(MSA1D &msa_)
145             : Handle(msa_) { }
146         using Handle::checkInvalidate;
147     public:
148         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> accumulate(unsigned int idx)
149         { 
150             Handle::checkValid();
151             return Accumulable<ENTRY, ENTRY_OPS_CLASS>(Handle::msa.accumulate(idx));
152         }
153         inline void accumulate(unsigned int idx, const ENTRY& ent)
154         {
155             Handle::checkValid();
156             Handle::msa.accumulate(idx, ent);
157         }
158
159         void contribute(unsigned int idx, const ENTRY *begin, const ENTRY *end)
160         {
161             Handle::checkValid();
162             for (const ENTRY *e = begin; e != end; ++e, ++idx)
163                 {
164                     Handle::msa.accumulate(idx, *e);
165                 }
166         }
167
168         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> operator() (unsigned int idx)
169             { return accumulate(idx); }
170     };
171
172 protected:
173     /// Total number of ENTRY's in the whole array.
174     unsigned int nEntries;
175     bool initHandleGiven;
176
177     /// Handle to owner of cache.
178     CacheGroup_t* cache;
179     CProxy_CacheGroup_t cg;
180
181     inline const ENTRY* readablePage(unsigned int page)
182     {
183         return (const ENTRY*)(cache->readablePage(page));
184     }
185
186     // known local page.
187     inline const ENTRY* readablePage2(unsigned int page)
188     {
189         return (const ENTRY*)(cache->readablePage2(page));
190     }
191
192     // Returns a pointer to the start of the local copy in the cache of the writeable page.
193     // @@ what if begin - end span across two or more pages?
194     inline ENTRY* writeablePage(unsigned int page, unsigned int offset)
195     {
196         return (ENTRY*)(cache->writeablePage(page, offset));
197     }
198
199 public:
200     // @@ Needed for Jade
201     inline MSA1D() 
202         :initHandleGiven(false) 
203     {}
204
205     virtual void pup(PUP::er &p){
206         p|nEntries;
207         p|cg;
208         if (p.isUnpacking()) cache=cg.ckLocalBranch();
209     }
210
211     /**
212       Create a completely new MSA array.  This call creates the
213       corresponding groups, so only call it once per array.
214     */
215     inline MSA1D(unsigned int nEntries_, unsigned int num_wrkrs, 
216                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES) 
217         : nEntries(nEntries_), initHandleGiven(false)
218     {
219         // first create the Page Array and the Page Group
220         unsigned int nPages = (nEntries + ENTRIES_PER_PAGE - 1)/ENTRIES_PER_PAGE;
221         CProxy_PageArray_t pageArray = CProxy_PageArray_t::ckNew(nPages);
222         cg = CProxy_CacheGroup_t::ckNew(nPages, pageArray, maxBytes, nEntries, num_wrkrs);
223         pageArray.setCacheProxy(cg);
224         pageArray.ckSetReductionClient(new CkCallback(CkIndex_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::SyncDone(), cg));
225         cache = cg.ckLocalBranch();
226     }
227
228     // Deprecated API for accessing CacheGroup directly.
229     inline MSA1D(CProxy_CacheGroup_t cg_) : cg(cg_), initHandleGiven(false)
230     {
231         cache = cg.ckLocalBranch();
232         nEntries = cache->getNumEntries();
233     }
234
235     inline ~MSA1D()
236     {
237         // TODO: how to get rid of the cache group and the page array
238         //(cache->getArray()).destroy();
239         //cg.destroy();
240         // TODO: calling FreeMem does not seem to work. Need to debug it.
241                                 cache->unroll();
242  //       cache->FreeMem();
243     }
244
245     /**
246      * this function is supposed to be called when the thread/object using this array
247      * migrates to another PE.
248      */
249     inline void changePE()
250     {
251         cache = cg.ckLocalBranch();
252
253         /* don't need to update the number of entries, as that does not change */
254     }
255
256     // ================ Accessor/Utility functions ================
257     /// Get the total length of the array, across all processors.
258     inline unsigned int length() const { return nEntries; }
259
260     inline const CProxy_CacheGroup_t &getCacheGroup() const { return cg; }
261
262     // Avoid using the term "page size" because it is confusing: does
263     // it mean in bytes or number of entries?
264     inline unsigned int getNumEntriesPerPage() const { return ENTRIES_PER_PAGE; }
265
266     /// Return the page this entry is stored at.
267     inline unsigned int getPageIndex(unsigned int idx)
268     {
269         return idx / ENTRIES_PER_PAGE;
270     }
271
272     /// Return the offset, in entries, that this entry is stored at within a page.
273     inline unsigned int getOffsetWithinPage(unsigned int idx)
274     {
275         return idx % ENTRIES_PER_PAGE;
276     }
277
278     // ================ MSA API ================
279
280     // We need to know the total number of workers across all
281     // processors, and we also calculate the number of worker threads
282     // running on this processor.
283     //
284     // Blocking method, basically does a barrier until all workers
285     // enroll.
286     inline void enroll(int num_workers)
287     {
288         // @@ This is a hack to identify the number of MSA1D
289         // threads on this processor.  This number is needed for sync.
290         //
291         // @@ What if a MSA1D thread migrates?
292         cache->enroll(num_workers);
293     }
294
295     // idx is the element to be read/written
296     //
297     // This function returns a reference to the first element on the
298     // page that contains idx.
299     inline ENTRY& getPageBottom(unsigned int idx, MSA_Page_Fault_t accessMode)
300     {
301         if (accessMode==Read_Fault) {
302             unsigned int page = idx / ENTRIES_PER_PAGE;
303             return const_cast<ENTRY&>(readablePage(page)[0]);
304         } else {
305             CkAssert(accessMode==Write_Fault || accessMode==Accumulate_Fault);
306             unsigned int page = idx / ENTRIES_PER_PAGE;
307             unsigned int offset = idx % ENTRIES_PER_PAGE;
308             ENTRY* e=writeablePage(page, offset);
309             return e[0];
310         }
311     }
312
313     inline void FreeMem()
314     {
315         cache->FreeMem();
316     }
317
318     /// Non-blocking prefetch of entries from start to end, inclusive.
319     /// Prefetch'd pages are locked into the cache, so you must call
320     ///   unlock afterwards.
321     inline void Prefetch(unsigned int start, unsigned int end)
322     {
323         unsigned int page1 = start / ENTRIES_PER_PAGE;
324         unsigned int page2 = end / ENTRIES_PER_PAGE;
325         cache->Prefetch(page1, page2);
326     }
327
328     /// Block until all prefetched pages arrive.
329     inline int WaitAll()    { return cache->WaitAll(); }
330
331     /// Unlock all locked pages
332     inline void Unlock()    { return cache->UnlockPages(); }
333
334     /// start and end are element indexes.
335     /// Unlocks completely spanned pages given a range of elements
336     /// index'd from "start" to "end", inclusive.  If start/end does not span a
337     /// page completely, i.e. start/end is in the middle of a page,
338     /// the entire page is still unlocked--in particular, this means
339     /// you should not have several adjacent ranges locked.
340     inline void Unlock(unsigned int start, unsigned int end)
341     {
342         unsigned int page1 = start / ENTRIES_PER_PAGE;
343         unsigned int page2 = end / ENTRIES_PER_PAGE;
344         cache->UnlockPages(page1, page2);
345     }
346
347     static const int DEFAULT_SYNC_SINGLE = 0;
348
349     inline Read &syncToRead(Handle &m, int single = DEFAULT_SYNC_SINGLE)
350     {
351         m.checkInvalidate(this);
352         delete &m;
353         sync(single);
354         return *(new Read(*this));
355     }
356
357     inline Write &syncToWrite(Handle &m, int single = DEFAULT_SYNC_SINGLE)
358     {
359         m.checkInvalidate(this);
360         delete &m;
361         sync(single);
362         return *(new Write(*this));
363     }
364
365     inline Accum &syncToAccum(Handle &m, int single = DEFAULT_SYNC_SINGLE)
366     {
367         m.checkInvalidate(this);
368         delete &m;
369         sync(single);
370         return *(new Accum(*this));
371     }
372
373     inline Write &getInitialWrite()
374     {
375         if (initHandleGiven)
376             throw MSA_InvalidHandle();
377
378         Write *w = new Write(*this);
379         sync();
380         initHandleGiven = true;
381         return *w;
382     }
383
384     inline Accum &getInitialAccum()
385     {
386         if (initHandleGiven)
387             throw MSA_InvalidHandle();
388
389         Accum *a = new Accum(*this);
390         sync();
391         initHandleGiven = true;
392         return *a;
393     }
394
395   // These are the meat of the MSA API, but they are only accessible
396   // through appropriate handles (defined in the public section above).
397 protected:
398     /// Return a read-only copy of the element at idx.
399     ///   May block if the element is not already in the cache.
400     inline const ENTRY& get(unsigned int idx)
401     {
402         unsigned int page = idx / ENTRIES_PER_PAGE;
403         unsigned int offset = idx % ENTRIES_PER_PAGE;
404         return readablePage(page)[offset];
405     }
406
407     inline const ENTRY& operator[](unsigned int idx)
408     {
409         return get(idx);
410     }
411
412     /// Return a read-only copy of the element at idx;
413     ///   ONLY WORKS WHEN ELEMENT IS ALREADY IN THE CACHE--
414     ///   WILL SEGFAULT IF ELEMENT NOT ALREADY PRESENT.
415     ///    Never blocks; may crash if element not already present.
416     inline const ENTRY& get2(unsigned int idx)
417     {
418         unsigned int page = idx / ENTRIES_PER_PAGE;
419         unsigned int offset = idx % ENTRIES_PER_PAGE;
420         return readablePage2(page)[offset];
421     }
422
423     /// Return a writeable copy of the element at idx.
424     ///    Never blocks; will create a new blank element if none exists locally.
425     ///    UNDEFINED if two threads set the same element.
426     inline ENTRY& set(unsigned int idx)
427     {
428         unsigned int page = idx / ENTRIES_PER_PAGE;
429         unsigned int offset = idx % ENTRIES_PER_PAGE;
430         ENTRY* e=writeablePage(page, offset);
431         return e[offset];
432     }
433
434     /// Fetch the ENTRY at idx to be accumulated.
435     ///   You must perform the accumulation on 
436     ///     the return value before calling "sync".
437     ///   Never blocks.
438     inline ENTRY& accumulate(unsigned int idx)
439     {
440         unsigned int page = idx / ENTRIES_PER_PAGE;
441         unsigned int offset = idx % ENTRIES_PER_PAGE;
442         return cache->accumulate(page, offset);
443     }
444     
445     /// Add ent to the element at idx.
446     ///   Never blocks.
447     ///   Merges together accumulates from different threads.
448     inline void accumulate(unsigned int idx, const ENTRY& ent)
449     {
450         ENTRY_OPS_CLASS::accumulate(accumulate(idx),ent);
451     }
452
453     /// Synchronize reads and writes across the entire array.
454     inline void sync(int single=0)
455     {
456         cache->SyncReq(single); 
457     }
458 };
459
460
461 // define a 2d distributed array based on the 1D array, support row major and column
462 // major arrangement of data
463 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE=MSA_DEFAULT_ENTRIES_PER_PAGE, MSA_Array_Layout_t ARRAY_LAYOUT=MSA_ROW_MAJOR>
464 class MSA2D : public MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>
465 {
466 public:
467     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
468     typedef MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> super;
469
470 protected:
471     unsigned int rows, cols;
472
473 public:
474     // @@ Needed for Jade
475     inline MSA2D() : super() {}
476     virtual void pup(PUP::er &p) {
477        super::pup(p);
478        p|rows; p|cols;
479     };
480
481         class Handle
482         {
483         protected:
484         MSA2D &msa;
485         bool valid;
486
487         friend class MSA2D;
488
489         inline void checkInvalidate(MSA2D *m)
490         {
491             if (&msa != m || !valid)
492                 throw MSA_InvalidHandle();
493             valid = false;
494         }
495
496         Handle(MSA2D &msa_) 
497             : msa(msa_), valid(true) 
498         { }
499         inline void checkValid()
500         {
501             if (!valid)
502                 throw MSA_InvalidHandle();
503         }
504     private:
505         // Disallow copy construction
506         Handle(Handle &);
507     };
508
509     class Read : public Handle
510     {
511     private:
512         friend class MSA2D;
513         Read(MSA2D &msa_)
514             :  Handle(msa_) { }
515
516     public: 
517         inline const ENTRY& get(unsigned int row, unsigned int col)
518         {
519             Handle::checkValid();
520             return Handle::msa.get(row, col);
521         }
522         inline const ENTRY& get2(unsigned int row, unsigned int col)
523         {
524             Handle::checkValid();
525             return Handle::msa.get2(row, col);
526         }
527
528         inline const ENTRY& operator() (unsigned int row, unsigned int col)
529             {
530                 return get(row,col);
531             }
532
533     };
534
535     class Write : public Handle
536     {
537     private:
538         friend class MSA2D;
539         Write(MSA2D &msa_)
540             :  Handle(msa_) { }
541
542     public: 
543         inline Writable<ENTRY> set(unsigned int row, unsigned int col)
544         {
545             Handle::checkValid();
546             return Writable<ENTRY>(Handle::msa.set(row, col));
547         }
548     };
549
550     inline MSA2D(unsigned int rows_, unsigned int cols_, unsigned int numwrkrs,
551                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES)
552         :super(rows_*cols_, numwrkrs, maxBytes)
553     {
554         rows = rows_; cols = cols_;
555     }
556
557     inline MSA2D(unsigned int rows_, unsigned int cols_, CProxy_CacheGroup_t cg_)
558         : rows(rows_), cols(cols_), super(cg_)
559     {}
560
561     // get the 1D index of the given entry as per the row major/column major format
562     inline unsigned int getIndex(unsigned int row, unsigned int col)
563     {
564         unsigned int index;
565
566         if(ARRAY_LAYOUT==MSA_ROW_MAJOR)
567             index = row*cols + col;
568         else
569             index = col*rows + row;
570
571         return index;
572     }
573
574     // Which page is (row, col) on?
575     inline unsigned int getPageIndex(unsigned int row, unsigned int col)
576     {
577         return getIndex(row, col)/ENTRIES_PER_PAGE;
578     }
579
580     inline unsigned int getOffsetWithinPage(unsigned int row, unsigned int col)
581     {
582         return getIndex(row, col)%ENTRIES_PER_PAGE;
583     }
584
585     inline unsigned int getRows(void) const {return rows;}
586     inline unsigned int getCols(void) const {return cols;}
587     inline unsigned int getColumns(void) const {return cols;}
588     inline MSA_Array_Layout_t getArrayLayout() const {return ARRAY_LAYOUT;}
589
590     inline void Prefetch(unsigned int start, unsigned int end)
591     {
592         // prefetch the start ... end rows/columns into the cache
593         if(start > end)
594         {
595             unsigned int temp = start;
596             start = end;
597             end = temp;
598         }
599
600         unsigned int index1 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(start, 0) : getIndex(0, start);
601         unsigned int index2 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(end, cols-1) : getIndex(rows-1, end);
602
603         MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::Prefetch(index1, index2);
604     }
605
606     // Unlocks pages starting from row "start" through row "end", inclusive
607     inline void UnlockPages(unsigned int start, unsigned int end)
608     {
609         if(start > end)
610         {
611             unsigned int temp = start;
612             start = end;
613             end = temp;
614         }
615
616         unsigned int index1 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(start, 0) : getIndex(0, start);
617         unsigned int index2 = (ARRAY_LAYOUT==MSA_ROW_MAJOR) ? getIndex(end, cols-1) : getIndex(rows-1, end);
618
619         MSA1D<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::Unlock(index1, index2);
620     }
621
622     inline Read& syncToRead(Handle &m, int single = super::DEFAULT_SYNC_SINGLE)
623     {
624         m.checkInvalidate(this);
625         delete &m;
626         super::sync(single);
627         return *(new Read(*this));
628     }
629
630     inline Write& syncToWrite(Handle &m, int single = super::DEFAULT_SYNC_SINGLE)
631     {
632         m.checkInvalidate(this);
633         delete &m;
634         super::sync(single);
635         return *(new Write(*this));
636     }
637
638     inline Write& getInitialWrite()
639     {
640         if (super::initHandleGiven)
641             throw MSA_InvalidHandle();
642
643         Write *w = new Write(*this);
644         super::initHandleGiven = true;
645         return *w;
646     }
647
648 protected:
649     inline const ENTRY& get(unsigned int row, unsigned int col)
650     {
651         return super::get(getIndex(row, col));
652     }
653
654     // known local
655     inline const ENTRY& get2(unsigned int row, unsigned int col)
656     {
657         return super::get2(getIndex(row, col));
658     }
659
660     // MSA2D::
661     inline ENTRY& set(unsigned int row, unsigned int col)
662     {
663         return super::set(getIndex(row, col));
664     }
665 };
666
667 namespace MSA
668 {
669     using std::min;
670     using std::max;
671
672
673 /**
674    The MSA3D class is a handle to a distributed shared array of items
675    of data type ENTRY. There are nEntries total numer of ENTRY's, with
676    ENTRIES_PER_PAGE data items per "page".  It is implemented as a
677    Chare Array of pages, and a Group representing the local cache.
678
679    The requirements for the templates are:
680      ENTRY: User data class stored in the array, with at least:
681         - A default constructor and destructor
682         - A working assignment operator
683         - A working pup routine
684      ENTRY_OPS_CLASS: Used to combine values for "accumulate":
685         - A method named "getIdentity", taking no arguments and
686           returning an ENTRY to use before any accumulation.
687         - A method named "accumulate", taking a source/dest ENTRY by reference
688           and an ENTRY to add to it by value or const reference.
689      ENTRIES_PER_PAGE: Optional integer number of ENTRY objects
690         to store and communicate at once.  For good performance,
691         make sure this value is a power of two.
692  */
693 template<class ENTRY, class ENTRY_OPS_CLASS, unsigned int ENTRIES_PER_PAGE>
694 class MSA3D
695 {
696     unsigned dim_x, dim_y, dim_z;
697
698
699 public:
700     typedef MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CacheGroup_t;
701     typedef CProxy_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_CacheGroup_t;
702     typedef CProxy_MSA_PageArray<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE> CProxy_PageArray_t;
703
704     // Sun's C++ compiler doesn't understand that nested classes are
705     // members for the sake of access to private data. (2008-10-23)
706     class Read; class Write; class Accum;
707     friend class Read; friend class Write; friend class Accum;
708
709         class Handle
710         {
711         protected:
712         MSA3D *msa;
713         bool valid;
714
715         friend class MSA3D;
716
717         void inline checkInvalidate() 
718         {
719             if (!valid)
720                 throw MSA_InvalidHandle();
721             valid = false;
722         }
723
724         Handle(MSA3D *msa_) 
725             : msa(msa_), valid(true) 
726         { }
727         void checkValid()
728         {
729             if (!valid)
730                 throw MSA_InvalidHandle();
731         }
732         
733     public:
734         inline void syncRelease()
735             {
736                 checkInvalidate();
737                 if (msa->active)
738                     msa->cache->SyncRelease();
739                 else
740                     CmiAbort("sync from an inactive thread!\n");
741                 msa->active = false;
742             }
743
744         inline void syncDone()
745             {
746                 checkInvalidate();
747                 msa->sync(DEFAULT_SYNC_SINGLE);
748             }
749
750         inline Read syncToRead()
751             {
752                 checkInvalidate();
753                 msa->sync(DEFAULT_SYNC_SINGLE);
754                 return Read(msa);
755             }
756
757         inline Write syncToWrite()
758             {
759                 checkInvalidate();
760                 msa->sync(DEFAULT_SYNC_SINGLE);
761                 return Write(msa);
762             }
763
764         inline Accum syncToAccum()
765             {
766                 checkInvalidate();
767                 msa->sync(DEFAULT_SYNC_SINGLE);
768                 return Accum(msa);
769             }
770
771         void pup(PUP::er &p)
772             {
773                 p|valid;
774                 if (valid)
775                 {
776                     if (p.isUnpacking())
777                         msa = new MSA3D;
778                     p|(*msa);
779                 }
780                 else if (p.isUnpacking())
781                     msa = NULL;
782             }
783
784         Handle() : msa(NULL), valid(false) {}
785     };
786
787     class Read : public Handle
788     {
789     protected:
790         friend class MSA3D;
791         Read(MSA3D *msa_)
792             :  Handle(msa_) { }
793         using Handle::checkValid;
794         using Handle::checkInvalidate;
795
796     public:
797         Read() {}
798
799         inline const ENTRY& get(unsigned x, unsigned y, unsigned z)
800         {
801             checkValid();
802             return Handle::msa->get(x, y, z); 
803         }
804         inline const ENTRY& operator()(unsigned x, unsigned y, unsigned z) { return get(x, y, z); }
805         inline const ENTRY& get2(unsigned x, unsigned y, unsigned z)
806         {
807             checkValid();
808             return Handle::msa->get2(x, y, z);
809         }
810
811         // Reads the specified range into the provided buffer in row-major order
812         void read(ENTRY *buf, unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2)
813         {
814             checkValid();
815
816             CkAssert(x1 <= x2);
817             CkAssert(y1 <= y2);
818             CkAssert(z1 <= z2);
819
820             CkAssert(x1 >= 0);
821             CkAssert(y1 >= 0);
822             CkAssert(z1 >= 0);
823
824             CkAssert(x2 < Handle::msa->dim_x);
825             CkAssert(y2 < Handle::msa->dim_y);
826             CkAssert(z2 < Handle::msa->dim_z);
827
828             unsigned i = 0;
829
830             for (unsigned ix = x1; ix <= x2; ++ix)
831                 for (unsigned iy = y1; iy <= y2; ++iy)
832                     for (unsigned iz = z1; iz <= z2; ++iz)
833                         buf[i++] = Handle::msa->get(ix, iy, iz);
834         }
835     };
836
837     class Write : public Handle
838     {
839     protected:
840         friend class MSA3D;
841         Write(MSA3D *msa_)
842             : Handle(msa_) { }
843
844     public:
845         Write() {}
846
847         inline Writable<ENTRY> set(unsigned x, unsigned y, unsigned z)
848         {
849             Handle::checkValid();
850             return Writable<ENTRY>(Handle::msa->set(x,y,z));
851         }
852         inline Writable<ENTRY> operator()(unsigned x, unsigned y, unsigned z)
853         {
854             return set(x,y,z);
855         }
856
857         void write(unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2, const ENTRY *buf)
858         {
859             Handle::checkValid();
860
861             CkAssert(x1 <= x2);
862             CkAssert(y1 <= y2);
863             CkAssert(z1 <= z2);
864
865             CkAssert(x1 >= 0);
866             CkAssert(y1 >= 0);
867             CkAssert(z1 >= 0);
868
869             CkAssert(x2 < Handle::msa->dim_x);
870             CkAssert(y2 < Handle::msa->dim_y);
871             CkAssert(z2 < Handle::msa->dim_z);
872
873             unsigned i = 0;
874
875             for (unsigned ix = x1; ix <= x2; ++ix)
876                 for (unsigned iy = y1; iy <= y2; ++iy)
877                     for (unsigned iz = z1; iz <= z2; ++iz)
878                     {
879                         if (isnan(buf[i]))
880                             CmiAbort("Tried to write a NaN!");
881                         Handle::msa->set(ix, iy, iz) = buf[i++];
882                     }
883         }
884 #if 0
885     private:
886         Write(Write &);
887 #endif
888     };
889
890     class Accum : public Handle
891     {
892     protected:
893         friend class MSA3D;
894         Accum(MSA3D *msa_)
895             : Handle(msa_) { }
896         using Handle::checkInvalidate;
897     public:
898         Accum() {}
899
900         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> accumulate(unsigned int x, unsigned int y, unsigned int z)
901         {
902             Handle::checkValid();
903             return Accumulable<ENTRY, ENTRY_OPS_CLASS>(Handle::msa->accumulate(x,y,z));
904         }
905         inline void accumulate(unsigned int x, unsigned int y, unsigned int z, const ENTRY& ent)
906         {
907             Handle::checkValid();
908             Handle::msa->accumulate(x,y,z, ent);
909         }
910
911         void accumulate(unsigned x1, unsigned y1, unsigned z1, unsigned x2, unsigned y2, unsigned z2, const ENTRY *buf)
912         {
913             Handle::checkValid();
914             CkAssert(x1 <= x2);
915             CkAssert(y1 <= y2);
916             CkAssert(z1 <= z2);
917
918             CkAssert(x1 >= 0);
919             CkAssert(y1 >= 0);
920             CkAssert(z1 >= 0);
921
922             CkAssert(x2 < Handle::msa->dim_x);
923             CkAssert(y2 < Handle::msa->dim_y);
924             CkAssert(z2 < Handle::msa->dim_z);
925
926             unsigned i = 0;
927
928             for (unsigned ix = x1; ix <= x2; ++ix)
929                 for (unsigned iy = y1; iy <= y2; ++iy)
930                     for (unsigned iz = z1; iz <= z2; ++iz)
931                         Handle::msa->accumulate(ix, iy, iz, buf[i++]);
932         }
933
934         inline Accumulable<ENTRY, ENTRY_OPS_CLASS> operator() (unsigned int x, unsigned int y, unsigned int z)
935             { return accumulate(x,y,z); }
936     };
937
938 protected:
939     /// Total number of ENTRY's in the whole array.
940     unsigned int nEntries;
941     bool initHandleGiven;
942
943     /// Handle to owner of cache.
944     CacheGroup_t* cache;
945     CProxy_CacheGroup_t cg;
946
947     inline const ENTRY* readablePage(unsigned int page)
948     {
949         return (const ENTRY*)(cache->readablePage(page));
950     }
951
952     // known local page.
953     inline const ENTRY* readablePage2(unsigned int page)
954     {
955         return (const ENTRY*)(cache->readablePage2(page));
956     }
957
958     // Returns a pointer to the start of the local copy in the cache of the writeable page.
959     // @@ what if begin - end span across two or more pages?
960     inline ENTRY* writeablePage(unsigned int page, unsigned int offset)
961     {
962         return (ENTRY*)(cache->writeablePage(page, offset));
963     }
964
965 public:
966     // @@ Needed for Jade
967     inline MSA3D() 
968         :initHandleGiven(false) 
969     {}
970
971     virtual void pup(PUP::er &p){
972         p|dim_x;
973         p|dim_y;
974         p|dim_z;
975         p|nEntries;
976         p|cg;
977         if (p.isUnpacking()) cache=cg.ckLocalBranch();
978     }
979
980     /**
981       Create a completely new MSA array.  This call creates the
982       corresponding groups, so only call it once per array.
983     */
984     inline MSA3D(unsigned x, unsigned y, unsigned z, unsigned int num_wrkrs, 
985                  unsigned int maxBytes=MSA_DEFAULT_MAX_BYTES)
986         : dim_x(x), dim_y(y), dim_z(z), initHandleGiven(false)
987     {
988         unsigned nEntries = x*y*z;
989         unsigned int nPages = (nEntries + ENTRIES_PER_PAGE - 1)/ENTRIES_PER_PAGE;
990         CProxy_PageArray_t pageArray = CProxy_PageArray_t::ckNew(nPages);
991         cg = CProxy_CacheGroup_t::ckNew(nPages, pageArray, maxBytes, nEntries, num_wrkrs);
992         pageArray.setCacheProxy(cg);
993         //pageArray.ckSetReductionClient(new CkCallback(CkIndex_MSA_CacheGroup<ENTRY, ENTRY_OPS_CLASS, ENTRIES_PER_PAGE>::SyncDone(), cg));
994         cache = cg.ckLocalBranch();
995     }
996
997     inline ~MSA3D()
998     {
999         // TODO: how to get rid of the cache group and the page array
1000         //(cache->getArray()).destroy();
1001         //cg.destroy();
1002         // TODO: calling FreeMem does not seem to work. Need to debug it.
1003         //cache->unroll();
1004         //cache->FreeMem();
1005     }
1006
1007     /**
1008      * this function is supposed to be called when the thread/object using this array
1009      * migrates to another PE.
1010      */
1011     inline void changePE()
1012     {
1013         cache = cg.ckLocalBranch();
1014
1015         /* don't need to update the number of entries, as that does not change */
1016     }
1017
1018     // ================ Accessor/Utility functions ================
1019
1020     inline const CProxy_CacheGroup_t &getCacheGroup() const { return cg; }
1021
1022     // Avoid using the term "page size" because it is confusing: does
1023     // it mean in bytes or number of entries?
1024     inline unsigned int getNumEntriesPerPage() const { return ENTRIES_PER_PAGE; }
1025
1026     inline unsigned int index(unsigned x, unsigned y, unsigned z)
1027     {
1028         CkAssert(x < dim_x);
1029         CkAssert(y < dim_y);
1030         CkAssert(z < dim_z);
1031         return ((x*dim_y) + y) * dim_z + z;
1032     }
1033     
1034     /// Return the page this entry is stored at.
1035     inline unsigned int getPageIndex(unsigned int idx)
1036     {
1037         return idx / ENTRIES_PER_PAGE;
1038     }
1039
1040     /// Return the offset, in entries, that this entry is stored at within a page.
1041     inline unsigned int getOffsetWithinPage(unsigned int idx)
1042     {
1043         return idx % ENTRIES_PER_PAGE;
1044     }
1045
1046     // ================ MSA API ================
1047
1048     // We need to know the total number of workers across all
1049     // processors, and we also calculate the number of worker threads
1050     // running on this processor.
1051     //
1052     // Blocking method, basically does a barrier until all workers
1053     // enroll.
1054     inline void enroll(int num_workers)
1055     {
1056         // @@ This is a hack to identify the number of MSA3D
1057         // threads on this processor.  This number is needed for sync.
1058         //
1059         // @@ What if a MSA3D thread migrates?
1060         cache->enroll(num_workers);
1061     }
1062
1063     // idx is the element to be read/written
1064     //
1065     // This function returns a reference to the first element on the
1066     // page that contains idx.
1067     inline ENTRY& getPageBottom(unsigned int idx, MSA_Page_Fault_t accessMode)
1068     {
1069         if (accessMode==Read_Fault) {
1070             unsigned int page = idx / ENTRIES_PER_PAGE;
1071             return const_cast<ENTRY&>(readablePage(page)[0]);
1072         } else {
1073             CkAssert(accessMode==Write_Fault || accessMode==Accumulate_Fault);
1074             unsigned int page = idx / ENTRIES_PER_PAGE;
1075             unsigned int offset = idx % ENTRIES_PER_PAGE;
1076             ENTRY* e=writeablePage(page, offset);
1077             return e[0];
1078         }
1079     }
1080
1081     inline void FreeMem()
1082     {
1083         cache->FreeMem();
1084     }
1085
1086     /// Non-blocking prefetch of entries from start to end, inclusive.
1087     /// Prefetch'd pages are locked into the cache, so you must call
1088     ///   unlock afterwards.
1089     inline void Prefetch(unsigned int start, unsigned int end)
1090     {
1091         unsigned int page1 = start / ENTRIES_PER_PAGE;
1092         unsigned int page2 = end / ENTRIES_PER_PAGE;
1093         cache->Prefetch(page1, page2);
1094     }
1095
1096     /// Block until all prefetched pages arrive.
1097     inline int WaitAll()    { return cache->WaitAll(); }
1098
1099     /// Unlock all locked pages
1100     inline void Unlock()    { return cache->UnlockPages(); }
1101
1102     /// start and end are element indexes.
1103     /// Unlocks completely spanned pages given a range of elements
1104     /// index'd from "start" to "end", inclusive.  If start/end does not span a
1105     /// page completely, i.e. start/end is in the middle of a page,
1106     /// the entire page is still unlocked--in particular, this means
1107     /// you should not have several adjacent ranges locked.
1108     inline void Unlock(unsigned int start, unsigned int end)
1109     {
1110         unsigned int page1 = start / ENTRIES_PER_PAGE;
1111         unsigned int page2 = end / ENTRIES_PER_PAGE;
1112         cache->UnlockPages(page1, page2);
1113     }
1114
1115     static const int DEFAULT_SYNC_SINGLE = 0;
1116
1117     inline Write getInitialWrite()
1118     {
1119         if (initHandleGiven)
1120             CmiAbort("Trying to get an MSA's initial handle a second time");
1121
1122         //Write *w = new Write(*this);
1123         //sync();
1124         initHandleGiven = true;
1125         return Write(this);
1126     }
1127
1128     inline Accum getInitialAccum()
1129     {
1130         if (initHandleGiven)
1131             CmiAbort("Trying to get an MSA's initial handle a second time");
1132
1133         //Accum *a = new Accum(*this);
1134         //sync();
1135         initHandleGiven = true;
1136         return Accum(this);
1137     }
1138
1139   // These are the meat of the MSA API, but they are only accessible
1140   // through appropriate handles (defined in the public section above).
1141 protected:
1142     /// Return a read-only copy of the element at idx.
1143     ///   May block if the element is not already in the cache.
1144     inline const ENTRY& get(unsigned x, unsigned y, unsigned z)
1145     {
1146         unsigned int idx = index(x,y,z);
1147         unsigned int page = idx / ENTRIES_PER_PAGE;
1148         unsigned int offset = idx % ENTRIES_PER_PAGE;
1149         return readablePage(page)[offset];
1150     }
1151
1152     /// Return a read-only copy of the element at idx;
1153     ///   ONLY WORKS WHEN ELEMENT IS ALREADY IN THE CACHE--
1154     ///   WILL SEGFAULT IF ELEMENT NOT ALREADY PRESENT.
1155     ///    Never blocks; may crash if element not already present.
1156     inline const ENTRY& get2(unsigned x, unsigned y, unsigned z)
1157     {
1158         unsigned int idx = index(x,y,z);
1159         unsigned int page = idx / ENTRIES_PER_PAGE;
1160         unsigned int offset = idx % ENTRIES_PER_PAGE;
1161         return readablePage2(page)[offset];
1162     }
1163
1164     /// Return a writeable copy of the element at idx.
1165     ///    Never blocks; will create a new blank element if none exists locally.
1166     ///    UNDEFINED if two threads set the same element.
1167     inline ENTRY& set(unsigned x, unsigned y, unsigned z)
1168     {
1169         unsigned int idx = index(x,y,z);
1170         unsigned int page = idx / ENTRIES_PER_PAGE;
1171         unsigned int offset = idx % ENTRIES_PER_PAGE;
1172         ENTRY* e=writeablePage(page, offset);
1173         return e[offset];
1174     }
1175
1176     /// Fetch the ENTRY at idx to be accumulated.
1177     ///   You must perform the accumulation on 
1178     ///     the return value before calling "sync".
1179     ///   Never blocks.
1180     inline ENTRY& accumulate(unsigned x, unsigned y, unsigned z)
1181     {
1182         unsigned int idx = index(x,y,z);
1183         unsigned int page = idx / ENTRIES_PER_PAGE;
1184         unsigned int offset = idx % ENTRIES_PER_PAGE;
1185         return cache->accumulate(page, offset);
1186     }
1187     
1188     /// Add ent to the element at idx.
1189     ///   Never blocks.
1190     ///   Merges together accumulates from different threads.
1191     inline void accumulate(unsigned x, unsigned y, unsigned z, const ENTRY& ent)
1192     {
1193         ENTRY_OPS_CLASS::accumulate(accumulate(x,y,z),ent);
1194     }
1195
1196     /// Synchronize reads and writes across the entire array.
1197     inline void sync(int single=0)
1198     {
1199         cache->SyncReq(single); 
1200     }
1201 };
1202
1203 }
1204 #endif