1 module sfmt;
2 
3 import std.stdio;
4 static import sfmt.internal;
5 import sfmt.internal : func1, func2, idxof, ucent_;
6 
7 import std.algorithm : max, min;
8 
9 
10 mixin (sfmt.internal.sfmtMixins([size_t(607), 1279, 2281, 4253, 11213, 19937], [
11         size_t(1), 2, 3, 4, 5, 6, 7, 8,
12         9, 10, 11, 12, 13, 14, 15, 16,
13         17, 18, 19, 20, 21, 22, 23, 24,
14         25, 26, 27, 28, 29, 30, 31, 32]));
15 
16 struct SFMT(sfmt.internal.Parameters parameters)
17 {
18     enum mersenneExponent = parameters.mersenneExponent;
19     enum n = (mersenneExponent >> 7) + 1;
20     enum size = n << 2;
21     enum m = parameters.m;
22     enum shifts = parameters.shifts;
23     enum masks = parameters.masks;
24     enum parity = parameters.parity;
25     enum id = parameters.id;
26     alias recursion = sfmt.internal.recursion!(shifts, masks);
27 
28     this (uint seed)
29     {
30         this.seed(seed);
31     }
32     this (uint[] seed)
33     {
34         this.seed(seed);
35     }
36     void printState()
37     {
38         import std.stdio;
39         "state:".writeln;
40         foreach (row; state[0..2]~state[$-2..$])
41             "%(%08x %)".writefln(row.u32);
42     }
43     void fillState(ubyte b)
44     {
45         ucent_ x;
46         x.u32[0] = b;
47         x.u32[0] = x.u32[0] << 8 | x.u32[0];
48         x.u32[0] = x.u32[0] << 16 | x.u32[0];
49         x.u32[1..$] = x.u32[0];
50         state[] = x;
51     }
52     // checked
53     void seed(uint seed)
54     {
55         uint* psfmt32 = &(state[0].u32[0]);
56         psfmt32[idxof!0] = seed;
57         foreach (i; 1..size)
58             psfmt32[i.idxof] = 1812433253U * (psfmt32[(i - 1).idxof] ^ (psfmt32[(i - 1).idxof] >> 30)) + i;
59         idx = size;
60         assureLongPeriod;
61     }
62     void seed(uint[] seed)
63     {
64         static if (size >= 623)
65             enum lag = 11;
66         else static if (size >= 68)
67             enum lag = 7;
68         else static if (size >= 39)
69             enum lag = 5;
70         else
71             enum lag = 3;
72         enum mid = (size - lag) / 2;
73         fillState(0x8b);
74         immutable count = seed.length.max(size - 1);
75         uint* psfmt32 = &(state[0].u32[0]);
76         uint r = func1(psfmt32[idxof!0] ^ psfmt32[idxof!mid] ^ psfmt32[idxof!(size - 1)]);
77         psfmt32[idxof!mid] += r;
78         r += seed.length;
79         psfmt32[idxof!(mid+lag)] += r;
80         psfmt32[idxof!0] = r;
81 
82         size_t i = 1;
83         foreach (j; 0..count.min(seed.length))
84         {
85             r = func1(
86                     psfmt32[i.idxof]
87                   ^ psfmt32[((i+mid)%size).idxof]
88                   ^ psfmt32[((i+size-1)%size).idxof]);
89             psfmt32[((i+mid)%size).idxof] += r;
90             r += seed[j] + i;
91             psfmt32[((i+mid+lag)%size).idxof] += r;
92             psfmt32[i.idxof] = r;
93             i = (i + 1) % size;
94         }
95         foreach (j; count.min(seed.length)..count)
96         {
97             r = func1(
98                     psfmt32[i.idxof]
99                   ^ psfmt32[((i+mid)%size).idxof]
100                   ^ psfmt32[((i+size-1)%size).idxof]);
101             psfmt32[((i+mid)%size).idxof] += r;
102             r += i;
103             psfmt32[((i+mid+lag)%size).idxof] += r;
104             psfmt32[i.idxof] = r;
105             i = (i + 1) % size;
106         }
107         foreach (j; 0..size)
108         {
109             r = func2(
110                     psfmt32[i.idxof]
111                   + psfmt32[((i+mid)%size).idxof]
112                   + psfmt32[((i+size-1)%size).idxof]);
113             psfmt32[((i+mid)%size).idxof] ^= r;
114             r -= i;
115             psfmt32[((i+mid+lag)%size).idxof] ^= r;
116             psfmt32[i.idxof] = r;
117             i = (i + 1) % size;
118         }
119         idx = size;
120         assureLongPeriod;
121     }
122     version (Big32){} else
123     T next(T)()
124         if (is (T == ulong))
125     {
126         ulong* psfmt64 = &(state[0].u64[0]);
127         assert (idx % 2 == 0, "out of alignment");
128         if (size <= idx)
129             generateAll;
130         immutable r = psfmt64[idx / 2];
131         idx += 2;
132         return r;
133     }
134     version (Big64){} else
135     T next(T)()
136         if (is (T == uint))
137     {
138         uint* psfmt32 = &(state[0].u32[0]);
139         if (size <= idx)
140             generateAll;
141         immutable r = psfmt32[idx];
142         idx += 1;
143         return r;
144     }
145     T next(T)(size_t size)
146         if (is (T == ulong[]) || is (T == uint[]))
147     {
148         return cast(T)fill(cast(ucent_[])(new T(size)));
149     }
150     private auto fill(ucent_[] array)
151     in
152     {
153         assert (n <= array.length);
154     }
155     body
156     {
157         immutable size = array.length;
158         recursion(
159             array[0], state[0],
160             state[0 + m],
161             state[n - 2], state[n - 1]);
162         recursion(
163             array[1], state[1],
164             state[1 + m],
165             state[n - 1], array[0]);
166 
167         foreach (i; 2 .. n-m)
168         {
169             recursion(
170                 array[i], state[i],
171                 state[i + m],
172                 array[i - 2], array[i - 1]);
173         }
174         foreach (i; n-m .. n)
175         {
176             recursion(
177                 array[i], state[i],
178                 array[i + m - n],
179                 array[i - 2], array[i - 1]);
180         }
181         foreach (i; n .. size-n)
182         {
183             recursion(
184                 array[i], array[i - n],
185                 array[i + m - n],
186                 array[i - 2], array[i - 1]);
187         }
188         foreach (j; 0..ptrdiff_t(2*n-size).max(0))
189         {
190             state[j] = array[j + size - n];
191         }
192         size_t j = ptrdiff_t(2*n-size).max(0);
193         foreach (i; size-n..size)
194         {
195             recursion(
196                 array[i], array[i - n],
197                 array[i + m - n],
198                 array[i - 2], array[i - 1]);
199             state[j] = array[i];
200             j += 1;
201         }
202         return array;
203     }
204     private void generateAll()
205     {
206         recursion(
207             state[0], state[0],
208             state[0+m],
209             state[n - 2], state[n - 1]);
210         recursion(
211             state[1], state[1],
212             state[1+m],
213             state[n - 1], state[0]);
214         foreach (i; 2..n-m)
215         {
216             recursion(
217                 state[i], state[i],
218                 state[i+m],
219                 state[i - 2], state[i - 1]);
220         }
221         foreach (i; n-m..n)
222         {
223             recursion(
224                 state[i], state[i],
225                 state[i+m-n],
226                 state[i - 2], state[i - 1]);
227         }
228         idx = 0;
229     }
230     ucent_[n] state;
231     int idx;
232     /// returns true if modification is done
233     bool assureLongPeriod()
234     {
235         uint inner;
236         uint* psfmt32 = &(state[0].u32[0]);
237         foreach (i; 0..4)
238             inner ^= psfmt32[i.idxof] & parity[i];
239         foreach (i; [16, 8, 4, 2, 1])
240             inner ^= inner >> i;
241         inner &= 1;
242         if (inner == 1)
243             return false;
244         foreach (i; 0..4)
245         {
246             uint working = 1;
247             foreach (j; 0..32)
248             {
249                 if (working & parity[i])
250                 {
251                     psfmt32[i.idxof] ^= working;
252                     return true;
253                 }
254                 working <<= 1;
255             }
256         }
257         assert (false, "unreachable?");
258     }
259 }
260 
261 version (BigEndian)
262 {
263     pragma (msg, "not tested");
264     version (Only64bit)
265         version = Big64;
266     else version (With32bit)
267         version = Big32;
268     else static assert (false, "Specify Only64bit or With32bit in BigEndian environment");
269 }
270 version (LittleEndian)
271 {
272     pragma (msg, "supported");
273 }
274 
275 version (Only64bit)
276     version (With32bit)
277         static assert (false, "Specify (at most) one of Only64bit or With32bit");
278