模板函数重载和 SFINAE 实现
Template function overloading and SFINAE implementations
我正在花一些时间学习如何在 C++ 中使用模板。我从没用过
之前,我并不总是确定在不同情况下可以实现什么或不能实现什么。
作为练习,我包装了我在活动中使用的一些 Blas 和 Lapack 函数,
我目前正在研究 ?GELS
的包装(评估线性方程组的解)。
A x + b = 0
?GELS
函数(仅适用于实数值)存在两个名称:SGELS
,用于单精度向量和
DGELS
双精度。
我对接口的想法是这样的函数solve
:
const std::size_t rows = /* number of rows for A */;
const std::size_t cols = /* number of cols for A */;
std::array< double, rows * cols > A = { /* values */ };
std::array< double, ??? > b = { /* values */ }; // ??? it can be either
// rows or cols. It depends on user
// problem, in general
// max( dim(x), dim(b) ) =
// max( cols, rows )
solve< double, rows, cols >(A, b);
// the solution x is stored in b, thus b
// must be "large" enough to accomodate x
根据用户要求,问题可能是多定的或未定的,即:
- 如果是超定的
dim(b) > dim(x)
(解是伪逆)
- 如果不确定
dim(b) < dim(x)
(解决方案是LSQ最小化)
- 或者
dim(b) = dim(x)
的正常情况(解是A
的逆)
(不考虑个别情况)
由于 ?GELS
将结果存储在输入向量 b
中,因此 std::array
应该
有足够的 space 来容纳解决方案,如代码注释 (max(rows, cols)
) 中所述。
我想(编译时)确定采用哪种解决方案(这是一个参数更改
在 ?GELS
通话中)。我有两个功能(为了这个问题我正在简化),
处理精度并且已经知道哪个是 b
的维度和 rows
/cols
:
的数量
namespace wrap {
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) {
SGELS(/* Called in the right way */);
}
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) {
DGELS(/* Called in the right way */);
}
}; /* namespace wrap */
它们是内部包装的一部分。用户函数,确定所需的大小
在通过模板的 b
向量中:
#include <type_traits>
/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim {
static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
std::integral_constant< std::size_t, cols > >::type::value;
};
/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;
/** Here we have the function that allows only the call with b of
* the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) {
static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
}
这样确实有效。但是我想更进一步,我真的不知道该怎么做。
如果用户试图调用 solve
并且 b
的大小太小,编译器会引发一个极其难以阅读的错误。
我正在尝试插入
static_assert
帮助用户理解他的错误。但我脑海中出现的任何方向
需要使用两个具有相同签名的函数(这就像模板重载?)
我找不到 SFINAE 策略(实际上它们根本不编译)。
您认为在 编译时 b
维度 不更改用户界面 的情况下可以提出静态断言吗?
我希望问题足够清楚。
@Caninonos:对我来说,用户界面就是用户调用求解器的方式,即:
solve< type, number of rows, number of cols > (matrix A, vector b)
这是我为了提高自己的技能而对锻炼施加的约束。这意味着,我不知道是否真的有可能实现解决方案。 b
的类型必须与函数调用匹配,如果我添加另一个模板参数并更改用户界面,这很容易违反我的约束。
最小的完整和工作示例
这是一个最小的完整的工作示例。根据要求,我删除了对线性代数概念的任何引用。是个数的问题。个案为:
N1 = 2, N2 =2
。由于 N3 = max(N1, N2) = 2
一切正常
N1 = 2, N2 =1
。由于 N3 = max(N1, N2) = N1 = 2
一切正常
N1 = 1, N2 =2
。由于 N3 = max(N1, N2) = N2 = 2
一切正常
N1 = 1, N2 =2
。由于 N3 = N1 = 1 < N2
它正确地引发了编译错误。我想用一个静态断言来拦截编译错误,该断言解释了 N3
的维数错误的事实。至于现在的错误很难阅读和理解。
您为什么不尝试将 tag dispatch 与一些 static_assert
组合在一起?我希望以下是实现您想要解决的问题的一种方法。我的意思是,所有三个正确的案例都正确地传递给了正确的 blas
调用,处理了不同的类型和维度不匹配,并且还处理了关于 float
和 double
s 的违规,所有以 user-friendly 的方式,感谢 static_assert
。
编辑。 我不确定你的 C++
版本要求,但下面是 C++11
友好的。
#include <algorithm>
#include <iostream>
#include <type_traits>
template <class value_t, int nrows, int ncols> struct Matrix {};
template <class value_t, int rows> struct Vector {};
template <class value_t> struct blas;
template <> struct blas<float> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
template <> struct blas<double> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
class overdet {};
class underdet {};
class normal {};
template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) {
static_assert(std::is_same<T1, T2>::value,
"lhs and rhs must have the same value types");
static_assert(dim >= nrows && dim >= ncols,
"rhs does not have enough space");
static_assert(std::is_same<T1, float>::value ||
std::is_same<T1, double>::value,
"Only float or double are accepted");
solve_impl(lhs, rhs,
typename std::conditional<(nrows < ncols), underdet,
typename std::conditional<(nrows > ncols), overdet,
normal>::type>::type{});
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, underdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::underdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, overdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::overdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, normal) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::normal(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
int main() {
/* valid types */
Matrix<float, 2, 4> A1;
Matrix<float, 4, 4> A2;
Matrix<float, 5, 4> A3;
Vector<float, 4> b1;
Vector<float, 5> b2;
solve(A1, b1);
solve(A2, b1);
solve(A3, b2);
Matrix<int, 4, 4> A4;
Vector<int, 4> b3;
// solve(A4, b3); // static_assert for float & double
Matrix<float, 4, 4> A5;
Vector<int, 4> b4;
// solve(A5, b4); // static_assert for different types
// solve(A3, b1); // static_assert for dimension problem
return 0;
}
您必须考虑为什么 界面提供这些(令人费解的)乱七八糟的参数。作者想到了几件事。首先,您可以在一个函数中解决 A x + b == 0
和 A^T x + b == 0
形式的问题。其次,给定的 A
和 b
实际上可以指向大于 alg 所需矩阵的内存。这可以通过LDA
和LDB
参数看出。
让事情变得复杂的是子寻址。如果您想要一个简单但可能足够有用的 API,您可以选择忽略该部分:
using ::std::size_t;
using ::std::array;
template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;
enum class TransposeMode : bool {
None = false, Transposed = true
};
// See
template<typename T> struct always_false_t : std::false_type {};
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;
template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)
{
// Since the algorithm works in place, b needs to be able to store
// both input and output
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
// SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else if constexpr (::std::is_same_v<T, double>) {
// DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
现在,可以使用 LDA
和 LDB
寻址子地址。我建议您将该部分作为您的数据类型,而不是直接作为模板签名的一部分。您希望拥有自己的矩阵类型,可以引用矩阵中的存储。也许是这样的:
// Since we store elements in a column-major order, we can always
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view { // Name derived from string_view :)
static_assert(actualRows >= rows);
T* start;
matrix_view(T* start) : start(start) {}
template<typename U, size_t r, size_t c, size_t ac>
friend class matrix_view;
public:
template<typename U>
matrix_view(matrix<U, rows, cols>& ref)
: start(ref.data()) { }
template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
auto submat() {
static_assert(colSkipped + newCols <= cols, "can only shrink");
static_assert(rowSkipped + newRows <= rows, "can only shrink");
auto newStart = start + colSkipped * actualRows + rowSkipped;
using newType = matrix_view<T, newRows, newCols, actualRows>
return newType{ newStart };
}
T* data() {
return start;
}
};
现在,您需要调整您的界面以适应这种新的数据类型,这基本上只是引入一些新参数。检查基本保持不变。
// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix {
public:
std::array<T, rows * cols> storage;
auto data() { return storage.data(); }
auto data() const { return storage.data(); }
};
extern void dgels(char TRANS
, integer M, integer N , integer NRHS
, double* A, integer LDA
, double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
, size_t rowsB, size_t colsB, size_t actualRowsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)
{
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
char transMode = mode == TransposeMode::None ? 'N' : 'T';
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
} else if constexpr (::std::is_same_v<T, double>) {
dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
// DGELS(, ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
用法示例:
int main() {
matrix<float, 5, 5> A;
matrix<float, 4, 1> b;
auto viewA = matrix_view{A}.submat<1, 1, 4, 4>();
auto viewb = matrix_view{b};
solve(viewA, viewb);
// solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
// solve(matrix_view{A}, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)
}
首先是一些改进,可以稍微简化设计并提高可读性:
不需要biggest_dim
。 std::max
从 C++14 开始就是 constexpr。你应该改用它。
不需要b_array_t
。你可以只写 std::array< REAL_T, std::max(N1, N2)>
现在解决你的问题。 C++17 中的一种好方法是:
template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
else
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");
// don't write static_assert(false)
// this would make the program ill-formed (*)
}
或者,正如@max66
所指出的
template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
}
Tadaa!! 简单、优雅、漂亮的错误消息。
constexpr if 版本和 static_assert
之间的区别,即:
void solve(...)
{
static_assert(...);
wrap::internal(...);
}
是因为只有 static_assert
编译器会尝试实例化 wrap::internal
即使 static_assert
失败,污染错误输出。使用 constexpr,如果对 wrap::internal
的调用不是主体的一部分,则条件失败,因此错误输出是干净的。
(*) 我不只写 static_asert(false, "error msg)
的原因是因为那样会使程序 ill-formed,不需要诊断。参见
如果需要,您也可以通过将模板参数移动到 non-deductible 之后来使 float
/ double
可扣除:
template < std::size_t N1, std::size_t N2, std::size_t N3, typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
所以调用变成:
solve< n1_3, n2_3>(A_3, b_3);
我正在花一些时间学习如何在 C++ 中使用模板。我从没用过 之前,我并不总是确定在不同情况下可以实现什么或不能实现什么。
作为练习,我包装了我在活动中使用的一些 Blas 和 Lapack 函数,
我目前正在研究 ?GELS
的包装(评估线性方程组的解)。
A x + b = 0
?GELS
函数(仅适用于实数值)存在两个名称:SGELS
,用于单精度向量和
DGELS
双精度。
我对接口的想法是这样的函数solve
:
const std::size_t rows = /* number of rows for A */;
const std::size_t cols = /* number of cols for A */;
std::array< double, rows * cols > A = { /* values */ };
std::array< double, ??? > b = { /* values */ }; // ??? it can be either
// rows or cols. It depends on user
// problem, in general
// max( dim(x), dim(b) ) =
// max( cols, rows )
solve< double, rows, cols >(A, b);
// the solution x is stored in b, thus b
// must be "large" enough to accomodate x
根据用户要求,问题可能是多定的或未定的,即:
- 如果是超定的
dim(b) > dim(x)
(解是伪逆) - 如果不确定
dim(b) < dim(x)
(解决方案是LSQ最小化) - 或者
dim(b) = dim(x)
的正常情况(解是A
的逆)
(不考虑个别情况)
由于 ?GELS
将结果存储在输入向量 b
中,因此 std::array
应该
有足够的 space 来容纳解决方案,如代码注释 (max(rows, cols)
) 中所述。
我想(编译时)确定采用哪种解决方案(这是一个参数更改
在 ?GELS
通话中)。我有两个功能(为了这个问题我正在简化),
处理精度并且已经知道哪个是 b
的维度和 rows
/cols
:
namespace wrap {
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) {
SGELS(/* Called in the right way */);
}
template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) {
DGELS(/* Called in the right way */);
}
}; /* namespace wrap */
它们是内部包装的一部分。用户函数,确定所需的大小
在通过模板的 b
向量中:
#include <type_traits>
/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim {
static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
std::integral_constant< std::size_t, cols > >::type::value;
};
/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;
/** Here we have the function that allows only the call with b of
* the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) {
static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
}
这样确实有效。但是我想更进一步,我真的不知道该怎么做。
如果用户试图调用 solve
并且 b
的大小太小,编译器会引发一个极其难以阅读的错误。
我正在尝试插入
static_assert
帮助用户理解他的错误。但我脑海中出现的任何方向
需要使用两个具有相同签名的函数(这就像模板重载?)
我找不到 SFINAE 策略(实际上它们根本不编译)。
您认为在 编译时 b
维度 不更改用户界面 的情况下可以提出静态断言吗?
我希望问题足够清楚。
@Caninonos:对我来说,用户界面就是用户调用求解器的方式,即:
solve< type, number of rows, number of cols > (matrix A, vector b)
这是我为了提高自己的技能而对锻炼施加的约束。这意味着,我不知道是否真的有可能实现解决方案。 b
的类型必须与函数调用匹配,如果我添加另一个模板参数并更改用户界面,这很容易违反我的约束。
最小的完整和工作示例
这是一个最小的完整的工作示例。根据要求,我删除了对线性代数概念的任何引用。是个数的问题。个案为:
N1 = 2, N2 =2
。由于N3 = max(N1, N2) = 2
一切正常N1 = 2, N2 =1
。由于N3 = max(N1, N2) = N1 = 2
一切正常N1 = 1, N2 =2
。由于N3 = max(N1, N2) = N2 = 2
一切正常N1 = 1, N2 =2
。由于N3 = N1 = 1 < N2
它正确地引发了编译错误。我想用一个静态断言来拦截编译错误,该断言解释了N3
的维数错误的事实。至于现在的错误很难阅读和理解。
您为什么不尝试将 tag dispatch 与一些 static_assert
组合在一起?我希望以下是实现您想要解决的问题的一种方法。我的意思是,所有三个正确的案例都正确地传递给了正确的 blas
调用,处理了不同的类型和维度不匹配,并且还处理了关于 float
和 double
s 的违规,所有以 user-friendly 的方式,感谢 static_assert
。
编辑。 我不确定你的 C++
版本要求,但下面是 C++11
友好的。
#include <algorithm>
#include <iostream>
#include <type_traits>
template <class value_t, int nrows, int ncols> struct Matrix {};
template <class value_t, int rows> struct Vector {};
template <class value_t> struct blas;
template <> struct blas<float> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
template <> struct blas<double> {
static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};
class overdet {};
class underdet {};
class normal {};
template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) {
static_assert(std::is_same<T1, T2>::value,
"lhs and rhs must have the same value types");
static_assert(dim >= nrows && dim >= ncols,
"rhs does not have enough space");
static_assert(std::is_same<T1, float>::value ||
std::is_same<T1, double>::value,
"Only float or double are accepted");
solve_impl(lhs, rhs,
typename std::conditional<(nrows < ncols), underdet,
typename std::conditional<(nrows > ncols), overdet,
normal>::type>::type{});
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, underdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::underdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, overdet) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::overdet(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
Vector<value_t, dim> &rhs, normal) {
/* get the pointers and dimension information from lhs and rhs */
blas<value_t>::normal(
/* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}
int main() {
/* valid types */
Matrix<float, 2, 4> A1;
Matrix<float, 4, 4> A2;
Matrix<float, 5, 4> A3;
Vector<float, 4> b1;
Vector<float, 5> b2;
solve(A1, b1);
solve(A2, b1);
solve(A3, b2);
Matrix<int, 4, 4> A4;
Vector<int, 4> b3;
// solve(A4, b3); // static_assert for float & double
Matrix<float, 4, 4> A5;
Vector<int, 4> b4;
// solve(A5, b4); // static_assert for different types
// solve(A3, b1); // static_assert for dimension problem
return 0;
}
您必须考虑为什么 界面提供这些(令人费解的)乱七八糟的参数。作者想到了几件事。首先,您可以在一个函数中解决 A x + b == 0
和 A^T x + b == 0
形式的问题。其次,给定的 A
和 b
实际上可以指向大于 alg 所需矩阵的内存。这可以通过LDA
和LDB
参数看出。
让事情变得复杂的是子寻址。如果您想要一个简单但可能足够有用的 API,您可以选择忽略该部分:
using ::std::size_t;
using ::std::array;
template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;
enum class TransposeMode : bool {
None = false, Transposed = true
};
// See
template<typename T> struct always_false_t : std::false_type {};
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;
template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)
{
// Since the algorithm works in place, b needs to be able to store
// both input and output
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
// SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else if constexpr (::std::is_same_v<T, double>) {
// DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
现在,可以使用 LDA
和 LDB
寻址子地址。我建议您将该部分作为您的数据类型,而不是直接作为模板签名的一部分。您希望拥有自己的矩阵类型,可以引用矩阵中的存储。也许是这样的:
// Since we store elements in a column-major order, we can always
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view { // Name derived from string_view :)
static_assert(actualRows >= rows);
T* start;
matrix_view(T* start) : start(start) {}
template<typename U, size_t r, size_t c, size_t ac>
friend class matrix_view;
public:
template<typename U>
matrix_view(matrix<U, rows, cols>& ref)
: start(ref.data()) { }
template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
auto submat() {
static_assert(colSkipped + newCols <= cols, "can only shrink");
static_assert(rowSkipped + newRows <= rows, "can only shrink");
auto newStart = start + colSkipped * actualRows + rowSkipped;
using newType = matrix_view<T, newRows, newCols, actualRows>
return newType{ newStart };
}
T* data() {
return start;
}
};
现在,您需要调整您的界面以适应这种新的数据类型,这基本上只是引入一些新参数。检查基本保持不变。
// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix {
public:
std::array<T, rows * cols> storage;
auto data() { return storage.data(); }
auto data() const { return storage.data(); }
};
extern void dgels(char TRANS
, integer M, integer N , integer NRHS
, double* A, integer LDA
, double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
, size_t rowsB, size_t colsB, size_t actualRowsB
, TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)
{
static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
char transMode = mode == TransposeMode::None ? 'N' : 'T';
// LDA = rowsA, LDB = rowsB
if constexpr (::std::is_same_v<T, float>) {
fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
} else if constexpr (::std::is_same_v<T, double>) {
dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
// DGELS(, ....);
} else {
static_assert(always_false_v<T>, "Unknown type");
}
}
用法示例:
int main() {
matrix<float, 5, 5> A;
matrix<float, 4, 1> b;
auto viewA = matrix_view{A}.submat<1, 1, 4, 4>();
auto viewb = matrix_view{b};
solve(viewA, viewb);
// solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
// solve(matrix_view{A}, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)
}
首先是一些改进,可以稍微简化设计并提高可读性:
不需要
biggest_dim
。std::max
从 C++14 开始就是 constexpr。你应该改用它。不需要
b_array_t
。你可以只写std::array< REAL_T, std::max(N1, N2)>
现在解决你的问题。 C++17 中的一种好方法是:
template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
else
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");
// don't write static_assert(false)
// this would make the program ill-formed (*)
}
或者,正如@max66
所指出的template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");
if constexpr (N3 == std::max(N1, N2))
wrap::internal< N1, N2, N3 >(A, b);
}
Tadaa!! 简单、优雅、漂亮的错误消息。
constexpr if 版本和 static_assert
之间的区别,即:
void solve(...)
{
static_assert(...);
wrap::internal(...);
}
是因为只有 static_assert
编译器会尝试实例化 wrap::internal
即使 static_assert
失败,污染错误输出。使用 constexpr,如果对 wrap::internal
的调用不是主体的一部分,则条件失败,因此错误输出是干净的。
(*) 我不只写 static_asert(false, "error msg)
的原因是因为那样会使程序 ill-formed,不需要诊断。参见
如果需要,您也可以通过将模板参数移动到 non-deductible 之后来使 float
/ double
可扣除:
template < std::size_t N1, std::size_t N2, std::size_t N3, typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {
所以调用变成:
solve< n1_3, n2_3>(A_3, b_3);