Implement alltoall with iget. Currently shadowed as MPI_Alltoall2
authorYan Shi <yanshi@uiuc.edu>
Wed, 12 Apr 2006 04:20:20 +0000 (04:20 +0000)
committerYan Shi <yanshi@uiuc.edu>
Wed, 12 Apr 2006 04:20:20 +0000 (04:20 +0000)
src/libs/ck-libs/ampi/ampi.C
src/libs/ck-libs/ampi/ampi.ci
src/libs/ck-libs/ampi/ampiimpl.h

index 1e808a7bba2a6d30e5e3ba53d9241e28ef9a87ee..14730fe2bab32b3b60770ccac3661d3fb2f60e6c 100644 (file)
@@ -1607,6 +1607,22 @@ ampi::bcastraw(void* buf, int len, CkArrayID aid)
   pa.generic(msg);
 }
 
+
+AmpiMsg* 
+ampi::Alltoall_RemoteIGet(int disp, int cnt, MPI_Datatype type, int tag)
+{
+  CkAssert(tag==MPI_ATA_TAG && AlltoallGetFlag);
+  int unit;
+  CkDDT_DataType *ddt = getDDT()->getType(type);
+  unit = ddt->getSize(1);
+  int totalsize = unit*cnt;
+
+  AmpiMsg *msg = new (totalsize, 0) AmpiMsg(-1, -1, -1, thisIndex,totalsize,myComm.getComm());
+  char* addr = (char*)Alltoallbuff+disp*unit;
+  ddt->serialize((char*)msg->data, addr, cnt, (-1));
+  return msg;
+}
+
 int MPI_null_copy_fn (MPI_Comm comm, int keyval, void *extra_state,
                        void *attr_in, void *attr_out, int *flag){
   (*flag) = 0;
@@ -3164,12 +3180,15 @@ int AMPI_Alltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   AMPIAPI("AMPI_Alltoall");
   if(getAmpiParent()->isInter(comm)) CkAbort("MPI_Alltoall not allowed for Inter-communicator!");
   if(comm==MPI_COMM_SELF) return copyDatatype(comm,sendtype,sendcount,sendbuf,recvbuf);
+  AMPI_Barrier(comm);
   ampi *ptr = getAmpiInstance(comm);
   int size = ptr->getSize(comm);
   CkDDT_DataType *dttype;
   int itemsize;
   int i;
 
+  AMPI_Barrier(comm);
+
     // post receives
   dttype = ptr->getDDT()->getType(recvtype) ;
   itemsize = dttype->getSize(recvcount) ;
@@ -3179,7 +3198,7 @@ int AMPI_Alltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
               i, MPI_ATA_TAG, comm, &reqs[i]);
   }
   //AMPI_Yield(comm);
-  //AMPI_Barrier(comm);
+  AMPI_Barrier(comm);
 
   dttype = ptr->getDDT()->getType(sendtype) ;
   itemsize = dttype->getSize(sendcount) ;
@@ -3212,6 +3231,56 @@ int AMPI_Alltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   return 0;
 }
 
+CDECL
+int AMPI_Alltoall2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
+                 void *recvbuf, int recvcount, MPI_Datatype recvtype,
+                 MPI_Comm comm)
+{
+  AMPIAPI("AMPI_Alltoall2");
+  if(getAmpiParent()->isInter(comm)) CkAbort("MPI_Alltoall not allowed for Inter-communicator!");
+  if(comm==MPI_COMM_SELF) return copyDatatype(comm,sendtype,sendcount,sendbuf,recvbuf);
+  ampi *ptr = getAmpiInstance(comm);
+  CProxy_ampi pa(ptr->ckGetArrayID());
+  int size = ptr->getSize(comm);
+  CkDDT_DataType *dttype;
+  int itemsize;
+  int recvdisp;
+  int myrank;
+  int i;
+  // Set flags for others to get
+  ptr->setA2AIGetFlag((void*)sendbuf);
+  MPI_Comm_rank(comm,&myrank);
+  recvdisp = myrank*recvcount;
+
+  AMPI_Barrier(comm);
+    // post receives
+  MPI_Request *reqs = new MPI_Request[size];
+  for(i=0;i<size;i++) {
+         reqs[i] = pa[i].Alltoall_RemoteIGet(recvdisp, recvcount, recvtype,
+MPI_ATA_TAG);
+  }
+
+  dttype = ptr->getDDT()->getType(recvtype) ;
+  itemsize = dttype->getSize(recvcount) ;
+  AmpiMsg *msg;
+  for(i=0;i<size;i++) {
+         msg = (AmpiMsg*)CkWaitReleaseFuture(reqs[i]);
+         memcpy((char*)recvbuf+(itemsize*i), msg->data,itemsize);
+         delete msg;
+  }
+  
+  delete [] reqs;
+  AMPI_Barrier(comm);
+
+  // Reset flags 
+  ptr->resetA2AIGetFlag();
+  
+#if AMPI_COUNTER
+  getAmpiParent()->counters.alltoall++;
+#endif
+  return 0;
+}
+
 CDECL
 int AMPI_Ialltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
                  void *recvbuf, int recvcount, MPI_Datatype recvtype,
index 8582358656e40ed714e0e3e0d788772e3d33badb..879fa52376e484f70826681dd713530c61bbdf6d 100644 (file)
@@ -40,6 +40,7 @@ module ampi {
     entry void winRemoteUnlock(int winIndex, CkFutureID ftHandle, int pe_src, int requestRank);
     entry [iget] AmpiMsg *winRemoteIGet(int orgdisp, int orgcnt, MPI_Datatype orgtype,
                            MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, int winIndex);
+    entry [iget] AmpiMsg *Alltoall_RemoteIGet(int disp, int cnt, MPI_Datatype type, int tag);
   };
 
   group [migratable] ampiWorlds {
index 944bf23b7b3b0758da10efaa00bfa844a7518449..b5fed7b94c9fa0c3aba0c5f59c27587d3e48b97d 100644 (file)
@@ -1394,6 +1394,14 @@ friend class IReq;
     void winGetName(WinStruct win, char *name, int *length);
     win_obj* getWinObjInstance(WinStruct win); 
     int getNewSemaId(); 
+
+    AmpiMsg* Alltoall_RemoteIGet(int disp, int targcnt, MPI_Datatype targtype, int tag);
+private:
+    int AlltoallGetFlag;
+    void *Alltoallbuff;
+public:
+    void setA2AIGetFlag(void* ptr) {AlltoallGetFlag=1;Alltoallbuff=ptr;}
+    void resetA2AIGetFlag() {AlltoallGetFlag=0;Alltoallbuff=NULL;} 
     //------------------------ End of code by YAN ---------------------
 };