vSMC
vSMC: Scalable Monte Carlo
aes_ni.hpp
Go to the documentation of this file.
1 //============================================================================
2 // vSMC/include/vsmc/rng/aes_ni.hpp
3 //----------------------------------------------------------------------------
4 // vSMC: Scalable Monte Carlo
5 //----------------------------------------------------------------------------
6 // Copyright (c) 2013,2014, Yan Zhou
7 // All rights reserved.
8 //
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 //
12 // Redistributions of source code must retain the above copyright notice,
13 // this list of conditions and the following disclaimer.
14 //
15 // Redistributions in binary form must reproduce the above copyright notice,
16 // this list of conditions and the following disclaimer in the documentation
17 // and/or other materials provided with the distribution.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
20 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 // POSSIBILITY OF SUCH DAMAGE.
30 //============================================================================
31 
32 #ifndef VSMC_RNG_AES_NI_HPP
33 #define VSMC_RNG_AES_NI_HPP
34 
36 #include <vsmc/rng/m128i.hpp>
37 #include <wmmintrin.h>
38 
39 #define VSMC_STATIC_ASSERT_RNG_AES_NI_BLOCKS(Blocks) \
40  VSMC_STATIC_ASSERT((Blocks > 0), USE_AESNIEngine_WITH_ZERO_BLOCKS)
41 
42 #define VSMC_STATIC_ASSERT_RNG_AES_NI_RESULT_TYPE(ResultType) \
43  VSMC_STATIC_ASSERT((cxx11::is_unsigned<ResultType>::value), \
44  USE_AESNIEngine_WITH_RESULT_TYPE_NOT_AN_UNSIGNED_INTEGER)
45 
46 #define VSMC_STATIC_ASSERT_RNG_AES_NI \
47  VSMC_STATIC_ASSERT_RNG_AES_NI_BLOCKS(Blocks); \
48  VSMC_STATIC_ASSERT_RNG_AES_NI_RESULT_TYPE(ResultType);
49 
50 namespace vsmc {
51 
52 namespace internal {
53 
54 template <typename KeySeq, bool KeySeqInit, std::size_t Rounds>
56 
57 template <typename KeySeq, std::size_t Rounds>
58 class AESNIKeySeqStorage<KeySeq, true, Rounds>
59 {
60  public :
61 
62  typedef typename KeySeq::key_type key_type;
64 
65  key_seq_type get (const key_type &) const {return key_seq_;}
66 
67  void set (const key_type &k)
68  {
69  KeySeq seq;
70  seq.generate(k, key_seq_);
71  }
72 
73  template <typename CharT, typename Traits>
74  friend inline std::basic_ostream<CharT, Traits> &operator<< (
75  std::basic_ostream<CharT, Traits> &os,
77  {
78  if (!os.good())
79  return os;
80 
81  for (std::size_t i = 0; i != Rounds + 1; ++i) {
82  m128i_output(os, ks.key_seq_[i]);
83  os << ' ';
84  }
85 
86  return os;
87  }
88 
89  template <typename CharT, typename Traits>
90  friend inline std::basic_istream<CharT, Traits> &operator>> (
91  std::basic_istream<CharT, Traits> &is,
93  {
94  if (!is.good())
95  return is;
96 
98  for (std::size_t i = 0; i != Rounds + 1; ++i)
99  m128i_input(is, ks_tmp.key_seq_[i]);
100 
101  if (is.good()) {
102 #if VSMC_HAS_CXX11_RVALUE_REFERENCES
103  ks = cxx11::move(ks_tmp);
104 #else
105  ks = ks_tmp;
106 #endif
107  }
108 
109  return is;
110  }
111 
112  private :
113 
114  key_seq_type key_seq_;
115 }; // struct AESNIKeySeqStorage
116 
117 template <typename KeySeq, std::size_t Rounds>
118 class AESNIKeySeqStorage<KeySeq, false, Rounds>
119 {
120  public :
121 
122  typedef typename KeySeq::key_type key_type;
124 
125  key_seq_type get (const key_type &k) const
126  {
127  key_seq_type ks;
128  KeySeq seq;
129  seq.generate(k, ks);
130 
131  return ks;
132  }
133 
134  void set (const key_type &) {}
135 
136  template <typename CharT, typename Traits>
137  friend inline std::basic_ostream<CharT, Traits> &operator<< (
138  std::basic_ostream<CharT, Traits> &os,
139  const AESNIKeySeqStorage<KeySeq, false, Rounds> &) {return os;}
140 
141  template <typename CharT, typename Traits>
142  friend inline std::basic_istream<CharT, Traits> &operator>> (
143  std::basic_istream<CharT, Traits> &is,
145 }; // struct AESNIKeySeqStorage
146 
147 } // namespace vsmc::internal
148 
211 template <typename ResultType, typename KeySeq, bool KeySeqInit,
212  std::size_t Rounds, std::size_t Blocks>
214 {
215  static VSMC_CONSTEXPR const std::size_t K_ =
216  sizeof(__m128i) / sizeof(ResultType) * Blocks;
217 
218  public :
219 
220  typedef ResultType result_type;
222  typedef Array<ResultType, sizeof(__m128i) / sizeof(ResultType)> ctr_type;
224  typedef typename KeySeq::key_type key_type;
226 
227  private :
228 
229  typedef Counter<ctr_type> counter;
230 
231  public :
232 
233  explicit AESNIEngine (result_type s = 0) : index_(K_)
234  {
236  seed(s);
237  }
238 
239  template <typename SeedSeq>
240  explicit AESNIEngine (SeedSeq &seq, typename cxx11::enable_if<
241  internal::is_seed_seq<SeedSeq, result_type, key_type,
243  >::value>::type * = VSMC_NULLPTR) : index_(K_)
244  {
246  seed(seq);
247  }
248 
249  AESNIEngine (const key_type &k) : index_(K_)
250  {
252  seed(k);
253  }
254 
255  void seed (result_type s)
256  {
257  counter::reset(ctr_block_);
258  key_.fill(0);
259  key_.front() = s;
260  key_seq_.set(key_);
261  index_ = K_;
262  }
263 
264  template <typename SeedSeq>
265  void seed (SeedSeq &seq, typename cxx11::enable_if<internal::is_seed_seq<
266  SeedSeq, result_type, key_type>:: value>::type * = VSMC_NULLPTR)
267  {
268  counter::reset(ctr_block_);
269  seq.generate(key_.begin(), key_.end());
270  key_seq_.set(key_);
271  index_ = K_;
272  }
273 
274  void seed (const key_type &k)
275  {
276  counter::reset(ctr_block_);
277  key_ = k;
278  key_seq_.set(k);
279  index_ = K_;
280  }
281 
282  template <std::size_t B>
283  ctr_type ctr () const {return ctr_block_[Position<B>()];}
284 
285  ctr_block_type ctr_block () const {return ctr_block_;}
286 
287  key_type key () const {return key_;}
288 
289  key_seq_type key_seq () const {return key_seq_.get(key_);}
290 
291  void ctr (const ctr_type &c)
292  {
293  counter::set(ctr_block_, c);
294  index_ = K_;
295  }
296 
297  void key (const key_type &k)
298  {
299  key_ = k;
300  key_seq_.set(k);
301  index_ = K_;
302  }
303 
304  result_type operator() ()
305  {
306  if (index_ == K_) {
307  counter::increment(ctr_block_);
308  generate_buffer(ctr_block_, buffer_);
309  index_ = 0;
310  }
311 
312  return reinterpret_cast<const result_type *>(buffer_.data())[index_++];
313  }
314 
317  buffer_type operator() (const ctr_type &c) const
318  {
319  ctr_block_type cb;
320  counter::set(cb, c);
321  buffer_type buf;
322  generate_buffer(cb, buf);
323 
324  return buf;
325  }
326 
329  buffer_type operator() (const ctr_block_type &cb) const
330  {
331  buffer_type buf;
332  generate_buffer(cb, buf);
333 
334  return buf;
335  }
336 
339  void operator() (const ctr_type &c, buffer_type &buf) const
340  {
341  ctr_block_type cb;
342  counter::set(cb, c);
343  generate_buffer(cb, buf);
344  }
345 
348  void operator() (const ctr_block_type &cb, buffer_type &buf) const
349  {generate_buffer(cb, buf);}
350 
351  void discard (result_type nskip)
352  {
353  std::size_t n = static_cast<std::size_t>(nskip);
354  if (index_ + n <= K_) {
355  index_ += n;
356  return;
357  }
358 
359  n -= K_ - index_;
360  if (n <= K_) {
361  index_ = K_;
362  operator()();
363  index_ = n;
364  return;
365  }
366 
367  counter::increment(ctr_block_, static_cast<result_type>(n / K_));
368  index_ = K_;
369  operator()();
370  index_ = n % K_;
371  }
372 
373  static VSMC_CONSTEXPR const result_type _Min = 0;
374  static VSMC_CONSTEXPR const result_type _Max = static_cast<result_type>(
375  ~(static_cast<result_type>(0)));
376 
377  static VSMC_CONSTEXPR result_type min VSMC_MNE () {return _Min;}
378  static VSMC_CONSTEXPR result_type max VSMC_MNE () {return _Max;}
379 
380  friend inline bool operator== (
381  const AESNIEngine<
382  ResultType, KeySeq, KeySeqInit, Rounds, Blocks> &eng1,
383  const AESNIEngine<
384  ResultType, KeySeq, KeySeqInit, Rounds, Blocks> &eng2)
385  {
386  return eng1.index_ == eng2.index_ &&
387  eng1.key_ == eng2.key_ &&
388  eng1.ctr_block_ == eng2.ctr_block_;
389  }
390 
391  friend inline bool operator!= (
392  const AESNIEngine<
393  ResultType, KeySeq, KeySeqInit, Rounds, Blocks> &eng1,
394  const AESNIEngine<
395  ResultType, KeySeq, KeySeqInit, Rounds, Blocks> &eng2)
396  {return !(eng1 == eng2);}
397 
398  template <typename CharT, typename Traits>
399  friend inline std::basic_ostream<CharT, Traits> &operator<< (
400  std::basic_ostream<CharT, Traits> &os,
401  const AESNIEngine<
402  ResultType, KeySeq, KeySeqInit, Rounds, Blocks> &eng)
403  {
404  if (!os.good())
405  return os;
406 
407  for (std::size_t i = 0; i != Blocks; ++i) {
408  m128i_output(os, eng.buffer_[i]);
409  os << ' ';
410  }
411  os << eng.ctr_block_ << ' ';
412  os << eng.key_ << ' ';
413  os << eng.index_;
414 
415  return os;
416  }
417 
418  template <typename CharT, typename Traits>
419  friend inline std::basic_istream<CharT, Traits> &operator>> (
420  std::basic_istream<CharT, Traits> &is,
422  {
423  if (!is.good())
424  return is;
425 
427  for (std::size_t i = 0; i != Blocks; ++i)
428  m128i_input(is, eng_tmp.buffer_[i]);
429  is >> std::ws >> eng_tmp.ctr_block_;
430  is >> std::ws >> eng_tmp.key_;
431  is >> std::ws >> eng_tmp.index_;
432 
433  if (is.good()) {
434 #if VSMC_HAS_CXX11_RVALUE_REFERENCES
435  eng = cxx11::move(eng_tmp);
436 #else
437  eng = eng_tmp;
438 #endif
439  }
440 
441  return is;
442  }
443 
444  private :
445 
446  // FIXME
447  // buffer_ is automatically 16 bytes aligned
448  // Thus, we assume that ctr_block_ and buffer_ will also be 16 bytes
449  // alinged
450 
451  buffer_type buffer_;
453  ctr_block_type ctr_block_;
454  key_type key_;
455  std::size_t index_;
456 
457  void generate_buffer (const ctr_block_type &cb,
458  buffer_type &buf) const
459  {
460  const key_seq_type ks(key_seq_.get(key_));
461  pack(cb, buf);
462  enc_first<0>(ks, buf, cxx11::true_type());
463  enc_round<1>(ks, buf, cxx11::integral_constant<bool, 1 < Rounds>());
464  enc_last <0>(ks, buf, cxx11::true_type());
465  }
466 
467  template <std::size_t>
468  void enc_first (const key_seq_type &, buffer_type &,
469  cxx11::false_type) const {}
470 
471  template <std::size_t B>
472  void enc_first (const key_seq_type &ks, buffer_type &buf,
473  cxx11::true_type) const
474  {
475  buf[Position<B>()] = _mm_xor_si128(buf[Position<B>()], ks.front());
476  enc_first<B + 1>(ks, buf,
477  cxx11::integral_constant<bool, B + 1 < Blocks>());
478  }
479 
480  template <std::size_t>
481  void enc_round (const key_seq_type &, buffer_type &,
482  cxx11::false_type) const {}
483 
484  template <std::size_t N>
485  void enc_round (const key_seq_type &ks, buffer_type &buf,
486  cxx11::true_type) const
487  {
488  enc_round_block<0, N>(ks, buf, cxx11::true_type());
489  enc_round<N + 1>(ks, buf,
490  cxx11::integral_constant<bool, N + 1 < Rounds>());
491  }
492 
493  template <std::size_t, std::size_t>
494  void enc_round_block (const key_seq_type &, buffer_type &,
495  cxx11::false_type) const {}
496 
497  template <std::size_t B, std::size_t N>
498  void enc_round_block (const key_seq_type &ks, buffer_type &buf,
499  cxx11::true_type) const
500  {
501  buf[Position<B>()] = _mm_aesenc_si128(
502  buf[Position<B>()], ks[Position<N>()]);
503  enc_round_block<B + 1, N>(ks, buf,
504  cxx11::integral_constant<bool, B + 1 < Blocks>());
505  }
506 
507  template <std::size_t>
508  void enc_last (const key_seq_type &, buffer_type &,
509  cxx11::false_type) const {}
510 
511  template <std::size_t B>
512  void enc_last (const key_seq_type &ks, buffer_type &buf,
513  cxx11::true_type) const
514  {
515  buf[Position<B>()] = _mm_aesenclast_si128(
516  buf[Position<B>()], ks.back());
517  enc_last<B + 1>(ks, buf,
518  cxx11::integral_constant<bool, B + 1 < Blocks>());
519  }
520 
521  void pack (const ctr_block_type &cb, buffer_type &buf) const
522  {pack_ctr<0>(cb, buf, cxx11::true_type());}
523 
524  template <std::size_t>
525  void pack_ctr (const ctr_block_type &, buffer_type &,
526  cxx11::false_type) const {}
527 
528  template <std::size_t B>
529  void pack_ctr (const ctr_block_type &cb, buffer_type &buf,
530  cxx11::true_type) const
531  {
532  m128i_pack<0>(cb[Position<B>()], buf[Position<B>()]);
533  pack_ctr<B + 1>(cb, buf,
534  cxx11::integral_constant<bool, B + 1 < Blocks>());
535  }
536 }; // class AESNIEngine
537 
538 } // namespace vsmc
539 
540 #endif // VSMC_RNG_AES_NI_HPP
void ctr(const ctr_type &c)
Definition: aes_ni.hpp:291
Definition: adapter.hpp:37
friend bool operator==(const AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng1, const AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng2)
Definition: aes_ni.hpp:380
#define VSMC_CONSTEXPR
constexpr
Definition: defines.hpp:55
Array< __m128i, Blocks > buffer_type
Definition: aes_ni.hpp:221
friend std::basic_istream< CharT, Traits > & operator>>(std::basic_istream< CharT, Traits > &is, AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng)
Definition: aes_ni.hpp:419
static constexpr const result_type _Min
Definition: aes_ni.hpp:373
key_seq_type key_seq() const
Definition: aes_ni.hpp:289
ResultType result_type
Definition: aes_ni.hpp:220
integral_constant< bool, false > false_type
AESNIEngine(SeedSeq &seq, typename cxx11::enable_if< internal::is_seed_seq< SeedSeq, result_type, key_type, AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > >::value >::type *=nullptr)
Definition: aes_ni.hpp:240
Function template argument used for position.
Definition: defines.hpp:126
ctr_type ctr() const
Definition: aes_ni.hpp:283
AESNIEngine(result_type s=0)
Definition: aes_ni.hpp:233
#define VSMC_MNE
Avoid MSVC stupid behavior: MNE = Macro No Expansion.
Definition: defines.hpp:38
static constexpr result_type min()
Definition: aes_ni.hpp:377
friend std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng)
Definition: aes_ni.hpp:399
KeySeq::key_type key_type
Definition: aes_ni.hpp:224
std::basic_istream< CharT, Traits > & m128i_input(std::basic_istream< CharT, Traits > &is, __m128i &a)
Input an __m128i object from an input stream as 16 bytes unsigned integers written by m128i_output...
Definition: m128i.hpp:119
key_type key() const
Definition: aes_ni.hpp:287
RNG engine using AES-NI instructions.
Definition: aes_ni.hpp:213
remove_reference< T >::type && move(T &&t) noexcept
std::basic_ostream< CharT, Traits > & operator<<(std::basic_ostream< CharT, Traits > &os, const Sampler< T > &sampler)
Definition: sampler.hpp:884
static constexpr result_type max()
Definition: aes_ni.hpp:378
void seed(SeedSeq &seq, typename cxx11::enable_if< internal::is_seed_seq< SeedSeq, result_type, key_type >::value >::type *=nullptr)
Definition: aes_ni.hpp:265
Array< __m128i, Rounds+1 > key_seq_type
Definition: aes_ni.hpp:225
#define VSMC_NULLPTR
nullptr
Definition: defines.hpp:79
void key(const key_type &k)
Definition: aes_ni.hpp:297
integral_constant< bool, true > true_type
pointer data()
Definition: array.hpp:127
Array< ctr_type, Blocks > ctr_block_type
Definition: aes_ni.hpp:223
result_type operator()()
Definition: aes_ni.hpp:304
#define VSMC_STATIC_ASSERT_RNG_AES_NI
Definition: aes_ni.hpp:46
Array< ResultType, sizeof(__m128i)/sizeof(ResultType)> ctr_type
Definition: aes_ni.hpp:222
std::basic_ostream< CharT, Traits > & m128i_output(std::basic_ostream< CharT, Traits > &os, const __m128i &a)
Write an __m128i object into an output stream as 16 bytes unsigned integers.
Definition: m128i.hpp:103
void discard(result_type nskip)
Definition: aes_ni.hpp:351
friend bool operator!=(const AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng1, const AESNIEngine< ResultType, KeySeq, KeySeqInit, Rounds, Blocks > &eng2)
Definition: aes_ni.hpp:391
void seed(const key_type &k)
Definition: aes_ni.hpp:274
void seed(result_type s)
Definition: aes_ni.hpp:255
static constexpr const result_type _Max
Definition: aes_ni.hpp:374
AESNIEngine(const key_type &k)
Definition: aes_ni.hpp:249
ctr_block_type ctr_block() const
Definition: aes_ni.hpp:285