Skip to content

Commit

Permalink
Update seed handling in random generator classes
Browse files Browse the repository at this point in the history
Refactored seed type from uint64_t to seed_int_t for consistency across bfloat16, half, and standard random generators. Adjusted assignment logic to use vector assignment instead of static casting.
  • Loading branch information
inakleinbottle committed Oct 30, 2024
1 parent 42ff75a commit ca4bfc0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
13 changes: 7 additions & 6 deletions scalars/src/random/standard_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class StandardRandomGenerator : public RandomGenerator
StandardRandomGenerator& operator=(StandardRandomGenerator&&) noexcept
= delete;

void set_seed(Slice<uint64_t> seed_data) override;
rpy::Vec<uint64_t> get_seed() const override;
void set_seed(Slice<seed_int_t> seed_data) override;
rpy::Vec<seed_int_t> get_seed() const override;
string get_type() const override;
string get_state() const override;

Expand Down Expand Up @@ -117,7 +117,7 @@ string StandardRandomGenerator<ScalarImpl, BitGenerator>::get_type() const
template <typename ScalarImpl, typename BitGenerator>
StandardRandomGenerator<ScalarImpl, BitGenerator>::StandardRandomGenerator(
const ScalarType* stype,
Slice<uint64_t> seed
Slice<seed_int_t> seed
)
: RandomGenerator(stype)
{
Expand All @@ -137,15 +137,16 @@ StandardRandomGenerator<ScalarImpl, BitGenerator>::StandardRandomGenerator(
continue_bits -= so_rd_int;
}
} else {
m_seed = static_cast<Vec<seed_int_t>>(seed);
// m_seed = static_cast<Vec<seed_int_t>>(seed);
m_seed.assign(seed.begin(), seed.end());
}

m_generator = BitGenerator(m_seed[0]);
}

template <typename ScalarImpl, typename BitGenerator>
void StandardRandomGenerator<ScalarImpl, BitGenerator>::set_seed(
Slice<uint64_t> seed_data
Slice<seed_int_t> seed_data
)
{
RPY_CHECK(seed_data.size() >= 1);
Expand All @@ -166,7 +167,7 @@ void StandardRandomGenerator<ScalarImpl, BitGenerator>::set_state(
}

template <typename ScalarImpl, typename BitGenerator>
rpy::Vec<uint64_t>
rpy::Vec<seed_int_t>
StandardRandomGenerator<ScalarImpl, BitGenerator>::get_seed() const
{
return {m_seed[0]};
Expand Down
4 changes: 3 additions & 1 deletion scalars/src/types/bfloat16/bfloat16_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ StandardRandomGenerator<bfloat16, BitGenerator>::StandardRandomGenerator(
s |= static_cast<seed_int_t>(dev());
continue_bits -= so_rd_int;
}
} else { m_seed = static_cast<Vec<uint64_t>>(seed); }
} else {
m_seed.assign(seed.begin(), seed.end());
}

m_generator = BitGenerator(m_seed[0]);
}
Expand Down
6 changes: 4 additions & 2 deletions scalars/src/types/half/half_random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class StandardRandomGenerator<half, BitGenerator> : public RandomGenerator
using scalar_type = half;
using bit_generator = BitGenerator;

rpy::Vec<uint64_t> m_seed;
rpy::Vec<seed_int_t> m_seed;

mutable BitGenerator m_generator;
mutable std::mutex m_lock;
Expand Down Expand Up @@ -62,7 +62,9 @@ StandardRandomGenerator<half, BitGenerator>::StandardRandomGenerator(
s |= static_cast<seed_int_t>(dev());
continue_bits -= so_rd_int;
}
} else { m_seed = static_cast<Vec<uint64_t>>(seed); }
} else {
m_seed.assign(seed.begin(), seed.end());
}

m_generator = BitGenerator(m_seed[0]);
}
Expand Down

0 comments on commit ca4bfc0

Please sign in to comment.