AMPI: Handle MPI_IN_PLACE uniformly where it already works, and error out where it...
authorPhil Miller <mille121@illinois.edu>
Thu, 15 Mar 2012 02:25:51 +0000 (21:25 -0500)
committerPhil Miller <mille121@illinois.edu>
Thu, 15 Mar 2012 03:04:25 +0000 (22:04 -0500)
src/libs/ck-libs/ampi/ampi.C

index 24f9f453e36b561313551a29c88dcf1405b2ec4a..08a134ab3b1869d4aad9d28dcf682021fa770863 100644 (file)
@@ -2989,6 +2989,13 @@ static int copyDatatype(MPI_Comm comm,MPI_Datatype type,int count,const void *in
   return MPI_SUCCESS;
 }
 
+static void handle_MPI_IN_PLACE(void* &inbuf, void* &outbuf)
+{
+  if (inbuf == MPI_IN_PLACE) inbuf = outbuf;
+  if (outbuf == MPI_IN_PLACE) outbuf = inbuf;
+  CmiAssert(inbuf != MPI_IN_PLACE && outbuf != MPI_IN_PLACE);
+}
+
 #define SYNCHRONOUS_REDUCE                           0
 
   CDECL
@@ -2997,9 +3004,7 @@ int AMPI_Reduce(void *inbuf, void *outbuf, int count, int type, MPI_Op op,
 {
   AMPIAPI("AMPI_Reduce");
 
-  if (inbuf == MPI_IN_PLACE) inbuf = outbuf;
-  if (outbuf == MPI_IN_PLACE) outbuf = inbuf;
-  CmiAssert(inbuf != MPI_IN_PLACE && outbuf != MPI_IN_PLACE);
+  handle_MPI_IN_PLACE(inbuf, outbuf);
 
 #if CMK_ERROR_CHECKING
   int ret;
@@ -3068,6 +3073,8 @@ int AMPI_Allreduce(void *inbuf, void *outbuf, int count, int type,
 {
   AMPIAPI("AMPI_Allreduce");
 
+  handle_MPI_IN_PLACE(inbuf, outbuf);
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, count, 1, type, 1, 0, 0, 0, 0, inbuf, 1, outbuf, 1);
@@ -3077,10 +3084,6 @@ int AMPI_Allreduce(void *inbuf, void *outbuf, int count, int type,
 
   ampi *ptr = getAmpiInstance(comm);
 
-  if (inbuf == MPI_IN_PLACE) inbuf = outbuf;
-  if (outbuf == MPI_IN_PLACE) outbuf = inbuf;
-  CmiAssert(inbuf != MPI_IN_PLACE && outbuf != MPI_IN_PLACE);
-
   CkDDT_DataType *ddt_type = ptr->getDDT()->getType(type);
 
 #if CMK_BIGSIM_CHARM
@@ -3132,6 +3135,8 @@ int AMPI_Iallreduce(void *inbuf, void *outbuf, int count, int type,
 {
   AMPIAPI("AMPI_Iallreduce");
 
+  handle_MPI_IN_PLACE(inbuf, outbuf);
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, count, 1, type, 1, 0, 0, 0, 0, inbuf, 1, outbuf, 1);
@@ -3167,6 +3172,8 @@ int AMPI_Reduce_scatter(void* sendbuf, void* recvbuf, int *recvcounts,
 {
   AMPIAPI("AMPI_Reduce_scatter");
 
+  handle_MPI_IN_PLACE(sendbuf, recvbuf);
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, 0, 0, datatype, 1, 0, 0, 0, 0, sendbuf, 1, recvbuf, 1);
@@ -3232,6 +3239,9 @@ int AMPI_Scan(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MP
   AMPIAPI("AMPI_Scan");
 
 #if CMK_ERROR_CHECKING
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Scan does not implement MPI_IN_PLACE");
+
   int ret;
   ret = errorCheck(comm, 1, count, 1, datatype, 1, 0, 0, 0, 0, sendbuf, 1, recvbuf, 1);
   if(ret != MPI_SUCCESS)
@@ -4229,6 +4239,8 @@ int AMPI_Ireduce(void *sendbuf, void *recvbuf, int count, int type, MPI_Op op,
 {
   AMPIAPI("AMPI_Ireduce");
 
+  handle_MPI_IN_PLACE(sendbuf, recvbuf);
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, count, 1, type, 1, 0, 0, root, 1, sendbuf, 1,
@@ -4265,6 +4277,9 @@ int AMPI_Allgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Allgather");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Allgather does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4323,6 +4338,9 @@ int AMPI_Iallgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Iallgather");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Iallgather does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4384,6 +4402,9 @@ int AMPI_Allgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Allgatherv");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Allgatherv does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4434,6 +4455,9 @@ int AMPI_Gather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Gather");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Gather does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4499,6 +4523,9 @@ int AMPI_Gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Gatherv");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Gatherv does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4560,6 +4587,9 @@ int AMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 {
   AMPIAPI("AMPI_Scatter");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Scatter does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4622,6 +4652,9 @@ int AMPI_Scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype send
 {
   AMPIAPI("AMPI_Scatterv");
 
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Scatterv does not implement MPI_IN_PLACE");
+
 #if CMK_ERROR_CHECKING
   int ret;
   ret = errorCheck(comm, 1, 0, 0, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
@@ -4685,6 +4718,9 @@ int AMPI_Alltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   AMPIAPI("AMPI_Alltoall");
 
 #if CMK_ERROR_CHECKING
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("MPI_Alltoall does not accept MPI_IN_PLACE");
+
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
   if(ret != MPI_SUCCESS)
@@ -4969,6 +5005,9 @@ int AMPI_Alltoall2(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   AMPIAPI("AMPI_Alltoall2");
 
 #if CMK_ERROR_CHECKING
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Alltoall2 does not accept MPI_IN_PLACE");
+
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
   if(ret != MPI_SUCCESS)
@@ -5030,6 +5069,9 @@ int AMPI_Ialltoall(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   AMPIAPI("AMPI_Ialltoall");
 
 #if CMK_ERROR_CHECKING
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("AMPI_Ialltoall does not accept MPI_IN_PLACE");
+
   int ret;
   ret = errorCheck(comm, 1, sendcount, 1, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
   if(ret != MPI_SUCCESS)
@@ -5091,6 +5133,9 @@ int AMPI_Alltoallv(void *sendbuf, int *sendcounts_, int *sdispls_,
   if(comm==MPI_COMM_SELF) return 0;
 
 #if CMK_ERROR_CHECKING
+  if (sendbuf == MPI_IN_PLACE || recvbuf == MPI_IN_PLACE)
+    CmiAbort("MPI_Alltoallv does not accept MPI_IN_PLACE");
+
   int ret;
   ret = errorCheck(comm, 1, 0, 0, sendtype, 1, 0, 0, 0, 0, sendbuf, 1);
   if(ret != MPI_SUCCESS)