当使用类java.util.Random时,如何以更有效的方式(特别是在O(1)中)获得从n次调用方法nextInt()中获得的值?
例如,如果我构造一个具有特定种子值的 Random 对象,并且我想快速获取第 100,000 个"nextInt() 值"(即调用方法 nextInt() 100,000 次后获得的值),我可以做到吗?
为简单起见,假设 JDK 的版本为 1.7.06,因为可能需要知道类 Random 中某些私有字段的确切值。说到这里,我发现以下字段与随机值的计算相关:
private static final long multiplier = 0x5DEECE66DL;
private static final long addend = 0xBL;
private static final long mask = (1L << 48) - 1;
在探索了一些关于随机性的知识之后,我发现随机值是使用线性同余生成器获得的。执行算法的实际方法是方法 next(int):
protected int next(int bits) {
long oldseed, nextseed;
AtomicLong seed = this.seed;
do {
oldseed = seed.get();
nextseed = (oldseed * multiplier + addend) & mask;
} while (!seed.compareAndSet(oldseed, nextseed));
return (int)(nextseed >>> (48 - bits));
}
算法的相关行是获取下一个种子值的行:
nextseed = (oldseed * multiplier + addend) & mask;
那么,更具体地说,有没有办法推广这个公式以获得"第 n 个下一个种子"值?我在这里假设在拥有它之后,我可以通过让变量"bits"为 32 来简单地获取第 n 个 int 值(方法nextInt() 只是调用 next(32) 并返回结果)。
提前致谢
PS:也许这是一个更适合数学交流的问题?
你可以在O(log N)
时间内完成。从 s(0)
开始,如果我们暂时忽略模数 (248),我们可以看到(使用 m
和 a
作为 multiplier
和 addend
的简写)
s(1) = s(0) * m + a
s(2) = s(1) * m + a = s(0) * m² + (m + 1) * a
s(3) = s(2) * m + a = s(0) * m³ + (m² + m + 1) * a
...
s(N) = s(0) * m^N + (m^(N-1) + ... + m + 1) * a
现在,可以通过重复平方的模幂轻松地以O(log N)
步长计算m^N (mod 2^48)
。
另一部分有点复杂。暂时再次忽略模量,几何和为
(m^N - 1) / (m - 1)
使计算这个模2^48
有点不平凡的是m - 1
不是模的互质。然而,由于
m = 0x5DEECE66DL
m-1
和模的最大公约数是4,(m-1)/4
有一个模逆inv
模2^48
。让
c = (m^N - 1) (mod 4*2^48)
然后
(c / 4) * inv ≡ (m^N - 1) / (m - 1) (mod 2^48)
所以
- 计算
M ≡ m^N (mod 2^50)
- 计算
inv
要获得
s(N) ≡ s(0)*M + ((M - 1)/4)*inv*a (mod 2^48)
我已经接受了Daniel Fischer的答案,因为它是正确的,并给出了一般的解决方案。使用 Daniel 的答案,这里有一个带有 java 代码的具体示例,它显示了公式的基本实现(我广泛使用了类 BigInteger,所以它可能不是最佳的,但我确认了实际调用方法 nextInt() N 次的基本方式的显着加速):
import java.math.BigInteger;
import java.util.Random;
public class RandomNthNextInt {
// copied from java.util.Random =========================
private static final long multiplier = 0x5DEECE66DL;
private static final long addend = 0xBL;
private static final long mask = (1L << 48) - 1;
private static long initialScramble(long seed) {
return (seed ^ multiplier) & mask;
}
private static int getNextInt(long nextSeed) {
return (int)(nextSeed >>> (48 - 32));
}
// ======================================================
private static final BigInteger mod = BigInteger.valueOf(mask + 1L);
private static final BigInteger inv = BigInteger.valueOf((multiplier - 1L) / 4L).modInverse(mod);
/**
* Returns the value obtained after calling the method {@link Random#nextInt()} {@code n} times from a
* {@link Random} object initialized with the {@code seed} value.
* <p>
* This method does not actually create any {@code Random} instance, instead it applies a direct formula which
* calculates the expected value in a more efficient way (close to O(log N)).
*
* @param seed
* The initial seed value of the supposed {@code Random} object
* @param n
* The index (starting at 1) of the "nextInt() value"
* @return the nth "nextInt() value" of a {@code Random} object initialized with the given seed value
* @throws IllegalArgumentException
* If {@code n} is not positive
*/
public static long getNthNextInt(long seed, long n) {
if (n < 1L) {
throw new IllegalArgumentException("n must be positive");
}
final BigInteger seedZero = BigInteger.valueOf(initialScramble(seed));
final BigInteger nthSeed = calculateNthSeed(seedZero, n);
return getNextInt(nthSeed.longValue());
}
private static BigInteger calculateNthSeed(BigInteger seed0, long n) {
final BigInteger largeM = calculateLargeM(n);
final BigInteger largeMmin1div4 = largeM.subtract(BigInteger.ONE).divide(BigInteger.valueOf(4L));
return seed0.multiply(largeM).add(largeMmin1div4.multiply(inv).multiply(BigInteger.valueOf(addend))).mod(mod);
}
private static BigInteger calculateLargeM(long n) {
return BigInteger.valueOf(multiplier).modPow(BigInteger.valueOf(n), BigInteger.valueOf(1L << 50));
}
// =========================== Testing stuff ======================================
public static void main(String[] args) {
final long n = 100000L; // change this to test other values
final long seed = 1L; // change this to test other values
System.out.println(n + "th nextInt (formula) = " + getNthNextInt(seed, n));
System.out.println(n + "th nextInt (slow) = " + getNthNextIntSlow(seed, n));
}
private static int getNthNextIntSlow(long seed, long n) {
if (n < 1L) {
throw new IllegalArgumentException("n must be positive");
}
final Random rand = new Random(seed);
for (long eL = 0; eL < (n - 1); eL++) {
rand.nextInt();
}
return rand.nextInt();
}
}
注意:请注意方法initialScramble(long),它用于获取第一个种子值。这是类 Random 在使用特定种子初始化实例时的行为。