vSMC
vSMC: Scalable Monte Carlo
normal_mv_distribution.hpp
Go to the documentation of this file.
1 //============================================================================
2 // vSMC/include/vsmc/rng/normal_mv_distribution.hpp
3 //----------------------------------------------------------------------------
4 // vSMC: Scalable Monte Carlo
5 //----------------------------------------------------------------------------
6 // Copyright (c) 2013-2016, Yan Zhou
7 // All rights reserved.
8 //
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 //
12 // Redistributions of source code must retain the above copyright notice,
13 // this list of conditions and the following disclaimer.
14 //
15 // Redistributions in binary form must reproduce the above copyright notice,
16 // this list of conditions and the following disclaimer in the documentation
17 // and/or other materials provided with the distribution.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
20 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 // POSSIBILITY OF SUCH DAMAGE.
30 //============================================================================
31 
32 #ifndef VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP
33 #define VSMC_RNG_NORMAL_MV_DISTRIBUTION_HPP
34 
37 
38 namespace vsmc
39 {
40 
53 template <typename RealType, std::size_t Dim>
54 class NormalMVDistribution
55 {
56 
57  public:
58  using result_type = RealType;
60 
61  class param_type
62  {
64  "**NormalMVDistributon::param_type** USED WITH RealType OTHER "
65  "THAN float OR double");
66 
67  public:
68  using result_type = RealType;
70 
71  explicit param_type(const result_type *mean = nullptr,
72  const result_type *chol = nullptr)
73  : rnorm_(0, 1)
74  , null_mean_(mean == nullptr)
75  , null_chol_(chol == nullptr)
76  {
77  static_assert(Dim != Dynamic, "**NormalMVDistribution::param_type*"
78  "* OBJECT DECLARED WITH DYNAMIC "
79  "DIMENSION");
80  init(mean, chol);
81  }
82 
83  explicit param_type(std::size_t dim, const result_type *mean = nullptr,
84  const result_type *chol = nullptr)
85  : rnorm_(0, 1)
86  , mean_(dim)
87  , chol_(dim * (dim + 1) / 2)
88  , null_mean_(mean == nullptr)
89  , null_chol_(chol == nullptr)
90  {
91  static_assert(Dim == Dynamic, "**NormalMVDistribution::param_type*"
92  "* OBJECT DECLARED WITH FIXED "
93  "DIMENSION");
94  init(mean, chol);
95  }
96 
97  std::size_t dim() const { return mean_.size(); }
98 
99  const result_type *mean() const { return mean_.data(); }
100 
101  const result_type *chol() const { return chol_.data(); }
102 
103  friend bool operator==(
104  const param_type &param1, const param_type &param2)
105  {
106  if (param1.norm_ != param2.norm_)
107  return false;
108  if (param1.mean_ != param2.mean_)
109  return false;
110  if (param1.chol_ != param2.chol_)
111  return false;
112  if (param1.null_mean_ != param2.null_mean_)
113  return false;
114  if (param1.null_chol_ != param2.null_chol_)
115  return false;
116  return true;
117  }
118 
119  friend bool operator!=(
120  const param_type &param1, const param_type &param2)
121  {
122  return !(param1 == param2);
123  }
124 
125  template <typename CharT, typename Traits>
126  friend std::basic_ostream<CharT, Traits> &operator<<(
127  std::basic_ostream<CharT, Traits> &os, const param_type &param)
128  {
129  if (!os.good())
130  return os;
131 
132  os << param.norm_ << ' ';
133  os << param.dim() << ' ';
134  os << param.mean_ << ' ';
135  os << param.chol_ << ' ';
136  os << param.null_mean_ << ' ';
137  os << param.null_chol_;
138 
139  return os;
140  }
141 
142  template <typename CharT, typename Traits>
143  friend std::basic_ostream<CharT, Traits> &operator>>(
144  std::basic_istream<CharT, Traits> &is, const param_type &param)
145  {
146  if (!is.good())
147  return is;
148 
151  internal::Array<result_type, Dim *(Dim + 1) / 2> chol;
152  bool null_mean;
153  bool null_chol;
154 
155  is >> std::ws >> rnorm;
156  if (!is.good())
157  return is;
158 
159  std::size_t dim = 0;
160  is >> std::ws >> dim;
161  if (!is.good())
162  return is;
163 
164  internal::resize(mean, dim);
165  internal::resize(chol, dim * (dim + 1) / 2);
166  is >> std::ws >> mean;
167  is >> std::ws >> chol;
168  is >> std::ws >> null_mean;
169  is >> std::ws >> null_chol;
170 
171  if (is.good()) {
172  param.rnorm_ = std::move(rnorm);
173  param.mean_ = std::move(mean);
174  param.chol_ = std::move(chol);
175  param.null_mean_ = null_mean;
176  param.null_chol_ = null_chol;
177  } else {
178  is.setstate(std::ios_base::failbit);
179  }
180 
181  return is;
182  }
183 
184  private:
187  internal::Array<result_type, Dim *(Dim + 1) / 2> chol_;
188  bool null_mean_;
189  bool null_chol_;
190 
191  friend distribution_type;
192 
193  void init(const result_type *mean, const result_type *chol)
194  {
195  if (mean == nullptr)
196  std::fill(mean_.begin(), mean_.end(), 0);
197  else
198  std::copy_n(mean, mean_.size(), mean_.begin());
199 
200  if (chol == nullptr)
201  std::fill(chol_.begin(), chol_.end(), 0);
202  else
203  std::copy_n(chol, chol_.size(), chol_.begin());
204 
205  if (chol == nullptr)
206  for (std::size_t i = 0; i != mean_.size(); ++i)
207  chol_[i * (i + 1) / 2 + i] = 1;
208  }
209  }; // class param_type
210 
219  const result_type *mean = nullptr, const result_type *chol = nullptr)
220  : param_(mean, chol)
221  {
222  reset();
223  }
224 
226  explicit NormalMVDistribution(std::size_t dim,
227  const result_type *mean = nullptr, const result_type *chol = nullptr)
228  : param_(dim, mean, chol)
229  {
230  reset();
231  }
232 
233  void min(result_type *x) const
234  {
235  std::fill_n(x, dim(), std::numeric_limits<result_type>::lowest());
236  }
237 
238  void max(result_type *x) const
239  {
240  std::fill_n(x, dim(), std::numeric_limits<result_type>::max());
241  }
242 
243  void reset() { param_.rnorm_.reset(); }
244 
245  std::size_t dim() const { return param_.dim(); }
246 
247  const result_type *mean() const { return param_.mean(); }
248 
249  const result_type *chol() const { return param_.chol(); }
250 
251  param_type param() const { return param_; }
252 
253  void param(const param_type &param)
254  {
255  param_ = param;
256  reset();
257  }
258 
260  {
261  param_ = std::move(param);
262  reset();
263  }
264 
265  template <typename RNGType>
266  void operator()(RNGType &rng, result_type *r)
267  {
268  operator()(rng, r, param_);
269  }
270 
271  template <typename RNGType>
272  void operator()(RNGType &rng, result_type *r, const param_type &param)
273  {
274  generate(rng, r, param);
275  }
276 
277  template <typename RNGType>
278  void operator()(RNGType &rng, std::size_t n, result_type *r)
279  {
280  operator()(rng, n, r, param_);
281  }
282 
283  template <typename RNGType>
285  RNGType &rng, std::size_t n, result_type *r, const param_type &param)
286  {
287  normal_mv_distribution(rng, n, r, param.dim(),
288  (param.null_mean_ ? param.mean() : nullptr),
289  (param.null_chol_ ? param.chol() : nullptr));
290  }
291 
292  friend bool operator==(
293  const distribution_type &dist1, const distribution_type &dist2)
294  {
295  if (dist1.param_ != dist2.param_)
296  return false;
297  return true;
298  }
299 
300  friend bool operator!=(
301  const distribution_type &dist1, const distribution_type &dist2)
302  {
303  return !(dist1 == dist2);
304  }
305 
306  template <typename CharT, typename Traits>
307  friend std::basic_ostream<CharT, Traits> &operator<<(
308  std::basic_ostream<CharT, Traits> &os, const distribution_type &dist)
309  {
310  if (!os.good())
311  return os;
312 
313  os << dist.param_ << ' ';
314 
315  return os;
316  }
317 
318  template <typename CharT, typename Traits>
319  friend std::basic_istream<CharT, Traits> &operator>>(
320  std::basic_istream<CharT, Traits> &is, distribution_type &dist)
321  {
322  if (!is.good())
323  return is;
324 
326  is >> std::ws >> param;
327  if (is.good())
328  dist.param_ = std::move(param);
329 
330  return is;
331  }
332 
333  private:
334  param_type param_;
335 
336  template <typename RNGType>
337  void generate(RNGType &rng, result_type *r, const param_type &param)
338  {
339  param_.rnorm_(rng, param.dim(), r);
340  if (!param.null_chol_)
341  mulchol(r, param);
342  if (!param.null_mean_)
343  add(param.dim(), param.mean(), r, r);
344  }
345 
346  void mulchol(float *r, const param_type &param)
347  {
348  ::cblas_stpmv(::CblasRowMajor, ::CblasLower, ::CblasNoTrans,
349  ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(dim()), param.chol(),
350  r, 1);
351  }
352 
353  void mulchol(double *r, const param_type &param)
354  {
355  ::cblas_dtpmv(::CblasRowMajor, ::CblasLower, ::CblasNoTrans,
356  ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(dim()), param.chol(),
357  r, 1);
358  }
359 }; // class NormalMVDistribution
360 
361 namespace internal
362 {
363 
365  std::size_t n, float *r, std::size_t m, const float *chol)
366 {
367  ::cblas_strmm(::CblasRowMajor, ::CblasRight, ::CblasLower, ::CblasTrans,
368  ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(n),
369  static_cast<VSMC_CBLAS_INT>(m), 1, chol,
370  static_cast<VSMC_CBLAS_INT>(m), r, static_cast<VSMC_CBLAS_INT>(m));
371 }
372 
374  std::size_t n, double *r, std::size_t m, const double *chol)
375 {
376  ::cblas_dtrmm(::CblasRowMajor, ::CblasRight, ::CblasLower, ::CblasTrans,
377  ::CblasNonUnit, static_cast<VSMC_CBLAS_INT>(n),
378  static_cast<VSMC_CBLAS_INT>(m), 1, chol,
379  static_cast<VSMC_CBLAS_INT>(m), r, static_cast<VSMC_CBLAS_INT>(m));
380 }
381 
382 } // namespace vsmc::internal
383 
386 template <typename RealType, typename RNGType>
387 inline void normal_mv_distribution(RNGType &rng, std::size_t n, RealType *r,
388  std::size_t dim, const RealType *mean, const RealType *chol)
389 {
391  "**normal_mv_distribution** USED WITH RealType OTHER THAN float OR "
392  "double");
393 
394  normal_distribution(rng, n * dim, r, 0.0, 1.0);
395  if (chol != nullptr) {
396  Vector<RealType> cholf(dim * dim);
397  for (std::size_t i = 0; i != dim; ++i)
398  for (std::size_t j = 0; j <= i; ++j)
399  cholf[i * dim + j] = *chol++;
400  internal::normal_mv_distribution_mulchol(n, r, dim, cholf.data());
401  }
402  if (mean != nullptr)
403  for (std::size_t i = 0; i != n; ++i, r += dim)
404  add(dim, mean, r, r);
405 }
406 
407 template <typename RealType, typename RNGType>
408 inline void normal_mv_distribution(RNGType &rng, std::size_t n, RealType *r,
410 {
411  normal_mv_distribution(rng, n, r, param.dim(), param.mean(), param.chol());
412 }
413 
414 template <typename RealType, std::size_t Dim, typename RNGType>
415 inline void rng_rand(RNGType &rng, NormalMVDistribution<RealType, Dim> &dist,
416  std::size_t n, RealType *r)
417 {
418  dist(rng, n, r);
419 }
420 
421 } // namespace vsmc
422 
423 #endif // VSMC_RNG_NORMAL_DISTRIBUTION_HPP
Definition: monitor.hpp:49
void param(const param_type &param)
const result_type * mean() const
param_type(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
NormalMVDistribution(const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim > 0
param_type(const result_type *mean=nullptr, const result_type *chol=nullptr)
typename std::conditional< std::is_scalar< T >::value, AlignedVector< T >, std::vector< T >>::type Vector
AlignedVector for scalar type and std::vector for others.
void normal_mv_distribution(RNGType &, std::size_t, RealType *, std::size_t, const RealType *, const RealType *)
Generating multivariate Normal random varaites.
friend std::basic_ostream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, const param_type &param)
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, distribution_type &dist)
void rng_rand(RNGType &rng, BetaDistribution< RealType > &dist, std::size_t n, RealType *r)
void normal_mv_distribution_mulchol(std::size_t n, float *r, std::size_t m, const float *chol)
void operator()(RNGType &rng, std::size_t n, result_type *r)
void max(result_type *x) const
friend bool operator!=(const distribution_type &dist1, const distribution_type &dist2)
friend bool operator==(const distribution_type &dist1, const distribution_type &dist2)
void normal_distribution(RNGType &, std::size_t, RealType *, RealType, RealType)
Generating Normal random variates.
void min(result_type *x) const
Normal distribution.
Definition: common.hpp:582
NormalMVDistribution(std::size_t dim, const result_type *mean=nullptr, const result_type *chol=nullptr)
Only usable when Dim == Dynamic
void operator()(RNGType &rng, result_type *r, const param_type &param)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const distribution_type &dist)
Multivariate Normal distribution.
Definition: common.hpp:585
friend bool operator==(const param_type &param1, const param_type &param2)
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const param_type &param)
void resize(std::array< T, N > &, std::size_t)
Definition: common.hpp:138
void operator()(RNGType &rng, std::size_t n, result_type *r, const param_type &param)
typename std::conditional< Dim==Dynamic, Vector< T >, std::array< T, Dim >>::type Array
Definition: common.hpp:135
void normal_mv_distribution_mulchol(std::size_t n, double *r, std::size_t m, const double *chol)
void param(param_type &&param)
void add(std::size_t n, const float *a, const float *b, float *y)
Definition: vmath.hpp:109
NormalMVDistribution< RealType, Dim > distribution_type
friend bool operator!=(const param_type &param1, const param_type &param2)
void operator()(RNGType &rng, result_type *r)
const result_type * chol() const
NormalMVDistribution< RealType, Dim > distribution_type