// Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/stream.h" namespace mlx::core::random { class KeySequence { public: explicit KeySequence(uint64_t seed); void seed(uint64_t seed); array next(); // static default static KeySequence& default_() { static KeySequence ks(0); return ks; } private: array key_; }; /** Get a PRNG key from a seed. */ array key(uint64_t seed); /** Seed the default PRNG key. */ void seed(uint64_t seed); /** Generate an array with type uint32 filled with random bits. */ array bits( const std::vector& shape, int width, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array bits( const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return bits(shape, 4, key, s); } /** Split the rng key into a pair of keys. */ std::pair split(const array& key, StreamOrDevice s = {}); /** Split the rng key into `num` keys. */ array split(const array& key, int num, StreamOrDevice s = {}); /** Generate uniform random numbers between low and high. */ array uniform( const array& low, const array& high, const std::vector& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); template array uniform( T low, U high, const std::vector& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return uniform(array(low), array(high), shape, dtype, key, to_stream(s)); } /** Generate uniform random numbers between 0 and 1. */ array uniform( const std::vector& shape, Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array uniform( const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return uniform(shape, float32, key); } /** Generate samples from the standard normal distribution. */ array normal( const std::vector& shape, Dtype dtype, const std::optional& key = std::nullopt, StreamOrDevice s = {}); inline array normal( const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return normal(shape, float32, key, s); } /** Generate integer samples uniformly at random */ array randint( const array& low, const array& high, const std::vector& shape, Dtype dtype = int32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); template array randint( T low, U high, const std::vector& shape, Dtype dtype = int32, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return randint(array(low), array(high), shape, dtype, key, to_stream(s)); }; /** Generate binary variables with probability to be true equal to p */ array bernoulli( const array& p, const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array bernoulli( const array& p, const std::optional& key = std::nullopt, StreamOrDevice s = {}); template array bernoulli( T p, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return bernoulli(array(p), key, s); }; template array bernoulli( T p, const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}) { return bernoulli(array(p), shape, key, s); }; array bernoulli( const std::optional& key = std::nullopt, StreamOrDevice s = {}); array truncated_normal( const array& lower, const array& upper, const std::vector& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array truncated_normal( const array& lower, const array& upper, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array gumbel( const std::vector& shape, Dtype dtype = float32, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array categorical( const array& logits, int axis, const std::vector& shape, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array categorical( const array& logits_, int axis, int num_samples, const std::optional& key = std::nullopt, StreamOrDevice s = {}); array categorical( const array& logits, int axis = -1, const std::optional& key = std::nullopt, StreamOrDevice s = {}); } // namespace mlx::core::random