如何在 CUDA 中正确避免 single/double 精度的复制粘贴方法
How to properly avoid copy pasting methods for single/double precision in CUDA
使用 CUDA 时,我经常比较单精度和双精度的执行时间 (float/double)。为了避免复制粘贴方法,我经常在标准情况下使用模板在 float 和 double 之间切换。
当我必须使用来自 cusparse/cublas 库的外部方法时,问题就开始了。在这种特殊情况下,例如:
cublasSaxpy() // single precision
cublasDaxpy() // double precision
如果懒,最简单的办法就是复制粘贴的方法
myFloatMethod(float var)
{
// do stuff in float
cublasSaxpy(var);
}
myDoubleMethod(double var)
{
// do stuff in double
cublasDaxpy(var);
}
我已经尝试搜索这个问题,我找到的唯一解决方案是全局定义这样的方法:
#define cublasTaxpy cublasSaxpy // or cublasDaxpy
#define DATATYPE float // or double
并使用 cublasTaxpy 而不是 cublasSaxpy/cublasDaxpy。每次我想更改精度时,我只更改定义而没有重复代码或遍历整个代码。
有什么好的方法可以做得更好吗?
您可以为 cublasTaxpy()
创建重载而不是宏
void cublasTaxpy(float f) { cublasSaxpy(f); }
void cublasTaxpy(double d) { cublasDaxpy(d); }
或者将整个函数集包装在专门的结构中:
template<typename FLOAT> struct helper_cublas;
template<> struct helper_cublas<float> {
static void cublasTaxpy(float f) { cublasSaxpy(f); }
// other functions
};
template<> struct helper_cublas<double> {
static void cublasTaxpy(double d) { cublasDaxpy(d); }
// other functions
};
使用 CUDA 时,我经常比较单精度和双精度的执行时间 (float/double)。为了避免复制粘贴方法,我经常在标准情况下使用模板在 float 和 double 之间切换。
当我必须使用来自 cusparse/cublas 库的外部方法时,问题就开始了。在这种特殊情况下,例如:
cublasSaxpy() // single precision
cublasDaxpy() // double precision
如果懒,最简单的办法就是复制粘贴的方法
myFloatMethod(float var)
{
// do stuff in float
cublasSaxpy(var);
}
myDoubleMethod(double var)
{
// do stuff in double
cublasDaxpy(var);
}
我已经尝试搜索这个问题,我找到的唯一解决方案是全局定义这样的方法:
#define cublasTaxpy cublasSaxpy // or cublasDaxpy
#define DATATYPE float // or double
并使用 cublasTaxpy 而不是 cublasSaxpy/cublasDaxpy。每次我想更改精度时,我只更改定义而没有重复代码或遍历整个代码。
有什么好的方法可以做得更好吗?
您可以为 cublasTaxpy()
void cublasTaxpy(float f) { cublasSaxpy(f); }
void cublasTaxpy(double d) { cublasDaxpy(d); }
或者将整个函数集包装在专门的结构中:
template<typename FLOAT> struct helper_cublas;
template<> struct helper_cublas<float> {
static void cublasTaxpy(float f) { cublasSaxpy(f); }
// other functions
};
template<> struct helper_cublas<double> {
static void cublasTaxpy(double d) { cublasDaxpy(d); }
// other functions
};