Skip to content

Commit

Permalink
fix: allow enabling sharded for certain channels
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Nov 12, 2024
1 parent bf67366 commit 559f76c
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 117 deletions.
35 changes: 23 additions & 12 deletions crates/rattler_repodata_gateway/src/gateway/channel_config.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
use crate::fetch::CacheAction;
use rattler_conda_types::Channel;
use std::collections::HashMap;

/// Describes additional properties that influence how the gateway fetches repodata for a specific
/// channel.
use rattler_conda_types::Channel;

use crate::fetch::CacheAction;

/// Describes additional properties that influence how the gateway fetches
/// repodata for a specific channel.
#[derive(Debug, Clone)]
pub struct SourceConfig {
/// When enabled repodata can be fetched incrementally using JLAP (defaults to true)
/// When enabled repodata can be fetched incrementally using JLAP (defaults
/// to true)
pub jlap_enabled: bool,

/// When enabled, the zstd variant will be used if available (defaults to true)
/// When enabled, the zstd variant will be used if available (defaults to
/// true)
pub zstd_enabled: bool,

/// When enabled, the bz2 variant will be used if available (defaults to true)
/// When enabled, the bz2 variant will be used if available (defaults to
/// true)
pub bz2_enabled: bool,

/// When enabled, sharded repodata will be used if available.
pub sharded_enabled: bool,

/// Describes fetching repodata from a channel should interact with any
/// caches.
pub cache_action: CacheAction,
Expand All @@ -26,6 +34,7 @@ impl Default for SourceConfig {
jlap_enabled: true,
zstd_enabled: true,
bz2_enabled: true,
sharded_enabled: false,
cache_action: CacheAction::default(),
}
}
Expand All @@ -34,17 +43,19 @@ impl Default for SourceConfig {
/// Describes additional information for fetching channels.
#[derive(Debug, Default)]
pub struct ChannelConfig {
/// The default source configuration. If a channel does not have a specific source configuration
/// this configuration will be used.
/// The default source configuration. If a channel does not have a specific
/// source configuration this configuration will be used.
pub default: SourceConfig,

/// Describes per channel properties that influence how the gateway fetches repodata.
/// Describes per channel properties that influence how the gateway fetches
/// repodata.
pub per_channel: HashMap<Channel, SourceConfig>,
}

impl ChannelConfig {
/// Returns the source configuration for the given channel. If the channel does not have a
/// specific source configuration the default source configuration will be returned.
/// Returns the source configuration for the given channel. If the channel
/// does not have a specific source configuration the default source
/// configuration will be returned.
pub fn get(&self, channel: &Channel) -> &SourceConfig {
self.per_channel.get(channel).unwrap_or(&self.default)
}
Expand Down
3 changes: 3 additions & 0 deletions crates/rattler_repodata_gateway/src/gateway/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ pub enum GatewayError {

#[error(transparent)]
InvalidPackageName(#[from] InvalidPackageNameError),

#[error("{0}")]
CacheError(String),
}

impl From<Cancelled> for GatewayError {
Expand Down
56 changes: 29 additions & 27 deletions crates/rattler_repodata_gateway/src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,32 +286,37 @@ impl GatewayInner {
"unsupported file based url".to_string(),
));
}
} else if supports_sharded_repodata(&url) {
sharded_subdir::ShardedSubdir::new(
channel.clone(),
platform.to_string(),
self.client.clone(),
self.cache.clone(),
self.concurrent_requests_semaphore.clone(),
reporter.as_deref(),
)
.await
.map(SubdirData::from_client)
} else if url.scheme() == "http"
|| url.scheme() == "https"
|| url.scheme() == "gcs"
|| url.scheme() == "oci"
{
remote_subdir::RemoteSubdirClient::new(
channel.clone(),
platform,
self.client.clone(),
self.cache.clone(),
self.channel_config.get(channel).clone(),
reporter,
)
.await
.map(SubdirData::from_client)
// Check if the channel supports sharded repodata
let source_config = self.channel_config.get(channel);
if self.channel_config.get(channel).sharded_enabled || force_sharded_repodata(&url) {
sharded_subdir::ShardedSubdir::new(
channel.clone(),
platform.to_string(),
self.client.clone(),
self.cache.clone(),
source_config.cache_action,
self.concurrent_requests_semaphore.clone(),
reporter.as_deref(),
)
.await
.map(SubdirData::from_client)
} else {
remote_subdir::RemoteSubdirClient::new(
channel.clone(),
platform,
self.client.clone(),
self.cache.clone(),
source_config.clone(),
reporter,
)
.await
.map(SubdirData::from_client)
}
} else {
return Err(GatewayError::UnsupportedUrl(format!(
"'{}' is not a supported scheme",
Expand Down Expand Up @@ -351,12 +356,9 @@ enum PendingOrFetched<T> {
Fetched(T),
}

fn supports_sharded_repodata(url: &Url) -> bool {
(url.scheme() == "http" || url.scheme() == "https")
&& (url.host_str() == Some("fast.prefiks.dev")
|| url
.host_str()
.map_or(false, |host| host.ends_with("prefix.dev")))
fn force_sharded_repodata(url: &Url) -> bool {
matches!(url.scheme(), "http" | "https")
&& matches!(url.host_str(), Some("fast.prefiks.dev" | "fast.prefix.dev"))
}

#[cfg(test)]
Expand Down
157 changes: 94 additions & 63 deletions crates/rattler_repodata_gateway/src/gateway/sharded_subdir/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use tokio::{
use url::Url;

use super::ShardedRepodata;
use crate::{reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError, Reporter};
use crate::{
fetch::CacheAction, reporter::ResponseReporterExt, utils::url_to_cache_filename, GatewayError,
Reporter,
};

/// Magic number that identifies the cache file format.
const MAGIC_NUMBER: &[u8] = b"SHARD-CACHE-V1";
Expand All @@ -28,6 +31,7 @@ pub async fn fetch_index(
client: ClientWithMiddleware,
channel_base_url: &Url,
cache_dir: &Path,
cache_action: CacheAction,
concurrent_requests_semaphore: Arc<tokio::sync::Semaphore>,
reporter: Option<&dyn Reporter>,
) -> Result<ShardedRepodata, GatewayError> {
Expand Down Expand Up @@ -125,78 +129,105 @@ pub async fn fetch_index(
let canonical_request = SimpleRequest::get(&canonical_shards_url);

// Try reading the cached file
if let Ok(cache_header) = read_cached_index(&mut cache_reader).await {
match cache_header
.policy
.before_request(&canonical_request, SystemTime::now())
{
BeforeRequest::Fresh(_) => {
if cache_action != CacheAction::NoCache {
if let Ok(cache_header) = read_cached_index(&mut cache_reader).await {
// If we are in cache-only mode we can't fetch the index from the server
if cache_action == CacheAction::ForceCacheOnly {
if let Ok(shard_index) = read_shard_index_from_reader(&mut cache_reader).await {
tracing::debug!("shard index cache hit");
tracing::debug!("using locally cached shard index for {channel_base_url}");
return Ok(shard_index);
}
}
BeforeRequest::Stale {
request: state_request,
..
} => {
// Determine the actual URL to use for the request
let shards_url = channel_base_url
.join(REPODATA_SHARDS_FILENAME)
.expect("invalid shard base url");

// Construct the actual request that we will send
let request = client
.get(shards_url.clone())
.headers(state_request.headers().clone())
.build()
.expect("failed to build request for shard index");

// Acquire a permit to do a request
let _request_permit = concurrent_requests_semaphore.acquire().await;

// Send the request
let download_reporter = reporter.map(|r| (r, r.on_download_start(&shards_url)));
let response = client.execute(request).await?;

match cache_header.policy.after_response(
&state_request,
&response,
SystemTime::now(),
) {
AfterResponse::NotModified(_policy, _) => {
// The cached file is still valid
match read_shard_index_from_reader(&mut cache_reader).await {
Ok(shard_index) => {
tracing::debug!("shard index cache was not modified");
// If reading the file failed for some reason we'll just fetch it
// again.
return Ok(shard_index);
}
Err(e) => {
tracing::warn!("the cached shard index has been corrupted: {e}");
if let Some((reporter, index)) = download_reporter {
reporter.on_download_complete(response.url(), index);
} else {
match cache_header
.policy
.before_request(&canonical_request, SystemTime::now())
{
BeforeRequest::Fresh(_) => {
if let Ok(shard_index) =
read_shard_index_from_reader(&mut cache_reader).await
{
tracing::debug!("shard index cache hit");
return Ok(shard_index);
}
}
BeforeRequest::Stale {
request: state_request,
..
} => {
if cache_action == CacheAction::UseCacheOnly {
return Err(GatewayError::CacheError(
format!("the sharded index cache for {channel_base_url} is stale and cache-only mode is enabled"),
));
}

// Determine the actual URL to use for the request
let shards_url = channel_base_url
.join(REPODATA_SHARDS_FILENAME)
.expect("invalid shard base url");

// Construct the actual request that we will send
let request = client
.get(shards_url.clone())
.headers(state_request.headers().clone())
.build()
.expect("failed to build request for shard index");

// Acquire a permit to do a request
let _request_permit = concurrent_requests_semaphore.acquire().await;

// Send the request
let download_reporter =
reporter.map(|r| (r, r.on_download_start(&shards_url)));
let response = client.execute(request).await?;

match cache_header.policy.after_response(
&state_request,
&response,
SystemTime::now(),
) {
AfterResponse::NotModified(_policy, _) => {
// The cached file is still valid
match read_shard_index_from_reader(&mut cache_reader).await {
Ok(shard_index) => {
tracing::debug!("shard index cache was not modified");
// If reading the file failed for some reason we'll just
// fetch it again.
return Ok(shard_index);
}
Err(e) => {
tracing::warn!(
"the cached shard index has been corrupted: {e}"
);
if let Some((reporter, index)) = download_reporter {
reporter.on_download_complete(response.url(), index);
}
}
}
}
AfterResponse::Modified(policy, _) => {
// Close the old file so we can create a new one.
tracing::debug!("shard index cache has become stale");
return from_response(
cache_reader.into_inner(),
&cache_path,
policy,
response,
download_reporter,
)
.await;
}
}
}
AfterResponse::Modified(policy, _) => {
// Close the old file so we can create a new one.
tracing::debug!("shard index cache has become stale");
return from_response(
cache_reader.into_inner(),
&cache_path,
policy,
response,
download_reporter,
)
.await;
}
}
}
}
};
}

if cache_action == CacheAction::ForceCacheOnly {
return Err(GatewayError::CacheError(format!(
"the sharded index cache for {channel_base_url} is not available"
)));
}

tracing::debug!("fetching fresh shard index");

Expand Down
Loading

0 comments on commit 559f76c

Please sign in to comment.