diff --git a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs index e9affbb42..65419292e 100644 --- a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs +++ b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs @@ -1,20 +1,28 @@ -use crate::fetch::CacheAction; -use rattler_conda_types::Channel; use std::collections::HashMap; -/// Describes additional properties that influence how the gateway fetches repodata for a specific -/// channel. +use rattler_conda_types::Channel; + +use crate::fetch::CacheAction; + +/// Describes additional properties that influence how the gateway fetches +/// repodata for a specific channel. #[derive(Debug, Clone)] pub struct SourceConfig { - /// When enabled repodata can be fetched incrementally using JLAP (defaults to true) + /// When enabled repodata can be fetched incrementally using JLAP (defaults + /// to true) pub jlap_enabled: bool, - /// When enabled, the zstd variant will be used if available (defaults to true) + /// When enabled, the zstd variant will be used if available (defaults to + /// true) pub zstd_enabled: bool, - /// When enabled, the bz2 variant will be used if available (defaults to true) + /// When enabled, the bz2 variant will be used if available (defaults to + /// true) pub bz2_enabled: bool, + /// When enabled, sharded repodata will be used if available. + pub sharded_enabled: bool, + /// Describes fetching repodata from a channel should interact with any /// caches. pub cache_action: CacheAction, @@ -26,6 +34,7 @@ impl Default for SourceConfig { jlap_enabled: true, zstd_enabled: true, bz2_enabled: true, + sharded_enabled: false, cache_action: CacheAction::default(), } } @@ -34,17 +43,19 @@ impl Default for SourceConfig { /// Describes additional information for fetching channels. #[derive(Debug, Default)] pub struct ChannelConfig { - /// The default source configuration. If a channel does not have a specific source configuration - /// this configuration will be used. + /// The default source configuration. If a channel does not have a specific + /// source configuration this configuration will be used. pub default: SourceConfig, - /// Describes per channel properties that influence how the gateway fetches repodata. + /// Describes per channel properties that influence how the gateway fetches + /// repodata. pub per_channel: HashMap, } impl ChannelConfig { - /// Returns the source configuration for the given channel. If the channel does not have a - /// specific source configuration the default source configuration will be returned. + /// Returns the source configuration for the given channel. If the channel + /// does not have a specific source configuration the default source + /// configuration will be returned. pub fn get(&self, channel: &Channel) -> &SourceConfig { self.per_channel.get(channel).unwrap_or(&self.default) } diff --git a/crates/rattler_repodata_gateway/src/gateway/error.rs b/crates/rattler_repodata_gateway/src/gateway/error.rs index af421daa8..0a2d8e053 100644 --- a/crates/rattler_repodata_gateway/src/gateway/error.rs +++ b/crates/rattler_repodata_gateway/src/gateway/error.rs @@ -47,6 +47,9 @@ pub enum GatewayError { #[error(transparent)] InvalidPackageName(#[from] InvalidPackageNameError), + + #[error("{0}")] + CacheError(String), } impl From for GatewayError { diff --git a/crates/rattler_repodata_gateway/src/gateway/mod.rs b/crates/rattler_repodata_gateway/src/gateway/mod.rs index d56aa5140..c768acd06 100644 --- a/crates/rattler_repodata_gateway/src/gateway/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/mod.rs @@ -286,32 +286,37 @@ impl GatewayInner { "unsupported file based url".to_string(), )); } - } else if supports_sharded_repodata(&url) { - sharded_subdir::ShardedSubdir::new( - channel.clone(), - platform.to_string(), - self.client.clone(), - self.cache.clone(), - self.concurrent_requests_semaphore.clone(), - reporter.as_deref(), - ) - .await - .map(SubdirData::from_client) } else if url.scheme() == "http" || url.scheme() == "https" || url.scheme() == "gcs" || url.scheme() == "oci" { - remote_subdir::RemoteSubdirClient::new( - channel.clone(), - platform, - self.client.clone(), - self.cache.clone(), - self.channel_config.get(channel).clone(), - reporter, - ) - .await - .map(SubdirData::from_client) + // Check if the channel supports sharded repodata + let source_config = self.channel_config.get(channel); + if self.channel_config.get(channel).sharded_enabled || force_sharded_repodata(&url) { + sharded_subdir::ShardedSubdir::new( + channel.clone(), + platform.to_string(), + self.client.clone(), + self.cache.clone(), + source_config.cache_action, + self.concurrent_requests_semaphore.clone(), + reporter.as_deref(), + ) + .await + .map(SubdirData::from_client) + } else { + remote_subdir::RemoteSubdirClient::new( + channel.clone(), + platform, + self.client.clone(), + self.cache.clone(), + source_config.clone(), + reporter, + ) + .await + .map(SubdirData::from_client) + } } else { return Err(GatewayError::UnsupportedUrl(format!( "'{}' is not a supported scheme", @@ -351,12 +356,9 @@ enum PendingOrFetched { Fetched(T), } -fn supports_sharded_repodata(url: &Url) -> bool { - (url.scheme() == "http" || url.scheme() == "https") - && (url.host_str() == Some("fast.prefiks.dev") - || url - .host_str() - .map_or(false, |host| host.ends_with("prefix.dev"))) +fn force_sharded_repodata(url: &Url) -> bool { + matches!(url.scheme(), "http" | "https") + && matches!(url.host_str(), Some("fast.prefiks.dev" | "fast.prefix.dev")) } #[cfg(test)] diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs index f2713c4d0..223075c13 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs @@ -16,7 +16,10 @@ use tokio::{ use url::Url; use super::ShardedRepodata; -use crate::{reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError, Reporter}; +use crate::{ + fetch::CacheAction, reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError, + Reporter, +}; /// Magic number that identifies the cache file format. const MAGIC_NUMBER: &[u8] = b"SHARD-CACHE-V1"; @@ -28,6 +31,7 @@ pub async fn fetch_index( client: ClientWithMiddleware, channel_base_url: &Url, cache_dir: &Path, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, reporter: Option<&dyn Reporter>, ) -> Result { @@ -125,78 +129,105 @@ pub async fn fetch_index( let canonical_request = SimpleRequest::get(&canonical_shards_url); // Try reading the cached file - if let Ok(cache_header) = read_cached_index(&mut cache_reader).await { - match cache_header - .policy - .before_request(&canonical_request, SystemTime::now()) - { - BeforeRequest::Fresh(_) => { + if cache_action != CacheAction::NoCache { + if let Ok(cache_header) = read_cached_index(&mut cache_reader).await { + // If we are in cache-only mode we can't fetch the index from the server + if cache_action == CacheAction::ForceCacheOnly { if let Ok(shard_index) = read_shard_index_from_reader(&mut cache_reader).await { - tracing::debug!("shard index cache hit"); + tracing::debug!("using locally cached shard index for {channel_base_url}"); return Ok(shard_index); } - } - BeforeRequest::Stale { - request: state_request, - .. - } => { - // Determine the actual URL to use for the request - let shards_url = channel_base_url - .join(REPODATA_SHARDS_FILENAME) - .expect("invalid shard base url"); - - // Construct the actual request that we will send - let request = client - .get(shards_url.clone()) - .headers(state_request.headers().clone()) - .build() - .expect("failed to build request for shard index"); - - // Acquire a permit to do a request - let _request_permit = concurrent_requests_semaphore.acquire().await; - - // Send the request - let download_reporter = reporter.map(|r| (r, r.on_download_start(&shards_url))); - let response = client.execute(request).await?; - - match cache_header.policy.after_response( - &state_request, - &response, - SystemTime::now(), - ) { - AfterResponse::NotModified(_policy, _) => { - // The cached file is still valid - match read_shard_index_from_reader(&mut cache_reader).await { - Ok(shard_index) => { - tracing::debug!("shard index cache was not modified"); - // If reading the file failed for some reason we'll just fetch it - // again. - return Ok(shard_index); - } - Err(e) => { - tracing::warn!("the cached shard index has been corrupted: {e}"); - if let Some((reporter, index)) = download_reporter { - reporter.on_download_complete(response.url(), index); + } else { + match cache_header + .policy + .before_request(&canonical_request, SystemTime::now()) + { + BeforeRequest::Fresh(_) => { + if let Ok(shard_index) = + read_shard_index_from_reader(&mut cache_reader).await + { + tracing::debug!("shard index cache hit"); + return Ok(shard_index); + } + } + BeforeRequest::Stale { + request: state_request, + .. + } => { + if cache_action == CacheAction::UseCacheOnly { + return Err(GatewayError::CacheError( + format!("the sharded index cache for {channel_base_url} is stale and cache-only mode is enabled"), + )); + } + + // Determine the actual URL to use for the request + let shards_url = channel_base_url + .join(REPODATA_SHARDS_FILENAME) + .expect("invalid shard base url"); + + // Construct the actual request that we will send + let request = client + .get(shards_url.clone()) + .headers(state_request.headers().clone()) + .build() + .expect("failed to build request for shard index"); + + // Acquire a permit to do a request + let _request_permit = concurrent_requests_semaphore.acquire().await; + + // Send the request + let download_reporter = + reporter.map(|r| (r, r.on_download_start(&shards_url))); + let response = client.execute(request).await?; + + match cache_header.policy.after_response( + &state_request, + &response, + SystemTime::now(), + ) { + AfterResponse::NotModified(_policy, _) => { + // The cached file is still valid + match read_shard_index_from_reader(&mut cache_reader).await { + Ok(shard_index) => { + tracing::debug!("shard index cache was not modified"); + // If reading the file failed for some reason we'll just + // fetch it again. + return Ok(shard_index); + } + Err(e) => { + tracing::warn!( + "the cached shard index has been corrupted: {e}" + ); + if let Some((reporter, index)) = download_reporter { + reporter.on_download_complete(response.url(), index); + } + } } } + AfterResponse::Modified(policy, _) => { + // Close the old file so we can create a new one. + tracing::debug!("shard index cache has become stale"); + return from_response( + cache_reader.into_inner(), + &cache_path, + policy, + response, + download_reporter, + ) + .await; + } } } - AfterResponse::Modified(policy, _) => { - // Close the old file so we can create a new one. - tracing::debug!("shard index cache has become stale"); - return from_response( - cache_reader.into_inner(), - &cache_path, - policy, - response, - download_reporter, - ) - .await; - } } } } - }; + } + + if cache_action == CacheAction::ForceCacheOnly { + return Err(GatewayError::CacheError(format!( + "the sharded index cache for {channel_base_url} is not available" + ))); + } tracing::debug!("fetching fresh shard index"); diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs index 39a176bfb..cd5a6827d 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs @@ -7,7 +7,7 @@ use simple_spawn_blocking::tokio::run_blocking_task; use url::Url; use crate::{ - fetch::FetchRepoDataError, + fetch::{CacheAction, FetchRepoDataError}, gateway::{error::SubdirNotFoundError, subdir::SubdirClient}, reporter::ResponseReporterExt, GatewayError, Reporter, @@ -22,6 +22,7 @@ pub struct ShardedSubdir { package_base_url: Url, sharded_repodata: ShardedRepodata, cache_dir: PathBuf, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, } @@ -31,6 +32,7 @@ impl ShardedSubdir { subdir: String, client: ClientWithMiddleware, cache_dir: PathBuf, + cache_action: CacheAction, concurrent_requests_semaphore: Arc, reporter: Option<&dyn Reporter>, ) -> Result { @@ -44,6 +46,7 @@ impl ShardedSubdir { client.clone(), &index_base_url, &cache_dir, + cache_action, concurrent_requests_semaphore.clone(), reporter, ) @@ -92,6 +95,7 @@ impl ShardedSubdir { package_base_url: add_trailing_slash(&package_base_url).into_owned(), sharded_repodata, cache_dir, + cache_action, concurrent_requests_semaphore, }) } @@ -113,21 +117,34 @@ impl SubdirClient for ShardedSubdir { let shard_cache_path = self.cache_dir.join(format!("{shard:x}.msgpack")); // Read the cached shard - match tokio::fs::read(&shard_cache_path).await { - Ok(cached_bytes) => { - // Decode the cached shard - return parse_records( - cached_bytes, - self.channel.canonical_name(), - self.package_base_url.clone(), - ) - .await - .map(Arc::from); - } - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - // The file is missing from the cache, we need to download it. + if self.cache_action != CacheAction::NoCache { + match tokio::fs::read(&shard_cache_path).await { + Ok(cached_bytes) => { + // Decode the cached shard + return parse_records( + cached_bytes, + self.channel.canonical_name(), + self.package_base_url.clone(), + ) + .await + .map(Arc::from); + } + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + // The file is missing from the cache, we need to download + // it. + } + Err(err) => return Err(FetchRepoDataError::IoError(err).into()), } - Err(err) => return Err(FetchRepoDataError::IoError(err).into()), + } + + if matches!( + self.cache_action, + CacheAction::UseCacheOnly | CacheAction::ForceCacheOnly + ) { + return Err(GatewayError::CacheError(format!( + "the shard for package '{}' is not in the cache", + name.as_source() + ))); } // Download the shard