16 位浮点数 MPI_Reduce?
16-bit float MPI_Reduce?
我有一个分布式应用程序,它使用 MPI_Reduce()
进行某些通信。在精度方面,我们使用 16 位浮点数(半精度)得到完全准确的结果。
为了加速通信(减少数据移动量),有没有办法在16位浮点数上调用MPI_Reduce()
?
(我查看了 MPI 文档,没有看到任何关于 16 位浮点数的信息。)
MPI
标准在其内部数据类型中仅定义了 32 位 (MPI_FLOAT
) 或 64 位 (MPI_DOUBLE
) 浮点数。
但是,您始终可以创建您自己的 MPI_Datatype
和您自己的自定义化简操作。下面的代码给出了一些关于如何执行此操作的粗略概念。由于不清楚您使用的是哪种 16 位浮点数实现,我将把类型简称为 float16_t
,将加法运算简称为 fp16_add()
.
// define custom reduce operation
void my_fp16_sum(void* invec, void* inoutvec, int *len,
MPI_Datatype *datatype) {
// cast invec and inoutvec to your float16 type
float16_t* in = (float16_t)invec;
float16_t* inout = (float16_t)inoutvec;
for (int i = 0; i < *len; ++i) {
// sum your 16 bit floats
*inout = fp16_add(*in, *inout);
}
}
// ...
// in your code:
// create 2-byte datatype (send raw, un-interpreted bytes)
MPI_Datatype mpi_type_float16;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_float16);
MPI_Type_commit(&mpi_type_float16);
// create user op (pass function pointer to your user function)
MPI_Op mpi_fp16sum;
MPI_Op_create(&my_fp16_sum, 1, &mpi_fp16sum);
// call MPI_Reduce using your custom reduction operation
MPI_Reduce(&fp16_val, &fp16_result, 1, mpi_type_float16, mpi_fp16sum, 0, MPI_COMM_WORLD);
// clean up (freeing of the custom MPI_Op and MPI_Datatype)
MPI_Type_free(&mpi_type_float16);
MPI_Op_free(&mpi_fp16sum);
我有一个分布式应用程序,它使用 MPI_Reduce()
进行某些通信。在精度方面,我们使用 16 位浮点数(半精度)得到完全准确的结果。
为了加速通信(减少数据移动量),有没有办法在16位浮点数上调用MPI_Reduce()
?
(我查看了 MPI 文档,没有看到任何关于 16 位浮点数的信息。)
MPI
标准在其内部数据类型中仅定义了 32 位 (MPI_FLOAT
) 或 64 位 (MPI_DOUBLE
) 浮点数。
但是,您始终可以创建您自己的 MPI_Datatype
和您自己的自定义化简操作。下面的代码给出了一些关于如何执行此操作的粗略概念。由于不清楚您使用的是哪种 16 位浮点数实现,我将把类型简称为 float16_t
,将加法运算简称为 fp16_add()
.
// define custom reduce operation
void my_fp16_sum(void* invec, void* inoutvec, int *len,
MPI_Datatype *datatype) {
// cast invec and inoutvec to your float16 type
float16_t* in = (float16_t)invec;
float16_t* inout = (float16_t)inoutvec;
for (int i = 0; i < *len; ++i) {
// sum your 16 bit floats
*inout = fp16_add(*in, *inout);
}
}
// ...
// in your code:
// create 2-byte datatype (send raw, un-interpreted bytes)
MPI_Datatype mpi_type_float16;
MPI_Type_contiguous(2, MPI_BYTE, &mpi_type_float16);
MPI_Type_commit(&mpi_type_float16);
// create user op (pass function pointer to your user function)
MPI_Op mpi_fp16sum;
MPI_Op_create(&my_fp16_sum, 1, &mpi_fp16sum);
// call MPI_Reduce using your custom reduction operation
MPI_Reduce(&fp16_val, &fp16_result, 1, mpi_type_float16, mpi_fp16sum, 0, MPI_COMM_WORLD);
// clean up (freeing of the custom MPI_Op and MPI_Datatype)
MPI_Type_free(&mpi_type_float16);
MPI_Op_free(&mpi_fp16sum);