32 #ifndef VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP 33 #define VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP 53 template <
typename RealType, std::
size_t Dim>
54 class NormalMVDistribution
64 "**NormalMVDistributon::param_type** USED WITH RealType OTHER " 65 "THAN float OR double");
74 , null_mean_(
mean == nullptr)
75 , null_chol_(
chol == nullptr)
77 static_assert(Dim !=
Dynamic,
"**NormalMVDistribution::param_type*" 78 "* OBJECT DECLARED WITH DYNAMIC " 87 , chol_(dim * (dim + 1) / 2)
88 , null_mean_(
mean == nullptr)
89 , null_chol_(
chol == nullptr)
91 static_assert(Dim ==
Dynamic,
"**NormalMVDistribution::param_type*" 92 "* OBJECT DECLARED WITH FIXED " 97 std::size_t
dim()
const {
return mean_.size(); }
106 if (param1.norm_ != param2.norm_)
108 if (param1.mean_ != param2.mean_)
110 if (param1.chol_ != param2.chol_)
112 if (param1.null_mean_ != param2.null_mean_)
114 if (param1.null_chol_ != param2.null_chol_)
122 return !(param1 == param2);
125 template <
typename CharT,
typename Traits>
132 os << param.norm_ <<
' ';
133 os << param.
dim() <<
' ';
134 os << param.mean_ <<
' ';
135 os << param.chol_ <<
' ';
136 os << param.null_mean_ <<
' ';
137 os << param.null_chol_;
142 template <
typename CharT,
typename Traits>
155 is >> std::ws >> rnorm;
160 is >> std::ws >>
dim;
166 is >> std::ws >>
mean;
167 is >> std::ws >>
chol;
168 is >> std::ws >> null_mean;
169 is >> std::ws >> null_chol;
172 param.rnorm_ = std::move(rnorm);
173 param.mean_ = std::move(mean);
174 param.chol_ = std::move(chol);
175 param.null_mean_ = null_mean;
176 param.null_chol_ = null_chol;
178 is.setstate(std::ios_base::failbit);
193 void init(
const result_type *
mean,
const result_type *
chol)
196 std::fill(mean_.begin(), mean_.end(), 0);
198 std::copy_n(mean, mean_.size(), mean_.begin());
201 std::fill(chol_.begin(), chol_.end(), 0);
203 std::copy_n(chol, chol_.size(), chol_.begin());
206 for (std::size_t i = 0; i != mean_.size(); ++i)
207 chol_[i * (i + 1) / 2 + i] = 1;
235 std::fill_n(x,
dim(), std::numeric_limits<result_type>::lowest());
240 std::fill_n(x,
dim(), std::numeric_limits<result_type>::max());
243 void reset() { param_.rnorm_.reset(); }
245 std::size_t
dim()
const {
return param_.dim(); }
261 param_ = std::move(
param);
265 template <
typename RNGType>
271 template <
typename RNGType>
274 generate(rng, r, param);
277 template <
typename RNGType>
283 template <
typename RNGType>
288 (param.null_mean_ ? param.
mean() :
nullptr),
289 (param.null_chol_ ? param.
chol() :
nullptr));
295 if (dist1.param_ != dist2.param_)
303 return !(dist1 == dist2);
306 template <
typename CharT,
typename Traits>
313 os << dist.param_ <<
' ';
318 template <
typename CharT,
typename Traits>
326 is >> std::ws >>
param;
328 dist.param_ = std::move(param);
336 template <
typename RNGType>
339 param_.rnorm_(rng, param.
dim(), r);
340 if (!param.null_chol_)
342 if (!param.null_mean_)
346 void mulchol(
float *r,
const param_type ¶m)
348 ::cblas_stpmv(::CblasRowMajor, ::CblasLower, ::CblasNoTrans,
349 ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(
dim()), param.
chol(),
353 void mulchol(
double *r,
const param_type ¶m)
355 ::cblas_dtpmv(::CblasRowMajor, ::CblasLower, ::CblasNoTrans,
356 ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(
dim()), param.
chol(),
365 std::size_t n,
float *r, std::size_t m,
const float *
chol)
367 ::cblas_strmm(::CblasRowMajor, ::CblasRight, ::CblasLower, ::CblasTrans,
368 ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(n),
369 static_cast<VSMC_CBLAS_INT>(m), 1, chol,
370 static_cast<VSMC_CBLAS_INT>(m), r, static_cast<VSMC_CBLAS_INT>(m));
374 std::size_t n,
double *r, std::size_t m,
const double *
chol)
376 ::cblas_dtrmm(::CblasRowMajor, ::CblasRight, ::CblasLower, ::CblasTrans,
377 ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(n),
378 static_cast<VSMC_CBLAS_INT>(m), 1, chol,
379 static_cast<VSMC_CBLAS_INT>(m), r, static_cast<VSMC_CBLAS_INT>(m));
386 template <
typename RealType,
typename RNGType>
388 std::size_t
dim,
const RealType *
mean,
const RealType *
chol)
391 "**normal_mv_distribution** USED WITH RealType OTHER THAN float OR " 395 if (chol !=
nullptr) {
397 for (std::size_t i = 0; i !=
dim; ++i)
398 for (std::size_t j = 0; j <= i; ++j)
399 cholf[i * dim + j] = *chol++;
403 for (std::size_t i = 0; i != n; ++i, r +=
dim)
404 add(dim, mean, r, r);
407 template <
typename RealType,
typename RNGType>
414 template <
typename RealType, std::
size_t Dim,
typename RNGType>
416 std::size_t n, RealType *r)
423 #endif // VSMC_RNG_NORMAL_DISTRIBUTION_HPP
void param(const param_type ¶m)
const result_type * mean() const
param_type(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
NormalMVDistribution(const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim > 0
param_type(const result_type *mean=nullptr, const result_type *chol=nullptr)
typename std::conditional< std::is_scalar< T >::value, AlignedVector< T >, std::vector< T >>::type Vector
AlignedVector for scalar type and std::vector for others.
void normal_mv_distribution(RNGType &, std::size_t, RealType *, std::size_t, const RealType *, const RealType *)
Generating multivariate Normal random varaites.
friend std::basic_ostream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, const param_type ¶m)
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, distribution_type &dist)
void rng_rand(RNGType &rng, BetaDistribution< RealType > &dist, std::size_t n, RealType *r)
void normal_mv_distribution_mulchol(std::size_t n, float *r, std::size_t m, const float *chol)
const result_type * chol() const
void operator()(RNGType &rng, std::size_t n, result_type *r)
void max(result_type *x) const
friend bool operator!=(const distribution_type &dist1, const distribution_type &dist2)
friend bool operator==(const distribution_type &dist1, const distribution_type &dist2)
void normal_distribution(RNGType &, std::size_t, RealType *, RealType, RealType)
Generating Normal random variates.
void min(result_type *x) const
NormalMVDistribution(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim == Dynamic
void operator()(RNGType &rng, result_type *r, const param_type ¶m)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const distribution_type &dist)
Multivariate Normal distribution.
friend bool operator==(const param_type ¶m1, const param_type ¶m2)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const param_type ¶m)
void resize(std::array< T, N > &, std::size_t)
void operator()(RNGType &rng, std::size_t n, result_type *r, const param_type ¶m)
const result_type * mean() const
typename std::conditional< Dim==Dynamic, Vector< T >, std::array< T, Dim >>::type Array
void normal_mv_distribution_mulchol(std::size_t n, double *r, std::size_t m, const double *chol)
void param(param_type &¶m)
void add(std::size_t n, const float *a, const float *b, float *y)
NormalMVDistribution< RealType, Dim > distribution_type
friend bool operator!=(const param_type ¶m1, const param_type ¶m2)
void operator()(RNGType &rng, result_type *r)
const result_type * chol() const
NormalMVDistribution< RealType, Dim > distribution_type