Skip to content

Commit

Permalink
update columns dot product for complex types
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jul 18, 2024
1 parent a3f3cd8 commit c3d8136
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
33 changes: 32 additions & 1 deletion stan/math/rev/fun/columns_dot_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@
namespace stan {
namespace math {

/**
* Returns the dot product of columns of the specified matrices.
*
* @tparam Mat1 type of the first matrix (must be derived from \c
* Eigen::MatrixBase)
* @tparam Mat2 type of the second matrix (must be derived from \c
* Eigen::MatrixBase)
*
* @param v1 Matrix of first vectors.
* @param v2 Matrix of second vectors.
* @return Dot product of the vectors.
* @throw std::domain_error If the vectors are not the same
* size or if they are both not vector dimensioned.
*/
template <typename Mat1, typename Mat2,
require_all_eigen_t<Mat1, Mat2>* = nullptr,
require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr,
require_vt_complex<Mat1>* = nullptr,
require_vt_complex<Mat2>* = nullptr>
inline Eigen::Matrix<return_type_t<Mat1, Mat2>, 1, Mat1::ColsAtCompileTime>
columns_dot_product(const Mat1& v1, const Mat2& v2) {
check_matching_sizes("dot_product", "v1", v1, "v2", v2);
Eigen::Matrix<var, 1, Mat1::ColsAtCompileTime> ret(1, v1.cols());
for (size_type j = 0; j < v1.cols(); ++j) {
ret.coeffRef(j) = dot_product(v1.col(j), v2.col(j));
}
return ret;
}

/**
* Returns the dot product of columns of the specified matrices.
*
Expand All @@ -27,7 +56,9 @@ namespace math {
* size or if they are both not vector dimensioned.
*/
template <typename Mat1, typename Mat2,
require_all_matrix_t<Mat1, Mat2>* = nullptr>
require_all_matrix_t<Mat1, Mat2>* = nullptr,
require_not_st_complex<Mat1>* = nullptr,
require_not_st_complex<Mat2>* = nullptr>
inline auto columns_dot_product(Mat1&& v1, Mat2&& v2) {
check_matching_sizes("columns_dot_product", "v1", v1, "v2", v2);
using inner_return_t = decltype(
Expand Down
7 changes: 5 additions & 2 deletions stan/math/rev/meta/return_var_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ using return_var_matrix_t = std::conditional_t<
is_any_var_matrix<ReturnType, Types...>::value,
stan::math::var_value<
stan::math::promote_scalar_t<double, plain_type_t<ReturnType>>>,
stan::math::promote_scalar_t<stan::math::var_value<double>,
plain_type_t<ReturnType>>>;
std::conditional_t<is_complex<scalar_type_t<ReturnType>>::value,
stan::math::promote_scalar_t<std::complex<stan::math::var_value<double>>,
plain_type_t<ReturnType>>,
stan::math::promote_scalar_t<stan::math::var_value<double>,
plain_type_t<ReturnType>>>>;
} // namespace stan

#endif

0 comments on commit c3d8136

Please sign in to comment.