diff --git a/src/cmd/sample.rs b/src/cmd/sample.rs index 2ceaad604..76a74c3a5 100644 --- a/src/cmd/sample.rs +++ b/src/cmd/sample.rs @@ -108,6 +108,15 @@ enum RngKind { Cryptosecure, } +// rng helper function +fn create_rng(seed: Option) -> T { + if let Some(seed) = seed { + T::seed_from_u64(seed) // DevSkim: ignore DS148264 + } else { + T::from_os_rng() + } +} + pub fn run(argv: &[&str]) -> CliResult<()> { let mut args: Args = util::get_args(USAGE, argv)?; @@ -121,28 +130,23 @@ pub fn run(argv: &[&str]) -> CliResult<()> { let temp_download = NamedTempFile::new()?; args.arg_input = match args.arg_input { - Some(uri) => { - if Url::parse(&uri).is_ok() && uri.starts_with("http") { - let max_size_bytes = args.flag_max_size.map(|mb| mb * 1024 * 1024); - - // its a remote file, download it first - let future = util::download_file( - &uri, - temp_download.path().to_path_buf(), - false, - args.flag_user_agent, - args.flag_timeout, - max_size_bytes, - ); - tokio::runtime::Runtime::new()?.block_on(future)?; - // safety: temp_download is a NamedTempFile, so we know can unwrap.to_string - let temp_download_path = temp_download.path().to_str().unwrap().to_string(); - Some(temp_download_path) - } else { - // its a local file - Some(uri) - } + Some(uri) if Url::parse(&uri).is_ok() && uri.starts_with("http") => { + let max_size_bytes = args.flag_max_size.map(|mb| mb * 1024 * 1024); + + // its a remote file, download it first + let future = util::download_file( + &uri, + temp_download.path().to_path_buf(), + false, + args.flag_user_agent, + args.flag_timeout, + max_size_bytes, + ); + tokio::runtime::Runtime::new()?.block_on(future)?; + // safety: temp_download is a NamedTempFile, so we know can unwrap.to_string + Some(temp_download.path().to_str().unwrap().to_string()) }, + Some(uri) => Some(uri), // local file None => None, }; @@ -174,10 +178,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> { "doing standard sample_random_access. Seed: {:?}", args.flag_seed ); - let mut rng: StdRng = match args.flag_seed { - None => StdRng::from_os_rng(), - Some(seed) => StdRng::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; + let mut rng = create_rng::(args.flag_seed); SliceRandom::shuffle(&mut *all_indices, &mut rng); }, RngKind::Faster => { @@ -185,10 +186,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> { "doing --faster sample_random_access. Seed: {:?}", args.flag_seed ); - let mut rng = match args.flag_seed { - None => Xoshiro256Plus::from_os_rng(), - Some(seed) => Xoshiro256Plus::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; + let mut rng = create_rng::(args.flag_seed); SliceRandom::shuffle(&mut *all_indices, &mut rng); }, RngKind::Cryptosecure => { @@ -196,10 +194,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> { "doing --cryptosecure sample_random_access. Seed: {:?}", args.flag_seed ); - let mut rng = match args.flag_seed { - None => Hc128Rng::from_os_rng(), - Some(seed) => Hc128Rng::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; + let mut rng = create_rng::(args.flag_seed); SliceRandom::shuffle(&mut *all_indices, &mut rng); }, } @@ -242,52 +237,42 @@ fn sample_reservoir( reservoir.push(row?); } + // safety: we know that reservoir has at least sample_size elements + // because we push sample_size elements into it in the loop above match *rng_kind { RngKind::Standard => { - log::info!("doing standard sample_random_access. Seed: {seed:?}",); - let mut rng: StdRng = match seed { - None => StdRng::from_os_rng(), - Some(seed) => StdRng::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; - + log::info!("doing standard sample_random_access. Seed: {seed:?}"); + let mut rng = create_rng::(seed); let mut random: usize; - // Now do the sampling. + for (i, row) in records { random = rng.random_range(0..=i); if random < sample_size as usize { - reservoir[random] = row?; + unsafe { *reservoir.get_unchecked_mut(random) = row? }; } } }, RngKind::Faster => { log::info!("doing --faster sample_random_access. Seed: {seed:?}",); - - let mut rng = match seed { - None => Xoshiro256Plus::from_os_rng(), - Some(seed) => Xoshiro256Plus::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; - + let mut rng = create_rng::(seed); let mut random: usize; - // Now do the sampling. + for (i, row) in records { random = rng.random_range(0..=i); if random < sample_size as usize { - reservoir[random] = row?; + unsafe { *reservoir.get_unchecked_mut(random) = row? }; } } }, RngKind::Cryptosecure => { log::info!("doing --cryptosecure sample_random_access. Seed: {seed:?}",); - - let mut rng: Hc128Rng = match seed { - None => Hc128Rng::from_os_rng(), - Some(seed) => Hc128Rng::seed_from_u64(seed), // DevSkim: ignore DS148264 - }; + let mut rng = create_rng::(seed); + let mut random: usize; for (i, row) in records { - let random = rng.random_range(0..=i); + random = rng.random_range(0..=i); if random < sample_size as usize { - reservoir[random] = row?; + unsafe { *reservoir.get_unchecked_mut(random) = row? }; } } },