C++矩阵乘法类型检测
C++ Matrix multiplication type detection
在我的 C++ 代码中,我有一个矩阵 class,以及一些用于将它们相乘的运算符。我的 class 是模板化的,这意味着我可以拥有 int、float、double ... 矩阵。
我猜我的运算符重载是 classic
template <typename T, typename U>
Matrix<T>& operator*(const Matrix<T>& a, const Matrix<U>& b)
{
assert(a.rows() == b.cols() && "You have to multiply a MxN matrix with a NxP one to get a MxP matrix\n");
Matrix<T> *c = new Matrix<T>(a.rows(), b.cols());
for (int ci=0 ; ci<c->rows() ; ++ci)
{
for (int cj=0 ; cj<c->cols() ; ++cj)
{
c->at(ci,cj)=0;
for (int k=0 ; k<a.cols() ; ++k)
{
c->at(ci,cj) += (T)(a.at(ci,k)*b.at(k,cj));
}
}
}
return *c;
}
在这段代码中,我 return 一个与第一个参数相同类型的矩阵,即 Matrix<int> * Matrix<float> = Matrix<int>
。我的问题是如何在我给出的两个类型中检测到最精确的类型而不损失太多精度,即 Matrix<int> * Matrix<float> = Matrix<float>
?有高手吗?
您想要的只是将 T
乘以 U
时出现的类型。可以通过以下方式给出:
template <class T, class U>
using product_type = decltype(std::declval<T>() * std::declval<U>());
您可以将其用作额外的默认模板参数:
template <typename T, typename U, typename R = product_type<T, U>>
Matrix<R> operator*(const Matrix<T>& a, const Matrix<U>& b) {
...
}
在 C++03 中,您可以通过使用许多小辅助类型进行大量重载来完成同样的事情(Boost 就是这样做的):
template <int I> struct arith;
template <int I, typename T> struct arith_helper {
typedef T type;
typedef char (&result_type)[I];
};
template <> struct arith<1> : arith_helper<1, bool> { };
template <> struct arith<2> : arith_helper<2, bool> { };
template <> struct arith<3> : arith_helper<3, signed char> { };
template <> struct arith<4> : arith_helper<4, short> { };
// ... lots more
然后我们可以写:
template <class T, class U>
class common_type {
private:
static arith<1>::result_type select(arith<1>::type );
static arith<2>::result_type select(arith<2>::type );
static arith<3>::result_type select(arith<3>::type );
// ...
static bool cond();
public:
typedef typename arith<sizeof(select(cond() ? T() : U() ))>::type type;
};
假设你写出了所有的整数类型,那么你可以在我之前使用 product_type
.
的地方使用 typename common_type<T, U>::type
如果这不是 C++11 有多酷的演示,我不知道什么才是。
注意,operator*
不应该 return 引用。你在做什么会泄漏内存。
在我的 C++ 代码中,我有一个矩阵 class,以及一些用于将它们相乘的运算符。我的 class 是模板化的,这意味着我可以拥有 int、float、double ... 矩阵。
我猜我的运算符重载是 classic
template <typename T, typename U>
Matrix<T>& operator*(const Matrix<T>& a, const Matrix<U>& b)
{
assert(a.rows() == b.cols() && "You have to multiply a MxN matrix with a NxP one to get a MxP matrix\n");
Matrix<T> *c = new Matrix<T>(a.rows(), b.cols());
for (int ci=0 ; ci<c->rows() ; ++ci)
{
for (int cj=0 ; cj<c->cols() ; ++cj)
{
c->at(ci,cj)=0;
for (int k=0 ; k<a.cols() ; ++k)
{
c->at(ci,cj) += (T)(a.at(ci,k)*b.at(k,cj));
}
}
}
return *c;
}
在这段代码中,我 return 一个与第一个参数相同类型的矩阵,即 Matrix<int> * Matrix<float> = Matrix<int>
。我的问题是如何在我给出的两个类型中检测到最精确的类型而不损失太多精度,即 Matrix<int> * Matrix<float> = Matrix<float>
?有高手吗?
您想要的只是将 T
乘以 U
时出现的类型。可以通过以下方式给出:
template <class T, class U>
using product_type = decltype(std::declval<T>() * std::declval<U>());
您可以将其用作额外的默认模板参数:
template <typename T, typename U, typename R = product_type<T, U>>
Matrix<R> operator*(const Matrix<T>& a, const Matrix<U>& b) {
...
}
在 C++03 中,您可以通过使用许多小辅助类型进行大量重载来完成同样的事情(Boost 就是这样做的):
template <int I> struct arith;
template <int I, typename T> struct arith_helper {
typedef T type;
typedef char (&result_type)[I];
};
template <> struct arith<1> : arith_helper<1, bool> { };
template <> struct arith<2> : arith_helper<2, bool> { };
template <> struct arith<3> : arith_helper<3, signed char> { };
template <> struct arith<4> : arith_helper<4, short> { };
// ... lots more
然后我们可以写:
template <class T, class U>
class common_type {
private:
static arith<1>::result_type select(arith<1>::type );
static arith<2>::result_type select(arith<2>::type );
static arith<3>::result_type select(arith<3>::type );
// ...
static bool cond();
public:
typedef typename arith<sizeof(select(cond() ? T() : U() ))>::type type;
};
假设你写出了所有的整数类型,那么你可以在我之前使用 product_type
.
typename common_type<T, U>::type
如果这不是 C++11 有多酷的演示,我不知道什么才是。
注意,operator*
不应该 return 引用。你在做什么会泄漏内存。