From c3d8136604eecad53171526f018b9604d6ebb31a Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 18 Jul 2024 11:57:50 -0400 Subject: [PATCH] update columns dot product for complex types --- stan/math/rev/fun/columns_dot_product.hpp | 33 ++++++++++++++++++++++- stan/math/rev/meta/return_var_matrix.hpp | 7 +++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/stan/math/rev/fun/columns_dot_product.hpp b/stan/math/rev/fun/columns_dot_product.hpp index 89c5b4784c5..d1f522de93c 100644 --- a/stan/math/rev/fun/columns_dot_product.hpp +++ b/stan/math/rev/fun/columns_dot_product.hpp @@ -14,6 +14,35 @@ namespace stan { namespace math { +/** + * Returns the dot product of columns of the specified matrices. + * + * @tparam Mat1 type of the first matrix (must be derived from \c + * Eigen::MatrixBase) + * @tparam Mat2 type of the second matrix (must be derived from \c + * Eigen::MatrixBase) + * + * @param v1 Matrix of first vectors. + * @param v2 Matrix of second vectors. + * @return Dot product of the vectors. + * @throw std::domain_error If the vectors are not the same + * size or if they are both not vector dimensioned. + */ +template * = nullptr, + require_any_eigen_vt* = nullptr, + require_vt_complex* = nullptr, + require_vt_complex* = nullptr> +inline Eigen::Matrix, 1, Mat1::ColsAtCompileTime> +columns_dot_product(const Mat1& v1, const Mat2& v2) { + check_matching_sizes("dot_product", "v1", v1, "v2", v2); + Eigen::Matrix ret(1, v1.cols()); + for (size_type j = 0; j < v1.cols(); ++j) { + ret.coeffRef(j) = dot_product(v1.col(j), v2.col(j)); + } + return ret; +} + /** * Returns the dot product of columns of the specified matrices. * @@ -27,7 +56,9 @@ namespace math { * size or if they are both not vector dimensioned. */ template * = nullptr> + require_all_matrix_t* = nullptr, + require_not_st_complex* = nullptr, + require_not_st_complex* = nullptr> inline auto columns_dot_product(Mat1&& v1, Mat2&& v2) { check_matching_sizes("columns_dot_product", "v1", v1, "v2", v2); using inner_return_t = decltype( diff --git a/stan/math/rev/meta/return_var_matrix.hpp b/stan/math/rev/meta/return_var_matrix.hpp index 153dfa05201..2e33c596616 100644 --- a/stan/math/rev/meta/return_var_matrix.hpp +++ b/stan/math/rev/meta/return_var_matrix.hpp @@ -23,8 +23,11 @@ using return_var_matrix_t = std::conditional_t< is_any_var_matrix::value, stan::math::var_value< stan::math::promote_scalar_t>>, - stan::math::promote_scalar_t, - plain_type_t>>; + std::conditional_t>::value, + stan::math::promote_scalar_t>, + plain_type_t>, + stan::math::promote_scalar_t, + plain_type_t>>>; } // namespace stan #endif