Skip to content

Commit

Permalink
feat: add shards_base_url and write shards atomically (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Jun 11, 2024
1 parent 9a33a80 commit ccba523
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 27 deletions.
11 changes: 9 additions & 2 deletions crates/rattler_conda_types/src/repo_data/sharded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use fxhash::{FxHashMap, FxHashSet};
use rattler_digest::Sha256Hash;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::PackageRecord;

Expand All @@ -24,7 +23,15 @@ pub struct ShardedSubdirInfo {

/// The base url of the subdirectory. This is the location where the actual
/// packages are stored.
pub base_url: Url,
///
/// This is used to construct the full url of the packages.
pub base_url: String,

/// The base url of the individual shards. This is the location where the actual
/// packages are stored.
///
/// This is used to construct the full url of the shard.
pub shards_base_url: String,
}

/// An individual shard that contains repodata for a single package name.
Expand Down
119 changes: 94 additions & 25 deletions crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
use crate::gateway::error::SubdirNotFoundError;
use crate::reporter::ResponseReporterExt;
use crate::Reporter;
use crate::{fetch::FetchRepoDataError, gateway::subdir::SubdirClient, GatewayError};
use futures::TryFutureExt;
use http::header::CACHE_CONTROL;
use http::{HeaderValue, StatusCode};
use std::{borrow::Cow, io::Write, path::PathBuf, sync::Arc};

use http::{header::CACHE_CONTROL, HeaderValue, StatusCode};
use rattler_conda_types::{Channel, PackageName, RepoDataRecord, Shard, ShardedRepodata};
use reqwest_middleware::ClientWithMiddleware;
use simple_spawn_blocking::tokio::run_blocking_task;
use std::{borrow::Cow, path::PathBuf, sync::Arc};
use token::TokenClient;
use url::Url;

use crate::{
fetch::FetchRepoDataError,
gateway::{error::SubdirNotFoundError, subdir::SubdirClient},
reporter::ResponseReporterExt,
GatewayError, Reporter,
};

mod index;
mod token;

pub struct ShardedSubdir {
channel: Channel,
client: ClientWithMiddleware,
shard_base_url: Url,
shards_base_url: Url,
package_base_url: Url,
token_client: TokenClient,
sharded_repodata: ShardedRepodata,
cache_dir: PathBuf,
Expand All @@ -35,21 +38,21 @@ impl ShardedSubdir {
reporter: Option<&dyn Reporter>,
) -> Result<Self, GatewayError> {
// Construct the base url for the shards (e.g. `<channel>/<subdir>`).
let shard_base_url = add_trailing_slash(channel.base_url())
let index_base_url = add_trailing_slash(channel.base_url())
.join(&format!("{subdir}/"))
.expect("invalid subdir url");

// Construct a token client to fetch the token when we need it.
let token_client = TokenClient::new(
client.clone(),
shard_base_url.clone(),
index_base_url.clone(),
concurrent_requests_semaphore.clone(),
);

// Fetch the shard index
let sharded_repodata = index::fetch_index(
client.clone(),
&shard_base_url,
&index_base_url,
&token_client,
&cache_dir,
concurrent_requests_semaphore.clone(),
Expand All @@ -67,6 +70,26 @@ impl ShardedSubdir {
e => e,
})?;

// Convert the URLs
let shards_base_url = Url::options()
.base_url(Some(&index_base_url))
.parse(&sharded_repodata.info.shards_base_url)
.map_err(|_| {
GatewayError::Generic(format!(
"shard index contains invalid `shards_base_url`: {}",
&sharded_repodata.info.shards_base_url
))
})?;
let package_base_url = Url::options()
.base_url(Some(&index_base_url))
.parse(&sharded_repodata.info.base_url)
.map_err(|_| {
GatewayError::Generic(format!(
"shard index contains invalid `base_url`: {}",
&sharded_repodata.info.base_url
))
})?;

// Determine the cache directory and make sure it exists.
let cache_dir = cache_dir.join("shards-v1");
tokio::fs::create_dir_all(&cache_dir)
Expand All @@ -76,7 +99,8 @@ impl ShardedSubdir {
Ok(Self {
channel,
client,
shard_base_url,
shards_base_url: add_trailing_slash(&shards_base_url).into_owned(),
package_base_url: add_trailing_slash(&package_base_url).into_owned(),
token_client,
sharded_repodata,
cache_dir,
Expand Down Expand Up @@ -107,7 +131,7 @@ impl SubdirClient for ShardedSubdir {
return parse_records(
cached_bytes,
self.channel.canonical_name(),
self.sharded_repodata.info.base_url.clone(),
self.package_base_url.clone(),
)
.await
.map(Arc::from);
Expand All @@ -122,11 +146,9 @@ impl SubdirClient for ShardedSubdir {
let token = self.token_client.get_token(reporter).await?;

// Download the shard
let shard_url = token
.shard_base_url
.as_ref()
.unwrap_or(&self.shard_base_url)
.join(&format!("shards/{shard:x}.msgpack.zst"))
let shard_url = self
.shards_base_url
.join(&format!("{shard:x}.msgpack.zst"))
.expect("invalid shard url");

let mut shard_request = self
Expand Down Expand Up @@ -162,15 +184,13 @@ impl SubdirClient for ShardedSubdir {
let shard_bytes = decode_zst_bytes_async(shard_bytes).await?;

// Create a future to write the cached bytes to disk
let write_to_cache_fut = tokio::fs::write(&shard_cache_path, shard_bytes.clone())
.map_err(FetchRepoDataError::IoError)
.map_err(GatewayError::from);
let write_to_cache_fut = write_shard_to_cache(shard_cache_path, shard_bytes.clone());

// Create a future to parse the records from the shard
let parse_records_fut = parse_records(
shard_bytes,
self.channel.canonical_name(),
self.sharded_repodata.info.base_url.clone(),
self.package_base_url.clone(),
);

// Await both futures concurrently.
Expand All @@ -180,6 +200,54 @@ impl SubdirClient for ShardedSubdir {
}
}

/// Atomically writes the shard bytes to the cache.
async fn write_shard_to_cache(
shard_cache_path: PathBuf,
shard_bytes: Vec<u8>,
) -> Result<(), GatewayError> {
run_blocking_task(move || {
let shard_cache_parent_path = shard_cache_path
.parent()
.expect("file path must have a parent");
let mut temp_file = tempfile::Builder::new()
.tempfile_in(
shard_cache_path
.parent()
.expect("file path must have a parent"),
)
.map_err(|e| {
GatewayError::IoError(
format!(
"failed to create temporary file to write shard in {}",
shard_cache_parent_path.display()
),
e,
)
})?;
temp_file.write_all(&shard_bytes).map_err(|e| {
GatewayError::IoError(
format!(
"failed to write shard to temporary file in {}",
shard_cache_parent_path.display()
),
e,
)
})?;
match temp_file.persist(&shard_cache_path) {
Ok(_) => Ok(()),
Err(e) if e.error.kind() == std::io::ErrorKind::AlreadyExists => {
// The file already exists, we don't need to write it again.
Ok(())
}
Err(e) => Err(GatewayError::IoError(
format!("failed to persist shard to {}", shard_cache_path.display()),
e.error,
)),
}
})
.await
}

async fn decode_zst_bytes_async<R: AsRef<[u8]> + Send + 'static>(
bytes: R,
) -> Result<Vec<u8>, GatewayError> {
Expand All @@ -199,14 +267,15 @@ async fn parse_records<R: AsRef<[u8]> + Send + 'static>(
base_url: Url,
) -> Result<Vec<RepoDataRecord>, GatewayError> {
run_blocking_task(move || {
// let shard = serde_json::from_slice::<Shard>(bytes.as_ref()).map_err(std::io::Error::from)?;
// let shard =
// serde_json::from_slice::<Shard>(bytes.as_ref()).
// map_err(std::io::Error::from)?;
let shard = rmp_serde::from_slice::<Shard>(bytes.as_ref())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
.map_err(FetchRepoDataError::IoError)?;
let packages =
itertools::chain(shard.packages.into_iter(), shard.conda_packages.into_iter())
.filter(|(name, _record)| !shard.removed.contains(name));
let base_url = add_trailing_slash(&base_url);
Ok(packages
.map(|(file_name, package_record)| RepoDataRecord {
url: base_url
Expand Down

0 comments on commit ccba523

Please sign in to comment.