001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.core.util;
018
019import java.util.Objects;
020import java.util.Spliterator;
021import java.util.function.Consumer;
022import java.util.stream.Stream;
023import java.util.stream.StreamSupport;
024import org.apache.commons.rng.SplittableUniformRandomProvider;
025import org.apache.commons.rng.UniformRandomProvider;
026
027/**
028 * Utility for creating streams using a source of randomness.
029 *
030 * @since 1.5
031 */
032public final class RandomStreams {
033    /** The number of bits of each random character in the seed.
034     * The generation algorithm will work if this is in the range [2, 30]. */
035    private static final int SEED_CHAR_BITS = 4;
036
037    /**
038     * A factory for creating objects using a seed and a using a source of randomness.
039     *
040     * @param <T> the object type
041     * @since 1.5
042     */
043    @FunctionalInterface
044    public interface SeededObjectFactory<T> {
045        /**
046         * Creates the object.
047         *
048         * @param seed Seed used to initialise the instance.
049         * @param source Source of randomness used to initialise the instance.
050         * @return the object
051         */
052        T create(long seed, UniformRandomProvider source);
053    }
054
055    /**
056     * Class contains only static methods.
057     */
058    private RandomStreams() {}
059
060    /**
061     * Returns a stream producing the given {@code streamSize} number of new objects
062     * generated using the supplied {@code source} of randomness and object {@code factory}.
063     *
064     * <p>A {@code long} seed is provided for each object instance using the stream position
065     * and random bits created from the supplied {@code source}.
066     *
067     * <p>The stream supports parallel execution by splitting the provided {@code source}
068     * of randomness. Consequently objects in the same position in the stream created from
069     * a sequential stream may be created from a different source of randomness than a parallel
070     * stream; it is not expected that parallel execution will create the same final
071     * collection of objects.
072     *
073     * @param <T> the object type
074     * @param streamSize Number of objects to generate.
075     * @param source A source of randomness used to initialise the new instances; this may
076     * be split to provide a source of randomness across a parallel stream.
077     * @param factory Factory to create new instances.
078     * @return a stream of objects; the stream is limited to the given {@code streamSize}.
079     * @throws IllegalArgumentException if {@code streamSize} is negative.
080     * @throws NullPointerException if {@code source} or {@code factory} is null.
081     */
082    public static <T> Stream<T> generateWithSeed(long streamSize,
083                                                 SplittableUniformRandomProvider source,
084                                                 SeededObjectFactory<T> factory) {
085        if (streamSize < 0) {
086            throw new IllegalArgumentException("Invalid stream size: " + streamSize);
087        }
088        Objects.requireNonNull(source, "source");
089        Objects.requireNonNull(factory, "factory");
090        final long seed = createSeed(source);
091        return StreamSupport
092            .stream(new SeededObjectSpliterator<>(0, streamSize, source, factory, seed), false);
093    }
094
095    /**
096     * Creates a seed to prepend to a counter. The seed is created to satisfy the following
097     * requirements:
098     * <ul>
099     * <li>The least significant bit is set</li>
100     * <li>The seed is composed of characters from an n-bit alphabet</li>
101     * <li>The character used in the least significant bits is unique</li>
102     * <li>The other characters are sampled uniformly from the remaining (n-1) characters</li>
103     * </ul>
104     *
105     * <p>The composed seed is created using {@code ((seed << shift) | count)}
106     * where the shift is applied to ensure non-overlap of the shifted seed and
107     * the count. This is achieved by ensuring the lowest 1-bit of the seed is
108     * above the highest 1-bit of the count. The shift is a multiple of n to ensure
109     * the character used in the least significant bits aligns with higher characters
110     * after a shift. As higher characters exclude the least significant character
111     * no shifted seed can duplicate previously observed composed seeds. This holds
112     * until the least significant character itself is shifted out of the composed seed.
113     *
114     * <p>The seed generation algorithm starts with a random series of bits with the lowest bit
115     * set. Any occurrences of the least significant character in the remaining characters are
116     * replaced using {@link UniformRandomProvider#nextInt()}.
117     *
118     * <p>The remaining characters will be rejected at a rate of 2<sup>-n</sup>. The
119     * character size is a compromise between a low rejection rate and the highest supported
120     * count that may receive a prepended seed.
121     *
122     * <p>The JDK's {@code java.util.random} package uses 4-bits for the character size when
123     * creating a stream of SplittableGenerator. This achieves a rejection rate
124     * of {@code 1/16}. Using this size will require 1 call to generate a {@code long} and
125     * on average 1 call to {@code nextInt(15)}. The maximum supported stream size with a unique
126     * seed per object is 2<sup>60</sup>. The algorithm here also uses a character size of 4-bits;
127     * this simplifies the implementation as there are exactly 16 characters. The algorithm is a
128     * different implementation to the JDK and creates an output seed with similar properties.
129     *
130     * @param rng Source of randomness.
131     * @return the seed
132     */
133    static long createSeed(UniformRandomProvider rng) {
134        // Initial random bits. Lowest bit must be set.
135        long bits = rng.nextLong() | 1;
136        // Mask to extract characters.
137        // Can be used to sample from (n-1) n-bit characters.
138        final long n = (1L << SEED_CHAR_BITS) - 1;
139
140        // Extract the unique character.
141        final long unique = bits & n;
142
143        // Check the rest of the characters do not match the unique character.
144        // This loop extracts the remaining characters and replaces if required.
145        // This will work if the characters do not evenly divide into 64 as we iterate
146        // over the count of remaining bits. The original order is maintained so that
147        // if the bits already satisfy the requirements they are unchanged.
148        for (int i = SEED_CHAR_BITS; i < Long.SIZE; i += SEED_CHAR_BITS) {
149            // Next character
150            long c = (bits >>> i) & n;
151            if (c == unique) {
152                // Branch frequency of 2^-bits.
153                // This code is deliberately branchless.
154                // Avoid nextInt(n) using: c = floor(n * ([0, 2^32) / 2^32))
155                // Rejection rate for non-uniformity will be negligible: 2^32 % 15 == 1
156                // so any rejection algorithm only has to exclude 1 value from nextInt().
157                c = (n * Integer.toUnsignedLong(rng.nextInt())) >>> Integer.SIZE;
158                // Ensure the sample is uniform in [0, n] excluding the unique character
159                c = (unique + c + 1) & n;
160                // Replace by masking out the current character and bitwise add the new one
161                bits = (bits & ~(n << i)) | (c << i);
162            }
163        }
164        return bits;
165    }
166
167    /**
168     * Spliterator for streams of a given object type that can be created from a seed
169     * and source of randomness. The source of randomness is splittable allowing parallel
170     * stream support.
171     *
172     * <p>The seed is mixed with the stream position to ensure each object is created using
173     * a unique seed value. As the position increases the seed is left shifted until there
174     * is no bit overlap between the seed and the position, i.e the right-most 1-bit of the seed
175     * is larger than the left-most 1-bit of the position.
176     *s
177     * @param <T> the object type
178     */
179    private static final class SeededObjectSpliterator<T>
180            implements Spliterator<T> {
181        /** Message when the consumer action is null. */
182        private static final String NULL_ACTION = "action must not be null";
183
184        /** The current position in the range. */
185        private long position;
186        /** The upper limit of the range. */
187        private final long end;
188        /** Seed used to initialise the new instances. The least significant 1-bit of
189         * the seed must be above the most significant bit of the position. This is maintained
190         * by left shift when the position is updated. */
191        private long seed;
192        /** Source of randomness used to initialise the new instances. */
193        private final SplittableUniformRandomProvider source;
194        /** Factory to create new instances. */
195        private final SeededObjectFactory<T> factory;
196
197        /**
198         * @param start Start position of the stream (inclusive).
199         * @param end Upper limit of the stream (exclusive).
200         * @param source Source of randomness used to initialise the new instances.
201         * @param factory Factory to create new instances.
202         * @param seed Seed used to initialise the instances. The least significant 1-bit of
203         * the seed must be above the most significant bit of the {@code start} position.
204         */
205        SeededObjectSpliterator(long start, long end,
206                                SplittableUniformRandomProvider source,
207                                SeededObjectFactory<T> factory,
208                                long seed) {
209            position = start;
210            this.end = end;
211            this.seed = seed;
212            this.source = source;
213            this.factory = factory;
214        }
215
216        @Override
217        public long estimateSize() {
218            return end - position;
219        }
220
221        @Override
222        public int characteristics() {
223            return SIZED | SUBSIZED | IMMUTABLE;
224        }
225
226        @Override
227        public Spliterator<T> trySplit() {
228            final long start = position;
229            final long middle = (start + end) >>> 1;
230            if (middle <= start) {
231                return null;
232            }
233            // The child spliterator can use the same seed as the position does not overlap
234            final SeededObjectSpliterator<T> s =
235                new SeededObjectSpliterator<>(start, middle, source.split(), factory, seed);
236            // Since the position has increased ensure the seed does not overlap
237            position = middle;
238            while (seed != 0 && Long.compareUnsigned(Long.lowestOneBit(seed), middle) <= 0) {
239                seed <<= SEED_CHAR_BITS;
240            }
241            return s;
242        }
243
244        @Override
245        public boolean tryAdvance(Consumer<? super T> action) {
246            Objects.requireNonNull(action, NULL_ACTION);
247            final long pos = position;
248            if (pos < end) {
249                // Advance before exceptions from the action are relayed to the caller
250                position = pos + 1;
251                action.accept(factory.create(seed | pos, source));
252                // If the position overlaps the seed, shift it by 1 character
253                if ((position & seed) != 0) {
254                    seed <<= SEED_CHAR_BITS;
255                }
256                return true;
257            }
258            return false;
259        }
260
261        @Override
262        public void forEachRemaining(Consumer<? super T> action) {
263            Objects.requireNonNull(action, NULL_ACTION);
264            long pos = position;
265            final long last = end;
266            if (pos < last) {
267                // Ensure forEachRemaining is called only once
268                position = last;
269                final SplittableUniformRandomProvider s = source;
270                final SeededObjectFactory<T> f = factory;
271                do {
272                    action.accept(f.create(seed | pos, s));
273                    pos++;
274                    // If the position overlaps the seed, shift it by 1 character
275                    if ((pos & seed) != 0) {
276                        seed <<= SEED_CHAR_BITS;
277                    }
278                } while (pos < last);
279            }
280        }
281    }
282}