001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.rng.simple;
018
019import java.util.EnumMap;
020import java.util.Map;
021import java.util.concurrent.locks.ReentrantLock;
022import org.apache.commons.rng.UniformRandomProvider;
023
024/**
025 * This class provides a thread-local {@link UniformRandomProvider}.
026 *
027 * <p>The {@link UniformRandomProvider} is created once-per-thread using the default
028 * construction method {@link RandomSource#create()}.
029 *
030 * <p>Example:</p>
031 * <pre><code>
032 * import org.apache.commons.rng.simple.RandomSource;
033 * import org.apache.commons.rng.simple.ThreadLocalRandomSource;
034 * import org.apache.commons.rng.sampling.distribution.PoissonSampler;
035 *
036 * // Access a thread-safe random number generator
037 * UniformRandomProvider rng = ThreadLocalRandomSource.current(RandomSource.SPLIT_MIX_64);
038 *
039 * // One-time Poisson sample
040 * double mean = 12.3;
041 * int counts = PoissonSampler.of(rng, mean).sample();
042 * </code></pre>
043 *
044 * <p>Note if the {@link RandomSource} requires additional arguments then it is not
045 * supported. The same can be achieved using:</p>
046 *
047 * <pre><code>
048 * import org.apache.commons.rng.simple.RandomSource;
049 * import org.apache.commons.rng.sampling.distribution.PoissonSampler;
050 *
051 * // Provide a thread-safe random number generator with data arguments
052 * private static ThreadLocal&lt;UniformRandomProvider&gt; rng =
053 *     new ThreadLocal&lt;UniformRandomProvider&gt;() {
054 *         &#64;Override
055 *         protected UniformRandomProvider initialValue() {
056 *             return RandomSource.TWO_CMRES_SELECT.create(null, 3, 4);
057 *         }
058 *     };
059 *
060 * // One-time Poisson sample using a thread-safe random number generator
061 * double mean = 12.3;
062 * int counts = PoissonSampler.of(rng.get(), mean).sample();
063 * </code></pre>
064 *
065 * @since 1.3
066 */
067public final class ThreadLocalRandomSource {
068    /**
069     * A map containing the {@link ThreadLocal} instance for each {@link RandomSource}.
070     *
071     * <p>This should only be modified to create new instances in a synchronized block.
072     */
073    private static final Map<RandomSource, ThreadLocal<UniformRandomProvider>> SOURCES =
074        new EnumMap<>(RandomSource.class);
075    /** An object to use for synchonization. */
076    private static final ReentrantLock LOCK = new ReentrantLock();
077
078    /** No public construction. */
079    private ThreadLocalRandomSource() {}
080
081    /**
082     * Extend the {@link ThreadLocal} to allow creation of the desired {@link RandomSource}.
083     */
084    private static class ThreadLocalRng extends ThreadLocal<UniformRandomProvider> {
085        /** The source. */
086        private final RandomSource source;
087
088        /**
089         * Create a new instance.
090         *
091         * @param source the source
092         */
093        ThreadLocalRng(RandomSource source) {
094            this.source = source;
095        }
096
097        @Override
098        protected UniformRandomProvider initialValue() {
099            // Create with the default seed generation method
100            return source.create();
101        }
102    }
103
104    /**
105     * Returns the current thread's copy of the given {@code source}. If there is no
106     * value for the current thread, it is first initialized to the value returned
107     * by {@link RandomSource#create()}.
108     *
109     * <p>Note if the {@code source} requires additional arguments then it is not
110     * supported.
111     *
112     * @param source the source
113     * @return the current thread's value of the {@code source}.
114     * @throws IllegalArgumentException if the source is null or the source requires arguments
115     */
116    public static UniformRandomProvider current(RandomSource source) {
117        ThreadLocal<UniformRandomProvider> rng = SOURCES.get(source);
118        // Implement double-checked locking:
119        // https://en.wikipedia.org/wiki/Double-checked_locking#Usage_in_Java
120        if (rng == null) {
121            // Do the checks on the source here since it is an edge case
122            // and the EnumMap handles null (returning null).
123            if (source == null) {
124                throw new IllegalArgumentException("Random source is null");
125            }
126
127            try {
128                LOCK.lock();
129                rng = SOURCES.computeIfAbsent(source, ThreadLocalRng::new);
130            } finally {
131                LOCK.unlock();
132            }
133        }
134        return rng.get();
135    }
136}