vSMC  v3.0.0
Scalable Monte Carlo
normal_mv_distribution.hpp
Go to the documentation of this file.
1 //============================================================================
2 // vSMC/include/vsmc/rng/normal_mv_distribution.hpp
3 //----------------------------------------------------------------------------
4 // vSMC: Scalable Monte Carlo
5 //----------------------------------------------------------------------------
6 // Copyright (c) 2013-2016, Yan Zhou
7 // All rights reserved.
8 //
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 //
12 // Redistributions of source code must retain the above copyright notice,
13 // this list of conditions and the following disclaimer.
14 //
15 // Redistributions in binary form must reproduce the above copyright notice,
16 // this list of conditions and the following disclaimer in the documentation
17 // and/or other materials provided with the distribution.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
20 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 // POSSIBILITY OF SUCH DAMAGE.
30 //============================================================================
31 
32 #ifndef VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP
33 #define VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP
34 
37 
38 namespace vsmc
39 {
40 
53 template <typename RealType, std::size_t Dim>
54 class NormalMVDistribution
55 {
56 
57  public:
58  using result_type = RealType;
60 
61  class param_type
62  {
64  "**NormalMVDistributon::param_type** USED WITH RealType OTHER "
65  "THAN float OR double");
66 
67  public:
68  using result_type = RealType;
70 
71  explicit param_type(const result_type *mean = nullptr,
72  const result_type *chol = nullptr)
73  : rnorm_(0, 1)
74  , null_mean_(mean == nullptr)
75  , null_chol_(chol == nullptr)
76  {
77  static_assert(Dim != Dynamic, "**NormalMVDistribution::param_type*"
78  "* OBJECT DECLARED WITH DYNAMIC "
79  "DIMENSION");
80  init(mean, chol);
81  }
82 
83  explicit param_type(std::size_t dim, const result_type *mean = nullptr,
84  const result_type *chol = nullptr)
85  : rnorm_(0, 1)
86  , mean_(dim)
87  , chol_(dim * (dim + 1) / 2)
88  , null_mean_(mean == nullptr)
89  , null_chol_(chol == nullptr)
90  {
91  static_assert(Dim == Dynamic, "**NormalMVDistribution::param_type*"
92  "* OBJECT DECLARED WITH FIXED "
93  "DIMENSION");
94  init(mean, chol);
95  }
96 
97  std::size_t dim() const { return mean_.size(); }
98 
99  const result_type *mean() const { return mean_.data(); }
100 
101  const result_type *chol() const { return chol_.data(); }
102 
103  friend bool operator==(
104  const param_type &param1, const param_type &param2)
105  {
106  if (param1.norm_ != param2.norm_)
107  return false;
108  if (param1.mean_ != param2.mean_)
109  return false;
110  if (param1.chol_ != param2.chol_)
111  return false;
112  if (param1.null_mean_ != param2.null_mean_)
113  return false;
114  if (param1.null_chol_ != param2.null_chol_)
115  return false;
116  return true;
117  }
118 
119  friend bool operator!=(
120  const param_type &param1, const param_type &param2)
121  {
122  return !(param1 == param2);
123  }
124 
125  template <typename CharT, typename Traits>
126  friend std::basic_ostream<CharT, Traits> &operator<<(
127  std::basic_ostream<CharT, Traits> &os, const param_type &param)
128  {
129  if (!os)
130  return os;
131 
132  os << param.norm_ << ' ';
133  os << param.dim() << ' ';
134  os << param.mean_ << ' ';
135  os << param.chol_ << ' ';
136  os << param.null_mean_ << ' ';
137  os << param.null_chol_;
138 
139  return os;
140  }
141 
142  template <typename CharT, typename Traits>
143  friend std::basic_istream<CharT, Traits> &operator>>(
144  std::basic_istream<CharT, Traits> &is, const param_type &param)
145  {
146  if (!is)
147  return is;
148 
150 
151  is >> std::ws >> rnorm;
152  if (!is)
153  return is;
154 
155  std::size_t dim = 0;
156  is >> std::ws >> dim;
157  if (!is)
158  return is;
159 
161  internal::StaticVector<result_type, Dim *(Dim + 1) / 2> chol(
162  dim * (dim + 1) / 2);
163  bool null_mean;
164  bool null_chol;
165 
166  is >> std::ws >> mean;
167  is >> std::ws >> chol;
168  is >> std::ws >> null_mean;
169  is >> std::ws >> null_chol;
170 
171  if (is) {
172  param.rnorm_ = std::move(rnorm);
173  std::move(mean.begin(), mean.end(), param.mean_.begin());
174  std::move(chol.begin(), chol.end(), param.chol_.begin());
175  param.null_mean_ = null_mean;
176  param.null_chol_ = null_chol;
177  } else {
178  is.setstate(std::ios_base::failbit);
179  }
180 
181  return is;
182  }
183 
184  private:
187  internal::StaticVector<result_type, Dim *(Dim + 1) / 2> chol_;
188  bool null_mean_;
189  bool null_chol_;
190 
191  friend distribution_type;
192 
193  void init(const result_type *mean, const result_type *chol)
194  {
195  if (mean == nullptr)
196  std::fill(mean_.begin(), mean_.end(), 0);
197  else
198  std::copy_n(mean, mean_.size(), mean_.begin());
199 
200  if (chol == nullptr)
201  std::fill(chol_.begin(), chol_.end(), 0);
202  else
203  std::copy_n(chol, chol_.size(), chol_.begin());
204 
205  if (chol == nullptr)
206  for (std::size_t i = 0; i != mean_.size(); ++i)
207  chol_[i * (i + 1) / 2 + i] = 1;
208  }
209  }; // class param_type
210 
219  const result_type *mean = nullptr, const result_type *chol = nullptr)
220  : param_(mean, chol)
221  {
222  internal::size_check<VSMC_BLAS_INT>(
223  Dim, "NormalMVDistribution::NormalMVDistribution");
224  reset();
225  }
226 
228  explicit NormalMVDistribution(std::size_t dim,
229  const result_type *mean = nullptr, const result_type *chol = nullptr)
230  : param_(dim, mean, chol)
231  {
232  internal::size_check<VSMC_BLAS_INT>(
233  dim, "NormalMVDistribution::NormalMVDistribution");
234  reset();
235  }
236 
237  void min(result_type *x) const
238  {
239  std::fill_n(x, dim(), std::numeric_limits<result_type>::lowest());
240  }
241 
242  void max(result_type *x) const
243  {
244  std::fill_n(x, dim(), std::numeric_limits<result_type>::max());
245  }
246 
247  void reset() { param_.rnorm_.reset(); }
248 
249  std::size_t dim() const { return param_.dim(); }
250 
251  const result_type *mean() const { return param_.mean(); }
252 
253  const result_type *chol() const { return param_.chol(); }
254 
255  param_type param() const { return param_; }
256 
257  void param(const param_type &param)
258  {
259  param_ = param;
260  reset();
261  }
262 
264  {
265  param_ = std::move(param);
266  reset();
267  }
268 
269  template <typename RNGType>
270  void operator()(RNGType &rng, result_type *r)
271  {
272  operator()(rng, r, param_);
273  }
274 
275  template <typename RNGType>
276  void operator()(RNGType &rng, result_type *r, const param_type &param)
277  {
278  generate(rng, r, param);
279  }
280 
281  template <typename RNGType>
282  void operator()(RNGType &rng, std::size_t n, result_type *r)
283  {
284  operator()(rng, n, r, param_);
285  }
286 
287  template <typename RNGType>
289  RNGType &rng, std::size_t n, result_type *r, const param_type &param)
290  {
291  normal_mv_distribution(rng, n, r, param.dim(),
292  (param.null_mean_ ? param.mean() : nullptr),
293  (param.null_chol_ ? param.chol() : nullptr));
294  }
295 
296  friend bool operator==(
297  const distribution_type &dist1, const distribution_type &dist2)
298  {
299  if (dist1.param_ != dist2.param_)
300  return false;
301  return true;
302  }
303 
304  friend bool operator!=(
305  const distribution_type &dist1, const distribution_type &dist2)
306  {
307  return !(dist1 == dist2);
308  }
309 
310  template <typename CharT, typename Traits>
311  friend std::basic_ostream<CharT, Traits> &operator<<(
312  std::basic_ostream<CharT, Traits> &os, const distribution_type &dist)
313  {
314  if (!os)
315  return os;
316 
317  os << dist.param_;
318 
319  return os;
320  }
321 
322  template <typename CharT, typename Traits>
323  friend std::basic_istream<CharT, Traits> &operator>>(
324  std::basic_istream<CharT, Traits> &is, distribution_type &dist)
325  {
326  if (!is)
327  return is;
328 
330  is >> std::ws >> param;
331  if (is)
332  dist.param_ = std::move(param);
333 
334  return is;
335  }
336 
337  private:
338  param_type param_;
339 
340  template <typename RNGType>
341  void generate(RNGType &rng, result_type *r, const param_type &param)
342  {
343  param_.rnorm_(rng, param.dim(), r);
344  if (!param.null_chol_)
345  mulchol(r, param);
346  if (!param.null_mean_)
347  add(param.dim(), param.mean(), r, r);
348  }
349 
350  void mulchol(float *r, const param_type &param)
351  {
352  internal::cblas_stpmv(internal::CblasRowMajor, internal::CblasLower,
353  internal::CblasNoTrans, internal::CblasNonUnit,
354  static_cast<VSMC_BLAS_INT>(dim()), param.chol(), r, 1);
355  }
356 
357  void mulchol(double *r, const param_type &param)
358  {
359  internal::cblas_dtpmv(internal::CblasRowMajor, internal::CblasLower,
360  internal::CblasNoTrans, internal::CblasNonUnit,
361  static_cast<VSMC_BLAS_INT>(dim()), param.chol(), r, 1);
362  }
363 }; // class NormalMVDistribution
364 
365 namespace internal
366 {
367 
369  std::size_t n, float *r, std::size_t dim, const float *chol)
370 {
371  cblas_strmm(CblasRowMajor, CblasRight, CblasLower, CblasTrans,
372  CblasNonUnit, static_cast<VSMC_BLAS_INT>(n),
373  static_cast<VSMC_BLAS_INT>(dim), 1, chol,
374  static_cast<VSMC_BLAS_INT>(dim), r, static_cast<VSMC_BLAS_INT>(dim));
375 }
376 
378  std::size_t n, double *r, std::size_t dim, const double *chol)
379 {
380  cblas_dtrmm(CblasRowMajor, CblasRight, CblasLower, CblasTrans,
381  CblasNonUnit, static_cast<VSMC_BLAS_INT>(n),
382  static_cast<VSMC_BLAS_INT>(dim), 1, chol,
383  static_cast<VSMC_BLAS_INT>(dim), r, static_cast<VSMC_BLAS_INT>(dim));
384 }
385 
386 } // namespace vsmc::internal
387 
390 template <typename RealType, typename RNGType>
391 inline void normal_mv_distribution(RNGType &rng, std::size_t n, RealType *r,
392  std::size_t dim, const RealType *mean, const RealType *chol)
393 {
395  "**normal_mv_distribution** USED WITH RealType OTHER THAN float OR "
396  "double");
397 
398  internal::size_check<VSMC_BLAS_INT>(n, "normal_mv_distribution");
399  internal::size_check<VSMC_BLAS_INT>(dim, "normal_mv_distribution");
400 
401  normal_distribution(rng, n * dim, r, 0.0, 1.0);
402  if (chol != nullptr) {
403  Vector<RealType> cholf(dim * dim);
404  for (std::size_t i = 0; i != dim; ++i)
405  for (std::size_t j = 0; j <= i; ++j)
406  cholf[i * dim + j] = *chol++;
407  internal::normal_mv_distribution_mulchol(n, r, dim, cholf.data());
408  }
409  if (mean != nullptr)
410  for (std::size_t i = 0; i != n; ++i, r += dim)
411  add(dim, mean, r, r);
412 }
413 
414 template <typename RealType, typename RNGType>
415 inline void normal_mv_distribution(RNGType &rng, std::size_t n, RealType *r,
417 {
418  normal_mv_distribution(rng, n, r, param.dim(), param.mean(), param.chol());
419 }
420 
421 template <typename RealType, std::size_t Dim, typename RNGType>
422 inline void rand(RNGType &rng, NormalMVDistribution<RealType, Dim> &dist,
423  std::size_t n, RealType *r)
424 {
425  dist(rng, n, r);
426 }
427 
428 } // namespace vsmc
429 
430 #endif // VSMC_RNG_NORMAL_DISTRIBUTION_HPP
std::vector< T, Alloc > Vector
std::vector with Allocator as default allocator
Definition: monitor.hpp:48
void param(const param_type &param)
const result_type * mean() const
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, const param_type &param)
param_type(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
NormalMVDistribution(const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim > 0
param_type(const result_type *mean=nullptr, const result_type *chol=nullptr)
void normal_mv_distribution(RNGType &, std::size_t, RealType *, std::size_t, const RealType *, const RealType *)
Generating multivariate Normal random varaites.
void normal_mv_distribution_mulchol(std::size_t n, double *r, std::size_t dim, const double *chol)
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, distribution_type &dist)
void operator()(RNGType &rng, std::size_t n, result_type *r)
void max(result_type *x) const
friend bool operator!=(const distribution_type &dist1, const distribution_type &dist2)
friend bool operator==(const distribution_type &dist1, const distribution_type &dist2)
void normal_distribution(RNGType &, std::size_t, RealType *, RealType, RealType)
Generating Normal random variates.
void min(result_type *x) const
void normal_mv_distribution_mulchol(std::size_t n, float *r, std::size_t dim, const float *chol)
NormalMVDistribution(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim == Dynamic
void operator()(RNGType &rng, result_type *r, const param_type &param)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const distribution_type &dist)
Multivariate Normal distribution.
friend bool operator==(const param_type &param1, const param_type &param2)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const param_type &param)
void operator()(RNGType &rng, std::size_t n, result_type *r, const param_type &param)
void rand(RNGType &rng, ArcsineDistribution< RealType > &dist, std::size_t N, RealType *r)
void param(param_type &&param)
void add(std::size_t n, const float *a, const float *b, float *y)
Definition: vmath.hpp:74
NormalMVDistribution< RealType, Dim > distribution_type
typename std::conditional< N==Dynamic, Vector< T >, std::array< T, N >>::type StaticVector
friend bool operator!=(const param_type &param1, const param_type &param2)
void operator()(RNGType &rng, result_type *r)
const result_type * chol() const
NormalMVDistribution< RealType, Dim > distribution_type