32 #ifndef VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP 33 #define VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP 53 template <
typename RealType, std::
size_t Dim>
54 class NormalMVDistribution
64 "**NormalMVDistributon::param_type** USED WITH RealType OTHER " 65 "THAN float OR double");
74 , null_mean_(
mean == nullptr)
75 , null_chol_(
chol == nullptr)
77 static_assert(Dim !=
Dynamic,
"**NormalMVDistribution::param_type*" 78 "* OBJECT DECLARED WITH DYNAMIC " 87 , chol_(dim * (dim + 1) / 2)
88 , null_mean_(
mean == nullptr)
89 , null_chol_(
chol == nullptr)
91 static_assert(Dim ==
Dynamic,
"**NormalMVDistribution::param_type*" 92 "* OBJECT DECLARED WITH FIXED " 97 std::size_t
dim()
const {
return mean_.size(); }
106 if (param1.norm_ != param2.norm_)
108 if (param1.mean_ != param2.mean_)
110 if (param1.chol_ != param2.chol_)
112 if (param1.null_mean_ != param2.null_mean_)
114 if (param1.null_chol_ != param2.null_chol_)
122 return !(param1 == param2);
125 template <
typename CharT,
typename Traits>
132 os << param.norm_ <<
' ';
133 os << param.
dim() <<
' ';
134 os << param.mean_ <<
' ';
135 os << param.chol_ <<
' ';
136 os << param.null_mean_ <<
' ';
137 os << param.null_chol_;
142 template <
typename CharT,
typename Traits>
151 is >> std::ws >> rnorm;
156 is >> std::ws >>
dim;
162 dim * (dim + 1) / 2);
166 is >> std::ws >>
mean;
167 is >> std::ws >>
chol;
168 is >> std::ws >> null_mean;
169 is >> std::ws >> null_chol;
172 param.rnorm_ = std::move(rnorm);
173 std::move(mean.begin(), mean.end(), param.mean_.begin());
174 std::move(chol.begin(), chol.end(), param.chol_.begin());
175 param.null_mean_ = null_mean;
176 param.null_chol_ = null_chol;
178 is.setstate(std::ios_base::failbit);
193 void init(
const result_type *
mean,
const result_type *
chol)
196 std::fill(mean_.begin(), mean_.end(), 0);
198 std::copy_n(mean, mean_.size(), mean_.begin());
201 std::fill(chol_.begin(), chol_.end(), 0);
203 std::copy_n(chol, chol_.size(), chol_.begin());
206 for (std::size_t i = 0; i != mean_.size(); ++i)
207 chol_[i * (i + 1) / 2 + i] = 1;
222 internal::size_check<VSMC_BLAS_INT>(
223 Dim,
"NormalMVDistribution::NormalMVDistribution");
232 internal::size_check<VSMC_BLAS_INT>(
233 dim,
"NormalMVDistribution::NormalMVDistribution");
239 std::fill_n(x,
dim(), std::numeric_limits<result_type>::lowest());
244 std::fill_n(x,
dim(), std::numeric_limits<result_type>::max());
247 void reset() { param_.rnorm_.reset(); }
249 std::size_t
dim()
const {
return param_.dim(); }
265 param_ = std::move(
param);
269 template <
typename RNGType>
275 template <
typename RNGType>
278 generate(rng, r, param);
281 template <
typename RNGType>
287 template <
typename RNGType>
292 (param.null_mean_ ? param.
mean() :
nullptr),
293 (param.null_chol_ ? param.
chol() :
nullptr));
299 if (dist1.param_ != dist2.param_)
307 return !(dist1 == dist2);
310 template <
typename CharT,
typename Traits>
322 template <
typename CharT,
typename Traits>
330 is >> std::ws >>
param;
332 dist.param_ = std::move(param);
340 template <
typename RNGType>
343 param_.rnorm_(rng, param.
dim(), r);
344 if (!param.null_chol_)
346 if (!param.null_mean_)
350 void mulchol(
float *r,
const param_type ¶m)
352 internal::cblas_stpmv(internal::CblasRowMajor, internal::CblasLower,
353 internal::CblasNoTrans, internal::CblasNonUnit,
354 static_cast<VSMC_BLAS_INT>(
dim()), param.
chol(), r, 1);
357 void mulchol(
double *r,
const param_type ¶m)
359 internal::cblas_dtpmv(internal::CblasRowMajor, internal::CblasLower,
360 internal::CblasNoTrans, internal::CblasNonUnit,
361 static_cast<VSMC_BLAS_INT>(
dim()), param.
chol(), r, 1);
369 std::size_t n,
float *r, std::size_t
dim,
const float *
chol)
371 cblas_strmm(CblasRowMajor, CblasRight, CblasLower, CblasTrans,
372 CblasNonUnit, static_cast<VSMC_BLAS_INT>(n),
373 static_cast<VSMC_BLAS_INT>(dim), 1, chol,
374 static_cast<VSMC_BLAS_INT>(dim), r, static_cast<VSMC_BLAS_INT>(dim));
378 std::size_t n,
double *r, std::size_t
dim,
const double *
chol)
380 cblas_dtrmm(CblasRowMajor, CblasRight, CblasLower, CblasTrans,
381 CblasNonUnit, static_cast<VSMC_BLAS_INT>(n),
382 static_cast<VSMC_BLAS_INT>(dim), 1, chol,
383 static_cast<VSMC_BLAS_INT>(dim), r, static_cast<VSMC_BLAS_INT>(dim));
390 template <
typename RealType,
typename RNGType>
392 std::size_t
dim,
const RealType *
mean,
const RealType *
chol)
395 "**normal_mv_distribution** USED WITH RealType OTHER THAN float OR " 398 internal::size_check<VSMC_BLAS_INT>(n,
"normal_mv_distribution");
399 internal::size_check<VSMC_BLAS_INT>(
dim,
"normal_mv_distribution");
402 if (chol !=
nullptr) {
404 for (std::size_t i = 0; i !=
dim; ++i)
405 for (std::size_t j = 0; j <= i; ++j)
406 cholf[i * dim + j] = *chol++;
410 for (std::size_t i = 0; i != n; ++i, r +=
dim)
411 add(dim, mean, r, r);
414 template <
typename RealType,
typename RNGType>
421 template <
typename RealType, std::
size_t Dim,
typename RNGType>
423 std::size_t n, RealType *r)
430 #endif // VSMC_RNG_NORMAL_DISTRIBUTION_HPP std::vector< T, Alloc > Vector
std::vector with Allocator as default allocator
void param(const param_type ¶m)
const result_type * mean() const
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, const param_type ¶m)
param_type(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
NormalMVDistribution(const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim > 0
param_type(const result_type *mean=nullptr, const result_type *chol=nullptr)
void normal_mv_distribution(RNGType &, std::size_t, RealType *, std::size_t, const RealType *, const RealType *)
Generating multivariate Normal random varaites.
void normal_mv_distribution_mulchol(std::size_t n, double *r, std::size_t dim, const double *chol)
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, distribution_type &dist)
const result_type * chol() const
void operator()(RNGType &rng, std::size_t n, result_type *r)
void max(result_type *x) const
friend bool operator!=(const distribution_type &dist1, const distribution_type &dist2)
friend bool operator==(const distribution_type &dist1, const distribution_type &dist2)
void normal_distribution(RNGType &, std::size_t, RealType *, RealType, RealType)
Generating Normal random variates.
void min(result_type *x) const
void normal_mv_distribution_mulchol(std::size_t n, float *r, std::size_t dim, const float *chol)
NormalMVDistribution(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim == Dynamic
void operator()(RNGType &rng, result_type *r, const param_type ¶m)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const distribution_type &dist)
Multivariate Normal distribution.
friend bool operator==(const param_type ¶m1, const param_type ¶m2)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const param_type ¶m)
void operator()(RNGType &rng, std::size_t n, result_type *r, const param_type ¶m)
const result_type * mean() const
void rand(RNGType &rng, ArcsineDistribution< RealType > &dist, std::size_t N, RealType *r)
void param(param_type &¶m)
void add(std::size_t n, const float *a, const float *b, float *y)
NormalMVDistribution< RealType, Dim > distribution_type
typename std::conditional< N==Dynamic, Vector< T >, std::array< T, N >>::type StaticVector
friend bool operator!=(const param_type ¶m1, const param_type ¶m2)
void operator()(RNGType &rng, result_type *r)
const result_type * chol() const
NormalMVDistribution< RealType, Dim > distribution_type