1 /// Main API of SIMD-oriented Fast Mersenne Twister (SFMT). 2 module sfmt; 3 4 version (unittest) 5 import std.stdio : stderr; 6 static import sfmt.internal; 7 import sfmt.internal : func1, func2, idxof, ucent_; 8 9 import std.algorithm : max, min; 10 11 static foreach (mexp; [size_t(607), 1279, 2281, 4253, 11213, 19937]) 12 { 13 import std.range : iota; 14 import std.algorithm : map; 15 import std.format : format; 16 static foreach (row; 32.iota.map!(_=>_+1)) 17 { 18 mixin (sfmt.internal.sfmtMixin(mexp, row)); 19 } 20 mixin ("alias SFMT%d = SFMT%d_0;".format(mexp, mexp)); 21 } 22 /** SFMT random number generator, whose parameters are set run time. 23 24 */ 25 struct RunTimeSFMT 26 { 27 /// Set parameters and seed with `std.random.unpredictableSeed`. 28 this (in sfmt.internal.Parameters parameters) 29 { 30 setParameters(parameters); 31 import std.random : unpredictableSeed; 32 seed(unpredictableSeed); 33 } 34 /// Set parameters and _seed with specified seed. 35 this (in sfmt.internal.Parameters parameters, in uint seed) 36 { 37 setParameters(parameters); 38 this.seed(seed); 39 } 40 /// ditto 41 this (in sfmt.internal.Parameters parameters, uint[] seed) 42 { 43 setParameters(parameters); 44 this.seed(seed); 45 } 46 /** Set parameters. 47 48 The internal state is undefined after the call: call `seed$(LPAREN)$(RPAREN)`. 49 */ 50 void setParameters(in sfmt.internal.Parameters parameters) 51 { 52 mexp(parameters.mersenneExponent); 53 m = parameters.m; 54 shifts = parameters.shifts; 55 masks = parameters.masks; 56 parity = parameters.parity; 57 } 58 /// Mersenne exponent. 59 size_t mexp() const @property 60 { 61 return mersenneExponent; 62 } 63 /// ditto 64 size_t mexp(in size_t value) @property 65 { 66 mersenneExponent = value; 67 n = (value >> 7) + 1; 68 state.length = n; 69 size = n << 2; 70 if (size >= 623) 71 lag = 11; 72 else if (size >= 68) 73 lag = 7; 74 else if (size >= 39) 75 lag = 5; 76 else 77 lag = 3; 78 mid = (size - lag) / 2; 79 tail = (size - 1).idxof; 80 imid = mid.idxof; 81 iml = (mid+lag).idxof; 82 return value; 83 } 84 ptrdiff_t n/**READONLY*/, size/**READONLY*/; 85 size_t m;/// 86 size_t[4] shifts/***/, masks/***/, parity/***/; 87 88 /// 89 string id() const 90 { 91 return "SFMT-%d:%d-%(%d-%):%(%08x-%)".format( 92 mersenneExponent, m, 93 shifts[], 94 masks[] 95 ); 96 } 97 98 mixin SFMTMixin; 99 100 private: 101 // see also sfmt.internal.recursion 102 void recursion(ref ucent_ r, ref ucent_ a, ref ucent_ b, ref ucent_ c, ref ucent_ d) 103 { 104 immutable 105 sl1 = shifts[0], 106 sl2 = shifts[1], 107 sr1 = shifts[2], 108 sr2 = shifts[3]; 109 immutable 110 m0 = masks[idxof!0], 111 m1 = masks[idxof!1], 112 m2 = masks[idxof!2], 113 m3 = masks[idxof!3]; 114 auto 115 x = a << sl2, 116 y = c >> sr2; 117 r.u32[0] = a.u32[0] ^ x.u32[0] ^ ((b.u32[0] >> sr1) & m0) ^ y.u32[0] ^ (d.u32[0] << sl1); 118 r.u32[1] = a.u32[1] ^ x.u32[1] ^ ((b.u32[1] >> sr1) & m1) ^ y.u32[1] ^ (d.u32[1] << sl1); 119 r.u32[2] = a.u32[2] ^ x.u32[2] ^ ((b.u32[2] >> sr1) & m2) ^ y.u32[2] ^ (d.u32[2] << sl1); 120 r.u32[3] = a.u32[3] ^ x.u32[3] ^ ((b.u32[3] >> sr1) & m3) ^ y.u32[3] ^ (d.u32[3] << sl1); 121 } 122 size_t mersenneExponent; 123 ptrdiff_t lag, mid; 124 size_t tail, imid, iml; 125 ucent_[] state; 126 } 127 /// `RunTimeSFMT` generators with predefined variation of parameters. 128 RunTimeSFMT[] rtSFMTs; 129 /// 130 unittest 131 { 132 assert (rtSFMTs.length == 192); 133 } 134 static this () 135 { 136 static foreach (mexp; [size_t(607), 1279, 2281, 4253, 11213, 19937]) 137 { 138 import std.range : iota; 139 import std.format : format; 140 static foreach (row; 32.iota) 141 { 142 rtSFMTs ~= RunTimeSFMT(mixin ("SFMT%d_%d".format(mexp, row)).params); 143 } 144 } 145 } 146 147 /// _SFMT random number generator, whose parameters are set compile time. 148 struct SFMT(sfmt.internal.Parameters parameters) 149 { 150 /// 151 this (uint seed) 152 { 153 this.seed(seed); 154 } 155 /// 156 this (uint[] seed) 157 { 158 this.seed(seed); 159 } 160 enum mersenneExponent = parameters.mersenneExponent;/// 161 alias mexp = mersenneExponent;/// 162 enum n = (mersenneExponent >> 7) + 1;/// 163 enum size = n << 2;/// 164 enum m = parameters.m;/// 165 enum shifts = parameters.shifts;/// 166 enum masks = parameters.masks;/// 167 enum parity = parameters.parity;/// 168 enum id = "SFMT-%d:%d-%(%d-%):%(%08x-%)".format( 169 mersenneExponent, m, 170 shifts[], 171 masks[] 172 );/// 173 174 mixin SFMTMixin; 175 176 private: 177 alias recursion = sfmt.internal.recursion!(shifts, masks); 178 alias params = parameters; 179 static if (size >= 623) 180 enum lag = 11; 181 else static if (size >= 68) 182 enum lag = 7; 183 else static if (size >= 39) 184 enum lag = 5; 185 else 186 enum lag = 3; 187 enum mid = (size - lag) / 2; 188 enum tail = idxof!(size - 1); 189 enum imid = idxof!mid; 190 enum iml = idxof!(mid+lag); 191 ucent_[n] state; 192 } 193 /// 194 unittest 195 { 196 /// SFMT19937 is an alias of SFMT!(...). 197 import std.random; 198 static assert (isUniformRNG!SFMT19937); 199 assert (SFMT19937(4321u).front == 16924766246869039260UL); 200 } 201 /// 202 unittest 203 { 204 import std.algorithm : equal; 205 import std.range : take; 206 assert (SFMT19937(4321u).next!(ulong[])(1000).equal( 207 SFMT19937(4321u).take(1000))); 208 stderr.writeln("checked next!ulong[] and range functionality"); 209 } 210 /// 211 unittest 212 { 213 import std.random; 214 auto sfmt = SFMT19937(4321u); 215 foreach (i; 0..1000) 216 { 217 assert (0 <= sfmt.uniform01!real); 218 assert (0 <= sfmt.uniform01!double); 219 assert (0 <= sfmt.uniform01!float); 220 assert (sfmt.uniform01!real < 1); 221 assert (sfmt.uniform01!double < 1); 222 assert (sfmt.uniform01!float < 1); 223 } 224 stderr.writeln("checked uniform01"); 225 226 auto sixThousandth = sfmt.front; 227 sfmt = SFMT19937(4321u); 228 foreach (i; 0..6000) 229 sfmt.popFront; 230 assert (sfmt.front == sixThousandth); 231 stderr.writeln("checked call-only-popFront case"); 232 } 233 /// 234 unittest 235 { 236 void testNext(U, ISFMT)(ISFMT sfmt) 237 { 238 auto copy = sfmt; 239 auto firstBlock = sfmt.next!(U[])(10000); 240 auto secondBlock = sfmt.next!(U[])(10000); 241 U s; 242 foreach (i, b; firstBlock) 243 assert (b == (s = copy.frontPop!U), "mismatch: first[%d] = %0*,8x != %0*,8x".format(i, U.sizeof>>1, b, U.sizeof>>1, s)); 244 foreach (i, b; secondBlock) 245 assert (b == (s = copy.frontPop!U), "mismatch: second[%d;%d] = %0*,8x != %0*,8x".format(i, i+firstBlock.length, U.sizeof>>1, b, U.sizeof>>1, s)); 246 } 247 testNext!ulong(SFMT19937(4321u)); 248 testNext!ulong(SFMT19937([uint(5), 4, 3, 2, 1])); 249 testNext!uint(SFMT19937(1234u)); 250 testNext!uint(SFMT19937([uint(0x1234), 0x5678, 0x9abc, 0xdef0])); 251 stderr.writeln("checked frontPop!U and next!U[] (U = ulong, uint)"); 252 } 253 /// 254 unittest 255 { 256 void testPopFrontThenBlock(size_t firstSize, size_t secondSize) 257 { 258 import std.range : drop, take; 259 import std.algorithm : equal; 260 auto sfmt = SFMT19937(4321u); 261 foreach (i; 0..firstSize*2) 262 sfmt.popFront; 263 auto a = sfmt.next!(ulong[])(secondSize*2); 264 auto b = SFMT19937(4321u).drop(firstSize*2).take(secondSize*2); 265 assert (a.equal(b)); 266 } 267 foreach (i; 0..SFMT19937.n) 268 foreach (j; SFMT19937.n..SFMT19937.n*2) 269 { 270 testPopFrontThenBlock(i, j); 271 } 272 stderr.writeln("checked next!U[]"); 273 } 274 /// Common functions of RunTimeSFMT and SFMT. 275 mixin template SFMTMixin() 276 { 277 enum isUniformRandom = true;/// 278 enum min = ulong.min;/// 279 enum max = ulong.max;/// 280 void fillState(ubyte b) 281 { 282 ucent_ x; 283 x.u32[0] = b; 284 x.u32[0] = x.u32[0] << 8 | x.u32[0]; 285 x.u32[0] = x.u32[0] << 16 | x.u32[0]; 286 x.u32[1..$] = x.u32[0]; 287 state[] = x; 288 } 289 /// 290 void seed(uint seed) 291 { 292 uint* psfmt32 = &(state[0].u32[0]); 293 psfmt32[idxof!0] = seed; 294 foreach (i; 1..size) 295 psfmt32[i.idxof] = 1812433253U * (psfmt32[(i - 1).idxof] ^ (psfmt32[(i - 1).idxof] >> 30)) + i; 296 idx = size; 297 assureLongPeriod; 298 generateAll; 299 } 300 /// 301 void seed(uint[] seed) 302 { 303 fillState(0x8b); 304 immutable count = seed.length.max(size - 1); 305 uint* psfmt32 = &(state[0].u32[0]); 306 uint r = func1(psfmt32[idxof!0] ^ psfmt32[imid] ^ psfmt32[tail]); 307 psfmt32[imid] += r; 308 r += seed.length; 309 psfmt32[iml] += r; 310 psfmt32[idxof!0] = r; 311 312 size_t i = 1; 313 foreach (j; 0..count.min(seed.length)) 314 { 315 r = func1( 316 psfmt32[i.idxof] 317 ^ psfmt32[((i+mid)%size).idxof] 318 ^ psfmt32[((i+size-1)%size).idxof]); 319 psfmt32[((i+mid)%size).idxof] += r; 320 r += seed[j] + i; 321 psfmt32[((i+mid+lag)%size).idxof] += r; 322 psfmt32[i.idxof] = r; 323 i = (i + 1) % size; 324 } 325 foreach (j; count.min(seed.length)..count) 326 { 327 r = func1( 328 psfmt32[i.idxof] 329 ^ psfmt32[((i+mid)%size).idxof] 330 ^ psfmt32[((i+size-1)%size).idxof]); 331 psfmt32[((i+mid)%size).idxof] += r; 332 r += i; 333 psfmt32[((i+mid+lag)%size).idxof] += r; 334 psfmt32[i.idxof] = r; 335 i = (i + 1) % size; 336 } 337 foreach (j; 0..size) 338 { 339 r = func2( 340 psfmt32[i.idxof] 341 + psfmt32[((i+mid)%size).idxof] 342 + psfmt32[((i+size-1)%size).idxof]); 343 psfmt32[((i+mid)%size).idxof] ^= r; 344 r -= i; 345 psfmt32[((i+mid+lag)%size).idxof] ^= r; 346 psfmt32[i.idxof] = r; 347 i = (i + 1) % size; 348 } 349 idx = size; 350 assureLongPeriod; 351 generateAll; 352 } 353 /// input range interface. 354 enum empty = false; 355 /// ditto 356 ulong front() @property 357 { 358 assert (idx % 2 == 0, "out of alignment"); 359 ulong* psfmt64 = &(state[0].u64[0]); 360 return psfmt64[idx / 2]; 361 } 362 /// ditto 363 void popFront() 364 { 365 idx += 2; 366 if (size <= idx) // in current implementation, 367 generateAll; // this is necessary when only popFront is called repeatedly. 368 } 369 version (Big32){} else 370 T frontPop(T : ulong)()/// 371 { 372 auto ret = front; 373 popFront; 374 return ret; 375 } 376 version (Big64){} else 377 T frontPop(T : uint)()/// 378 { 379 uint* psfmt32 = &(state[0].u32[0]); 380 immutable r = psfmt32[idx]; 381 idx += 1; 382 if (size <= idx) 383 generateAll; 384 return r; 385 } 386 /// 387 T next(T)(size_t size) 388 if (is (T == ulong[]) || is (T == uint[])) 389 { 390 return cast(T)fill(cast(ucent_[])(new T(size))); 391 } 392 private auto fill(ucent_[] array) 393 in 394 { 395 assert (n <= array.length); 396 assert (idx % 4 == 0, "out of alignment"); 397 } 398 body 399 { 400 immutable size_t 401 index = idx / 4, 402 size = array.length; 403 immutable size_t 404 prepared = n-index; 405 array[0..prepared] = state[index..$]; 406 // array[prepared-j] == state[n-j] 407 // array[i-n] == state[i-prepared] 408 // array[i] == state[i+n-prepared] == state[i+index] 409 if (prepared <= 0) 410 { 411 recursion( 412 array[0], state[0], 413 state[m], 414 state[index-2], array[index-1]); 415 } 416 if (prepared <= 1) 417 { 418 recursion( 419 array[1], state[1-prepared], 420 state[1+m-prepared], 421 state[index-1], array[0]); 422 } 423 if (prepared <= n-m-1) 424 foreach (i; prepared.max(2)..n-m) 425 { 426 recursion( 427 array[i], state[i-prepared], 428 state[i-(prepared-m)], 429 array[i-2], array[i-1]); 430 } 431 foreach (i; prepared.max(n-m)..n) 432 { 433 recursion( 434 array[i], state[i-prepared], 435 array[i-(n-m)], 436 array[i-2], array[i-1]); 437 } 438 foreach (i; n .. size) 439 { 440 recursion( 441 array[i], array[i-n], 442 array[i-(n-m)], 443 array[i-2], array[i-1]); 444 } 445 // array[$-n+i] == state[i-n] 446 // array[$+i] == state[i] 447 recursion( 448 state[0], array[$-n], 449 array[$-(n-m)], 450 array[$-2], array[$-1]); 451 recursion( 452 state[1], array[$+1-n], 453 array[$+1-(n-m)], 454 array[$-1], state[0]); 455 foreach (i; 2..(n-m)) 456 { 457 recursion( 458 state[i], array[$+i-n], 459 array[$+i-(n-m)], 460 state[i-2], state[i-1]); 461 } 462 foreach (i; (n-m)..n) 463 { 464 recursion( 465 state[i], array[$+i-n], 466 state[i-(n-m)], 467 state[i-2], state[i-1]); 468 } 469 return array; 470 } 471 private void generateAll() 472 { 473 recursion( 474 state[0], state[0], 475 state[0+m], 476 state[n - 2], state[n - 1]); 477 recursion( 478 state[1], state[1], 479 state[1+m], 480 state[n - 1], state[0]); 481 foreach (i; 2..n-m) 482 { 483 recursion( 484 state[i], state[i], 485 state[i+m], 486 state[i - 2], state[i - 1]); 487 } 488 foreach (i; n-m..n) 489 { 490 recursion( 491 state[i], state[i], 492 state[i+m-n], 493 state[i - 2], state[i - 1]); 494 } 495 idx = 0; 496 } 497 int idx; 498 /// returns true if modification is done 499 bool assureLongPeriod() 500 { 501 uint inner; 502 uint* psfmt32 = &(state[0].u32[0]); 503 static foreach (i; 0..4) 504 inner ^= psfmt32[idxof!i] & parity[i]; 505 foreach (i; [16, 8, 4, 2, 1]) 506 inner ^= inner >> i; 507 inner &= 1; 508 if (inner == 1) 509 return false; 510 foreach (i; 0..4) 511 { 512 uint working = 1; 513 foreach (j; 0..32) 514 { 515 if (working & parity[i]) 516 { 517 psfmt32[i.idxof] ^= working; 518 return true; 519 } 520 working <<= 1; 521 } 522 } 523 assert (false, "unreachable?"); 524 } 525 } 526 /// 527 unittest 528 { 529 auto ct = SFMT19937(13579u); 530 RunTimeSFMT rt; 531 rt.mexp(ct.mersenneExponent); 532 rt.m = ct.m; 533 rt.shifts = ct.shifts; 534 rt.masks = ct.masks; 535 rt.parity = ct.parity; 536 rt.seed(13579u); 537 foreach (i; 0..1000) 538 assert (ct.frontPop!ulong == rt.frontPop!ulong); 539 stderr.writeln("checked compile time and run time"); 540 } 541 542 version (BigEndian) 543 { 544 pragma (msg, "not tested"); 545 version (Only64bit) 546 version = Big64; 547 else version (With32bit) 548 version = Big32; 549 else static assert (false, "Specify Only64bit or With32bit in BigEndian environment"); 550 } 551 version (LittleEndian) 552 { 553 pragma (msg, "supported"); 554 } 555 556 version (Only64bit) 557 version (With32bit) 558 static assert (false, "Specify (at most) one of Only64bit or With32bit"); 559