From ca4bfc02f9ccd82e7e9212d0220be4d1c9175370 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Wed, 30 Oct 2024 16:13:48 +0000 Subject: [PATCH] Update seed handling in random generator classes Refactored seed type from uint64_t to seed_int_t for consistency across bfloat16, half, and standard random generators. Adjusted assignment logic to use vector assignment instead of static casting. --- scalars/src/random/standard_random_generator.h | 13 +++++++------ .../src/types/bfloat16/bfloat16_random_generator.h | 4 +++- scalars/src/types/half/half_random_generator.h | 6 ++++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/scalars/src/random/standard_random_generator.h b/scalars/src/random/standard_random_generator.h index e5c86e3d8..fc08e75fe 100644 --- a/scalars/src/random/standard_random_generator.h +++ b/scalars/src/random/standard_random_generator.h @@ -85,8 +85,8 @@ class StandardRandomGenerator : public RandomGenerator StandardRandomGenerator& operator=(StandardRandomGenerator&&) noexcept = delete; - void set_seed(Slice seed_data) override; - rpy::Vec get_seed() const override; + void set_seed(Slice seed_data) override; + rpy::Vec get_seed() const override; string get_type() const override; string get_state() const override; @@ -117,7 +117,7 @@ string StandardRandomGenerator::get_type() const template StandardRandomGenerator::StandardRandomGenerator( const ScalarType* stype, - Slice seed + Slice seed ) : RandomGenerator(stype) { @@ -137,7 +137,8 @@ StandardRandomGenerator::StandardRandomGenerator( continue_bits -= so_rd_int; } } else { - m_seed = static_cast>(seed); + // m_seed = static_cast>(seed); + m_seed.assign(seed.begin(), seed.end()); } m_generator = BitGenerator(m_seed[0]); @@ -145,7 +146,7 @@ StandardRandomGenerator::StandardRandomGenerator( template void StandardRandomGenerator::set_seed( - Slice seed_data + Slice seed_data ) { RPY_CHECK(seed_data.size() >= 1); @@ -166,7 +167,7 @@ void StandardRandomGenerator::set_state( } template -rpy::Vec +rpy::Vec StandardRandomGenerator::get_seed() const { return {m_seed[0]}; diff --git a/scalars/src/types/bfloat16/bfloat16_random_generator.h b/scalars/src/types/bfloat16/bfloat16_random_generator.h index 18f9dc23f..b7ed1fa32 100644 --- a/scalars/src/types/bfloat16/bfloat16_random_generator.h +++ b/scalars/src/types/bfloat16/bfloat16_random_generator.h @@ -62,7 +62,9 @@ StandardRandomGenerator::StandardRandomGenerator( s |= static_cast(dev()); continue_bits -= so_rd_int; } - } else { m_seed = static_cast>(seed); } + } else { + m_seed.assign(seed.begin(), seed.end()); + } m_generator = BitGenerator(m_seed[0]); } diff --git a/scalars/src/types/half/half_random_generator.h b/scalars/src/types/half/half_random_generator.h index 80013d911..dc6556d19 100644 --- a/scalars/src/types/half/half_random_generator.h +++ b/scalars/src/types/half/half_random_generator.h @@ -20,7 +20,7 @@ class StandardRandomGenerator : public RandomGenerator using scalar_type = half; using bit_generator = BitGenerator; - rpy::Vec m_seed; + rpy::Vec m_seed; mutable BitGenerator m_generator; mutable std::mutex m_lock; @@ -62,7 +62,9 @@ StandardRandomGenerator::StandardRandomGenerator( s |= static_cast(dev()); continue_bits -= so_rd_int; } - } else { m_seed = static_cast>(seed); } + } else { + m_seed.assign(seed.begin(), seed.end()); + } m_generator = BitGenerator(m_seed[0]); }