Skip to content

Commit

Permalink
refactor: sample to use an RNG helper function
Browse files Browse the repository at this point in the history
to reduce redundant code

Also skipped unneeded bounds checking in hot loops
  • Loading branch information
jqnatividad committed Feb 6, 2025
1 parent f5ab083 commit c7f5ba1
Showing 1 changed file with 41 additions and 56 deletions.
97 changes: 41 additions & 56 deletions src/cmd/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ enum RngKind {
Cryptosecure,
}

// rng helper function
fn create_rng<T: SeedableRng>(seed: Option<u64>) -> T {
if let Some(seed) = seed {
T::seed_from_u64(seed) // DevSkim: ignore DS148264
} else {
T::from_os_rng()
}
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let mut args: Args = util::get_args(USAGE, argv)?;

Expand All @@ -121,28 +130,23 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
let temp_download = NamedTempFile::new()?;

args.arg_input = match args.arg_input {
Some(uri) => {
if Url::parse(&uri).is_ok() && uri.starts_with("http") {
let max_size_bytes = args.flag_max_size.map(|mb| mb * 1024 * 1024);

// its a remote file, download it first
let future = util::download_file(
&uri,
temp_download.path().to_path_buf(),
false,
args.flag_user_agent,
args.flag_timeout,
max_size_bytes,
);
tokio::runtime::Runtime::new()?.block_on(future)?;
// safety: temp_download is a NamedTempFile, so we know can unwrap.to_string
let temp_download_path = temp_download.path().to_str().unwrap().to_string();
Some(temp_download_path)
} else {
// its a local file
Some(uri)
}
Some(uri) if Url::parse(&uri).is_ok() && uri.starts_with("http") => {
let max_size_bytes = args.flag_max_size.map(|mb| mb * 1024 * 1024);

// its a remote file, download it first
let future = util::download_file(
&uri,
temp_download.path().to_path_buf(),
false,
args.flag_user_agent,
args.flag_timeout,
max_size_bytes,
);
tokio::runtime::Runtime::new()?.block_on(future)?;
// safety: temp_download is a NamedTempFile, so we know can unwrap.to_string
Some(temp_download.path().to_str().unwrap().to_string())
},
Some(uri) => Some(uri), // local file
None => None,
};

Expand Down Expand Up @@ -174,32 +178,23 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
"doing standard sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng: StdRng = match args.flag_seed {
None => StdRng::from_os_rng(),
Some(seed) => StdRng::seed_from_u64(seed), // DevSkim: ignore DS148264
};
let mut rng = create_rng::<StdRng>(args.flag_seed);
SliceRandom::shuffle(&mut *all_indices, &mut rng);
},
RngKind::Faster => {
log::info!(
"doing --faster sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng = match args.flag_seed {
None => Xoshiro256Plus::from_os_rng(),
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), // DevSkim: ignore DS148264
};
let mut rng = create_rng::<Xoshiro256Plus>(args.flag_seed);
SliceRandom::shuffle(&mut *all_indices, &mut rng);
},
RngKind::Cryptosecure => {
log::info!(
"doing --cryptosecure sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng = match args.flag_seed {
None => Hc128Rng::from_os_rng(),
Some(seed) => Hc128Rng::seed_from_u64(seed), // DevSkim: ignore DS148264
};
let mut rng = create_rng::<Hc128Rng>(args.flag_seed);
SliceRandom::shuffle(&mut *all_indices, &mut rng);
},
}
Expand Down Expand Up @@ -242,52 +237,42 @@ fn sample_reservoir<R: io::Read>(
reservoir.push(row?);
}

// safety: we know that reservoir has at least sample_size elements
// because we push sample_size elements into it in the loop above
match *rng_kind {
RngKind::Standard => {
log::info!("doing standard sample_random_access. Seed: {seed:?}",);
let mut rng: StdRng = match seed {
None => StdRng::from_os_rng(),
Some(seed) => StdRng::seed_from_u64(seed), // DevSkim: ignore DS148264
};

log::info!("doing standard sample_random_access. Seed: {seed:?}");
let mut rng = create_rng::<StdRng>(seed);
let mut random: usize;
// Now do the sampling.

for (i, row) in records {
random = rng.random_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
unsafe { *reservoir.get_unchecked_mut(random) = row? };
}
}
},
RngKind::Faster => {
log::info!("doing --faster sample_random_access. Seed: {seed:?}",);

let mut rng = match seed {
None => Xoshiro256Plus::from_os_rng(),
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), // DevSkim: ignore DS148264
};

let mut rng = create_rng::<Xoshiro256Plus>(seed);
let mut random: usize;
// Now do the sampling.

for (i, row) in records {
random = rng.random_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
unsafe { *reservoir.get_unchecked_mut(random) = row? };
}
}
},
RngKind::Cryptosecure => {
log::info!("doing --cryptosecure sample_random_access. Seed: {seed:?}",);

let mut rng: Hc128Rng = match seed {
None => Hc128Rng::from_os_rng(),
Some(seed) => Hc128Rng::seed_from_u64(seed), // DevSkim: ignore DS148264
};
let mut rng = create_rng::<Hc128Rng>(seed);
let mut random: usize;

for (i, row) in records {
let random = rng.random_range(0..=i);
random = rng.random_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
unsafe { *reservoir.get_unchecked_mut(random) = row? };
}
}
},
Expand Down

0 comments on commit c7f5ba1

Please sign in to comment.