1 /// Main API of SIMD-oriented Fast Mersenne Twister (SFMT).
2 module sfmt;
3 
4 version (unittest)
5     import std.stdio : stderr;
6 static import sfmt.internal;
7 import sfmt.internal : func1, func2, idxof, ucent_;
8 
9 import std.algorithm : max, min;
10 
11 static foreach (mexp; [size_t(607), 1279, 2281, 4253, 11213, 19937])
12 {
13     import std.range : iota;
14     import std.algorithm : map;
15     import std.format : format;
16     static foreach (row; 32.iota.map!(_=>_+1))
17     {
18         mixin (sfmt.internal.sfmtMixin(mexp, row));
19     }
20     mixin ("alias SFMT%d = SFMT%d_0;".format(mexp, mexp));
21 }
22 /** SFMT random number generator, whose parameters are set run time.
23 
24 */
25 struct RunTimeSFMT
26 {
27     /// Set parameters and seed with `std.random.unpredictableSeed`.
28     this (in sfmt.internal.Parameters parameters)
29     {
30         setParameters(parameters);
31         import std.random : unpredictableSeed;
32         seed(unpredictableSeed);
33     }
34     /// Set parameters and _seed with specified seed.
35     this (in sfmt.internal.Parameters parameters, in uint seed)
36     {
37         setParameters(parameters);
38         this.seed(seed);
39     }
40     /// ditto
41     this (in sfmt.internal.Parameters parameters, uint[] seed)
42     {
43         setParameters(parameters);
44         this.seed(seed);
45     }
46     /** Set parameters.
47 
48     The internal state is undefined after the call: call `seed$(LPAREN)$(RPAREN)`.
49     */
50     void setParameters(in sfmt.internal.Parameters parameters)
51     {
52         mexp(parameters.mersenneExponent);
53         m = parameters.m;
54         shifts = parameters.shifts;
55         masks = parameters.masks;
56         parity = parameters.parity;
57     }
58     /// Mersenne exponent.
59     size_t mexp() const @property
60     {
61         return mersenneExponent;
62     }
63     /// ditto
64     size_t mexp(in size_t value) @property
65     {
66         mersenneExponent = value;
67         n = (value >> 7) + 1;
68         state.length = n;
69         size = n << 2;
70         if (size >= 623)
71             lag = 11;
72         else if (size >= 68)
73             lag = 7;
74         else if (size >= 39)
75             lag = 5;
76         else
77             lag = 3;
78         mid = (size - lag) / 2;
79         tail = (size - 1).idxof;
80         imid = mid.idxof;
81         iml = (mid+lag).idxof;
82         return value;
83     }
84     ptrdiff_t n/**READONLY*/, size/**READONLY*/;
85     size_t m;///
86     size_t[4] shifts/***/, masks/***/, parity/***/;
87 
88     ///
89     string id() const
90     {
91         return "SFMT-%d:%d-%(%d-%):%(%08x-%)".format(
92                 mersenneExponent, m,
93                 shifts[],
94                 masks[]
95                 );
96     }
97 
98     mixin SFMTMixin;
99 
100 private:
101     // see also sfmt.internal.recursion
102     void recursion(ref ucent_ r, ref ucent_ a, ref ucent_ b, ref ucent_ c, ref ucent_ d)
103     {
104         immutable
105             sl1 = shifts[0],
106             sl2 = shifts[1],
107             sr1 = shifts[2],
108             sr2 = shifts[3];
109         immutable
110             m0 = masks[idxof!0],
111             m1 = masks[idxof!1],
112             m2 = masks[idxof!2],
113             m3 = masks[idxof!3];
114         auto
115             x = a << sl2,
116             y = c >> sr2;
117         r.u32[0] = a.u32[0] ^ x.u32[0] ^ ((b.u32[0] >> sr1) & m0) ^ y.u32[0] ^ (d.u32[0] << sl1);
118         r.u32[1] = a.u32[1] ^ x.u32[1] ^ ((b.u32[1] >> sr1) & m1) ^ y.u32[1] ^ (d.u32[1] << sl1);
119         r.u32[2] = a.u32[2] ^ x.u32[2] ^ ((b.u32[2] >> sr1) & m2) ^ y.u32[2] ^ (d.u32[2] << sl1);
120         r.u32[3] = a.u32[3] ^ x.u32[3] ^ ((b.u32[3] >> sr1) & m3) ^ y.u32[3] ^ (d.u32[3] << sl1);
121     }
122     size_t mersenneExponent;
123     ptrdiff_t lag, mid;
124     size_t tail, imid, iml;
125     ucent_[] state;
126 }
127 /// `RunTimeSFMT` generators with predefined variation of parameters.
128 RunTimeSFMT[] rtSFMTs;
129 ///
130 unittest
131 {
132     assert (rtSFMTs.length == 192);
133 }
134 static this ()
135 {
136     static foreach (mexp; [size_t(607), 1279, 2281, 4253, 11213, 19937])
137     {
138         import std.range : iota;
139         import std.format : format;
140         static foreach (row; 32.iota)
141         {
142             rtSFMTs ~= RunTimeSFMT(mixin ("SFMT%d_%d".format(mexp, row)).params);
143         }
144     }
145 }
146 
147 /// _SFMT random number generator, whose parameters are set compile time.
148 struct SFMT(sfmt.internal.Parameters parameters)
149 {
150     ///
151     this (uint seed)
152     {
153         this.seed(seed);
154     }
155     ///
156     this (uint[] seed)
157     {
158         this.seed(seed);
159     }
160     enum mersenneExponent = parameters.mersenneExponent;///
161     alias mexp = mersenneExponent;///
162     enum n = (mersenneExponent >> 7) + 1;///
163     enum size = n << 2;///
164     enum m = parameters.m;///
165     enum shifts = parameters.shifts;///
166     enum masks = parameters.masks;///
167     enum parity = parameters.parity;///
168     enum id = "SFMT-%d:%d-%(%d-%):%(%08x-%)".format(
169                 mersenneExponent, m,
170                 shifts[],
171                 masks[]
172                 );///
173 
174     mixin SFMTMixin;
175 
176 private:
177     alias recursion = sfmt.internal.recursion!(shifts, masks);
178     alias params = parameters;
179     static if (size >= 623)
180         enum lag = 11;
181     else static if (size >= 68)
182         enum lag = 7;
183     else static if (size >= 39)
184         enum lag = 5;
185     else
186         enum lag = 3;
187     enum mid = (size - lag) / 2;
188     enum tail = idxof!(size - 1);
189     enum imid = idxof!mid;
190     enum iml = idxof!(mid+lag);
191     ucent_[n] state;
192 }
193 ///
194 unittest
195 {
196     /// SFMT19937 is an alias of SFMT!(...).
197     import std.random;
198     static assert (isUniformRNG!SFMT19937);
199     assert (SFMT19937(4321u).front == 16924766246869039260UL);
200 }
201 ///
202 unittest
203 {
204     import std.algorithm : equal;
205     import std.range : take;
206     assert (SFMT19937(4321u).next!(ulong[])(1000).equal(
207             SFMT19937(4321u).take(1000)));
208     stderr.writeln("checked next!ulong[] and range functionality");
209 }
210 ///
211 unittest
212 {
213     import std.random;
214     auto sfmt = SFMT19937(4321u);
215     foreach (i; 0..1000)
216     {
217         assert (0 <= sfmt.uniform01!real);
218         assert (0 <= sfmt.uniform01!double);
219         assert (0 <= sfmt.uniform01!float);
220         assert (sfmt.uniform01!real < 1);
221         assert (sfmt.uniform01!double < 1);
222         assert (sfmt.uniform01!float < 1);
223     }
224     stderr.writeln("checked uniform01");
225 
226     auto sixThousandth = sfmt.front;
227     sfmt = SFMT19937(4321u);
228     foreach (i; 0..6000)
229         sfmt.popFront;
230     assert (sfmt.front == sixThousandth);
231     stderr.writeln("checked call-only-popFront case");
232 }
233 ///
234 unittest
235 {
236     void testNext(U, ISFMT)(ISFMT sfmt)
237     {
238         auto copy = sfmt;
239         auto firstBlock = sfmt.next!(U[])(10000);
240         auto secondBlock = sfmt.next!(U[])(10000);
241         U s;
242         foreach (i, b; firstBlock)
243             assert (b == (s = copy.frontPop!U), "mismatch: first[%d] = %0*,8x != %0*,8x".format(i, U.sizeof>>1, b, U.sizeof>>1, s));
244         foreach (i, b; secondBlock)
245             assert (b == (s = copy.frontPop!U), "mismatch: second[%d;%d] = %0*,8x != %0*,8x".format(i, i+firstBlock.length, U.sizeof>>1, b, U.sizeof>>1, s));
246     }
247     testNext!ulong(SFMT19937(4321u));
248     testNext!ulong(SFMT19937([uint(5), 4, 3, 2, 1]));
249     testNext!uint(SFMT19937(1234u));
250     testNext!uint(SFMT19937([uint(0x1234), 0x5678, 0x9abc, 0xdef0]));
251     stderr.writeln("checked frontPop!U and next!U[] (U = ulong, uint)");
252 }
253 ///
254 unittest
255 {
256     void testPopFrontThenBlock(size_t firstSize, size_t secondSize)
257     {
258         import std.range : drop, take;
259         import std.algorithm : equal;
260         auto sfmt = SFMT19937(4321u);
261         foreach (i; 0..firstSize*2)
262             sfmt.popFront;
263         auto a = sfmt.next!(ulong[])(secondSize*2);
264         auto b = SFMT19937(4321u).drop(firstSize*2).take(secondSize*2);
265         assert (a.equal(b));
266     }
267     foreach (i; 0..SFMT19937.n)
268         foreach (j; SFMT19937.n..SFMT19937.n*2)
269         {
270             testPopFrontThenBlock(i, j);
271         }
272     stderr.writeln("checked next!U[]");
273 }
274 /// Common functions of RunTimeSFMT and SFMT.
275 mixin template SFMTMixin()
276 {
277     enum isUniformRandom = true;///
278     enum min = ulong.min;///
279     enum max = ulong.max;///
280     void fillState(ubyte b)
281     {
282         ucent_ x;
283         x.u32[0] = b;
284         x.u32[0] = x.u32[0] << 8 | x.u32[0];
285         x.u32[0] = x.u32[0] << 16 | x.u32[0];
286         x.u32[1..$] = x.u32[0];
287         state[] = x;
288     }
289     ///
290     void seed(uint seed)
291     {
292         uint* psfmt32 = &(state[0].u32[0]);
293         psfmt32[idxof!0] = seed;
294         foreach (i; 1..size)
295             psfmt32[i.idxof] = 1812433253U * (psfmt32[(i - 1).idxof] ^ (psfmt32[(i - 1).idxof] >> 30)) + i;
296         idx = size;
297         assureLongPeriod;
298         generateAll;
299     }
300     ///
301     void seed(uint[] seed)
302     {
303         fillState(0x8b);
304         immutable count = seed.length.max(size - 1);
305         uint* psfmt32 = &(state[0].u32[0]);
306         uint r = func1(psfmt32[idxof!0] ^ psfmt32[imid] ^ psfmt32[tail]);
307         psfmt32[imid] += r;
308         r += seed.length;
309         psfmt32[iml] += r;
310         psfmt32[idxof!0] = r;
311 
312         size_t i = 1;
313         foreach (j; 0..count.min(seed.length))
314         {
315             r = func1(
316                     psfmt32[i.idxof]
317                   ^ psfmt32[((i+mid)%size).idxof]
318                   ^ psfmt32[((i+size-1)%size).idxof]);
319             psfmt32[((i+mid)%size).idxof] += r;
320             r += seed[j] + i;
321             psfmt32[((i+mid+lag)%size).idxof] += r;
322             psfmt32[i.idxof] = r;
323             i = (i + 1) % size;
324         }
325         foreach (j; count.min(seed.length)..count)
326         {
327             r = func1(
328                     psfmt32[i.idxof]
329                   ^ psfmt32[((i+mid)%size).idxof]
330                   ^ psfmt32[((i+size-1)%size).idxof]);
331             psfmt32[((i+mid)%size).idxof] += r;
332             r += i;
333             psfmt32[((i+mid+lag)%size).idxof] += r;
334             psfmt32[i.idxof] = r;
335             i = (i + 1) % size;
336         }
337         foreach (j; 0..size)
338         {
339             r = func2(
340                     psfmt32[i.idxof]
341                   + psfmt32[((i+mid)%size).idxof]
342                   + psfmt32[((i+size-1)%size).idxof]);
343             psfmt32[((i+mid)%size).idxof] ^= r;
344             r -= i;
345             psfmt32[((i+mid+lag)%size).idxof] ^= r;
346             psfmt32[i.idxof] = r;
347             i = (i + 1) % size;
348         }
349         idx = size;
350         assureLongPeriod;
351         generateAll;
352     }
353     /// input range interface.
354     enum empty = false;
355     /// ditto
356     ulong front() @property
357     {
358         assert (idx % 2 == 0, "out of alignment");
359         ulong* psfmt64 = &(state[0].u64[0]);
360         return psfmt64[idx / 2];
361     }
362     /// ditto
363     void popFront()
364     {
365         idx += 2;
366         if (size <= idx) // in current implementation,
367             generateAll; // this is necessary when only popFront is called repeatedly.
368     }
369     version (Big32){} else
370     T frontPop(T : ulong)()///
371     {
372         auto ret = front;
373         popFront;
374         return ret;
375     }
376     version (Big64){} else
377     T frontPop(T : uint)()///
378     {
379         uint* psfmt32 = &(state[0].u32[0]);
380         immutable r = psfmt32[idx];
381         idx += 1;
382         if (size <= idx)
383             generateAll;
384         return r;
385     }
386     ///
387     T next(T)(size_t size)
388         if (is (T == ulong[]) || is (T == uint[]))
389     {
390         return cast(T)fill(cast(ucent_[])(new T(size)));
391     }
392     private auto fill(ucent_[] array)
393     in
394     {
395         assert (n <= array.length);
396         assert (idx % 4 == 0, "out of alignment");
397     }
398     body
399     {
400         immutable size_t
401             index = idx / 4,
402             size = array.length;
403         immutable size_t
404             prepared = n-index;
405         array[0..prepared] = state[index..$];
406         // array[prepared-j] == state[n-j]
407         // array[i-n] == state[i-prepared]
408         // array[i] == state[i+n-prepared] == state[i+index]
409         if (prepared <= 0)
410         {
411             recursion(
412                 array[0], state[0],
413                 state[m],
414                 state[index-2], array[index-1]);
415         }
416         if (prepared <= 1)
417         {
418             recursion(
419                 array[1], state[1-prepared],
420                 state[1+m-prepared],
421                 state[index-1], array[0]);
422         }
423         if (prepared <= n-m-1)
424         foreach (i; prepared.max(2)..n-m)
425         {
426             recursion(
427                 array[i], state[i-prepared],
428                 state[i-(prepared-m)],
429                 array[i-2], array[i-1]);
430         }
431         foreach (i; prepared.max(n-m)..n)
432         {
433             recursion(
434                 array[i], state[i-prepared],
435                 array[i-(n-m)],
436                 array[i-2], array[i-1]);
437         }
438         foreach (i; n .. size)
439         {
440             recursion(
441                 array[i], array[i-n],
442                 array[i-(n-m)],
443                 array[i-2], array[i-1]);
444         }
445         // array[$-n+i] == state[i-n]
446         // array[$+i] == state[i]
447         recursion(
448             state[0], array[$-n],
449             array[$-(n-m)],
450             array[$-2], array[$-1]);
451         recursion(
452             state[1], array[$+1-n],
453             array[$+1-(n-m)],
454             array[$-1], state[0]);
455         foreach (i; 2..(n-m))
456         {
457             recursion(
458                 state[i], array[$+i-n],
459                 array[$+i-(n-m)],
460                 state[i-2], state[i-1]);
461         }
462         foreach (i; (n-m)..n)
463         {
464             recursion(
465                 state[i], array[$+i-n],
466                 state[i-(n-m)],
467                 state[i-2], state[i-1]);
468         }
469         return array;
470     }
471     private void generateAll()
472     {
473         recursion(
474             state[0], state[0],
475             state[0+m],
476             state[n - 2], state[n - 1]);
477         recursion(
478             state[1], state[1],
479             state[1+m],
480             state[n - 1], state[0]);
481         foreach (i; 2..n-m)
482         {
483             recursion(
484                 state[i], state[i],
485                 state[i+m],
486                 state[i - 2], state[i - 1]);
487         }
488         foreach (i; n-m..n)
489         {
490             recursion(
491                 state[i], state[i],
492                 state[i+m-n],
493                 state[i - 2], state[i - 1]);
494         }
495         idx = 0;
496     }
497     int idx;
498     /// returns true if modification is done
499     bool assureLongPeriod()
500     {
501         uint inner;
502         uint* psfmt32 = &(state[0].u32[0]);
503         static foreach (i; 0..4)
504             inner ^= psfmt32[idxof!i] & parity[i];
505         foreach (i; [16, 8, 4, 2, 1])
506             inner ^= inner >> i;
507         inner &= 1;
508         if (inner == 1)
509             return false;
510         foreach (i; 0..4)
511         {
512             uint working = 1;
513             foreach (j; 0..32)
514             {
515                 if (working & parity[i])
516                 {
517                     psfmt32[i.idxof] ^= working;
518                     return true;
519                 }
520                 working <<= 1;
521             }
522         }
523         assert (false, "unreachable?");
524     }
525 }
526 ///
527 unittest
528 {
529     auto ct = SFMT19937(13579u);
530     RunTimeSFMT rt;
531     rt.mexp(ct.mersenneExponent);
532     rt.m = ct.m;
533     rt.shifts = ct.shifts;
534     rt.masks = ct.masks;
535     rt.parity = ct.parity;
536     rt.seed(13579u);
537     foreach (i; 0..1000)
538         assert (ct.frontPop!ulong == rt.frontPop!ulong);
539     stderr.writeln("checked compile time and run time");
540 }
541 
542 version (BigEndian)
543 {
544     pragma (msg, "not tested");
545     version (Only64bit)
546         version = Big64;
547     else version (With32bit)
548         version = Big32;
549     else static assert (false, "Specify Only64bit or With32bit in BigEndian environment");
550 }
551 version (LittleEndian)
552 {
553     pragma (msg, "supported");
554 }
555 
556 version (Only64bit)
557     version (With32bit)
558         static assert (false, "Specify (at most) one of Only64bit or With32bit");
559