CUDA中的多精度乘法
multi-precision multiplication in CUDA
我正在尝试在 CUDA 中实现多精度乘法。为此,我实现了一个内核,该内核应该计算 uint32_t
类型操作数与 256 位操作数的乘法并将结果放入 288 位数组中。到目前为止,我想出了这个代码:
__device__ __constant__ UN_256fe B_const;
__global__ void multiply32x256Kernel(uint32_t A, UN_288bite* result){
uint8_t tid = blockIdx.x * blockDim.x + threadIdx.x;
//for managing warps
//uint8_t laineid = tid % 32;
//allocate partial products into array of uint64_t
__shared__ uint64_t partialMuls[8];
uint32_t carry, r;
if((tid < 8) && (tid != 0)){
//compute partial products
partialMuls[tid] = A * B_const.uint32[tid];
//add partial products and propagate carry
result->uint32[8] = (uint32_t)partialMuls[7];
r = (partialMuls[tid] >> 32) + ((uint32_t)partialMuls[tid - 1]);
carry = r < (partialMuls[tid] >> 32);
result->uint32[0] = (partialMuls[0] >> 32);
while(__any(carry)){
r = r + carry;
//new carry?
carry = r < carry;
}
result->uint32[tid] = r;
}
我的数据类型是:
typedef struct UN_256fe{
uint32_t uint32[8];
}UN_256fe;
typedef struct UN_288bite{
uint32_t uint32[9];
}UN_288bite;
我的内核可以工作,但它给了我错误的结果。我无法在内核内部进行调试,所以如果有人让我知道问题出在哪里或如何在 tegra-ubuntu
和 cuda-6.0
上调试内核内部的代码,我将不胜感激。
谢谢
这个答案与CUDA本身无关,而是一个通用的C实现。
我不太明白你在做什么(尤其是 carry
),但你可以根据我自己的 big num 函数尝试这个片段。我定义了 dtype
以便更容易地使用较小的字段进行测试。请注意,我没有专门使用carry
,而是将部分产品进行了处理。
// little-endian
#include <stdio.h>
#include <stdint.h>
#include <limits.h>
#define dtype uint8_t // for testing
//#define dtype uint32_t // for proper ver
#define SHIFTS (sizeof(dtype)*CHAR_BIT)
#define NIBBLES (SHIFTS/4)
#define ARRLEN 8
typedef struct UN_256fe {
dtype uint[ARRLEN];
} UN_256fe;
typedef struct UN_288bite {
dtype uint[ARRLEN+1];
} UN_288bite;
void multiply(UN_288bite *product, UN_256fe *operand, dtype multiplier)
{
int i;
uint64_t partial = 0;
for (i=0; i<ARRLEN; i++) {
partial = partial + (uint64_t)multiplier * operand->uint[i];
product->uint[i] = (dtype)partial;
partial >>= SHIFTS; // carry
}
product->uint[i] = (dtype)partial;
}
int main(void)
{
int i;
dtype multiplier = 0xAA;
UN_256fe operand = { 1, 2, 3, 4, 5, 6, 7, 8};
UN_288bite product;
multiply(&product, &operand, multiplier);
for(i=ARRLEN-1; i>=0; i--)
printf("%0*X", NIBBLES, operand.uint[i]);
printf("\n * %0*X = \n", NIBBLES, multiplier);
for(i=ARRLEN; i>=0; i--)
printf("%0*X", NIBBLES, product.uint[i]);
printf("\n");
return 0;
}
uint8_t
的程序输出
0807060504030201
* AA =
0554A9FF54A9FF54AA
我正在尝试在 CUDA 中实现多精度乘法。为此,我实现了一个内核,该内核应该计算 uint32_t
类型操作数与 256 位操作数的乘法并将结果放入 288 位数组中。到目前为止,我想出了这个代码:
__device__ __constant__ UN_256fe B_const;
__global__ void multiply32x256Kernel(uint32_t A, UN_288bite* result){
uint8_t tid = blockIdx.x * blockDim.x + threadIdx.x;
//for managing warps
//uint8_t laineid = tid % 32;
//allocate partial products into array of uint64_t
__shared__ uint64_t partialMuls[8];
uint32_t carry, r;
if((tid < 8) && (tid != 0)){
//compute partial products
partialMuls[tid] = A * B_const.uint32[tid];
//add partial products and propagate carry
result->uint32[8] = (uint32_t)partialMuls[7];
r = (partialMuls[tid] >> 32) + ((uint32_t)partialMuls[tid - 1]);
carry = r < (partialMuls[tid] >> 32);
result->uint32[0] = (partialMuls[0] >> 32);
while(__any(carry)){
r = r + carry;
//new carry?
carry = r < carry;
}
result->uint32[tid] = r;
}
我的数据类型是:
typedef struct UN_256fe{
uint32_t uint32[8];
}UN_256fe;
typedef struct UN_288bite{
uint32_t uint32[9];
}UN_288bite;
我的内核可以工作,但它给了我错误的结果。我无法在内核内部进行调试,所以如果有人让我知道问题出在哪里或如何在 tegra-ubuntu
和 cuda-6.0
上调试内核内部的代码,我将不胜感激。
谢谢
这个答案与CUDA本身无关,而是一个通用的C实现。
我不太明白你在做什么(尤其是 carry
),但你可以根据我自己的 big num 函数尝试这个片段。我定义了 dtype
以便更容易地使用较小的字段进行测试。请注意,我没有专门使用carry
,而是将部分产品进行了处理。
// little-endian
#include <stdio.h>
#include <stdint.h>
#include <limits.h>
#define dtype uint8_t // for testing
//#define dtype uint32_t // for proper ver
#define SHIFTS (sizeof(dtype)*CHAR_BIT)
#define NIBBLES (SHIFTS/4)
#define ARRLEN 8
typedef struct UN_256fe {
dtype uint[ARRLEN];
} UN_256fe;
typedef struct UN_288bite {
dtype uint[ARRLEN+1];
} UN_288bite;
void multiply(UN_288bite *product, UN_256fe *operand, dtype multiplier)
{
int i;
uint64_t partial = 0;
for (i=0; i<ARRLEN; i++) {
partial = partial + (uint64_t)multiplier * operand->uint[i];
product->uint[i] = (dtype)partial;
partial >>= SHIFTS; // carry
}
product->uint[i] = (dtype)partial;
}
int main(void)
{
int i;
dtype multiplier = 0xAA;
UN_256fe operand = { 1, 2, 3, 4, 5, 6, 7, 8};
UN_288bite product;
multiply(&product, &operand, multiplier);
for(i=ARRLEN-1; i>=0; i--)
printf("%0*X", NIBBLES, operand.uint[i]);
printf("\n * %0*X = \n", NIBBLES, multiplier);
for(i=ARRLEN; i>=0; i--)
printf("%0*X", NIBBLES, product.uint[i]);
printf("\n");
return 0;
}
uint8_t
0807060504030201
* AA =
0554A9FF54A9FF54AA