From e05d2adc2029ff693a799e718a99fe92de664a6c Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Tue, 12 Nov 2024 16:57:05 +0100 Subject: [PATCH] feat: introduce condaurl to enable better channel comparison --- .../src/channel/conda_url.rs | 76 +++++++++++++++ crates/rattler_conda_types/src/channel/mod.rs | 95 ++++++++----------- crates/rattler_conda_types/src/lib.rs | 2 +- .../rattler_conda_types/src/repo_data/mod.rs | 5 +- .../src/gateway/channel_config.rs | 7 +- .../src/gateway/mod.rs | 7 +- .../src/gateway/sharded_subdir/mod.rs | 4 +- .../src/sparse/mod.rs | 3 +- py-rattler/Cargo.lock | 5 +- py-rattler/src/lock/mod.rs | 2 +- py-rattler/src/repo_data/gateway.rs | 2 +- 11 files changed, 138 insertions(+), 70 deletions(-) create mode 100644 crates/rattler_conda_types/src/channel/conda_url.rs diff --git a/crates/rattler_conda_types/src/channel/conda_url.rs b/crates/rattler_conda_types/src/channel/conda_url.rs new file mode 100644 index 000000000..49231f4fa --- /dev/null +++ b/crates/rattler_conda_types/src/channel/conda_url.rs @@ -0,0 +1,76 @@ +use std::fmt::{Display, Formatter}; + +use serde::{Deserialize, Deserializer, Serialize}; +use url::Url; + +use crate::Platform; + +/// Represents a channel base url. This is a wrapper around an url that is +/// normalized: +/// +/// * The URL always contains a trailing `/`. +/// +/// This is useful to be able to compare different channels. +#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize)] +#[serde(transparent)] +pub struct CondaUrl(Url); + +impl CondaUrl { + /// Returns the base Url of the channel. + pub fn url(&self) -> &Url { + &self.0 + } + + /// Returns the string representation of the url. + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + /// Append the platform to the base url. + pub fn platform_url(&self, platform: Platform) -> Url { + self.0 + .join(&format!("{}/", platform.as_str())) // trailing slash is important here as this signifies a directory + .expect("platform is a valid url fragment") + } +} + +impl<'de> Deserialize<'de> for CondaUrl { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let url = Url::deserialize(deserializer)?; + Ok(url.into()) + } +} + +impl From for CondaUrl { + fn from(url: Url) -> Self { + let path = url.path(); + if path.ends_with('/') { + Self(url) + } else { + let mut url = url.clone(); + url.set_path(&format!("{path}/")); + Self(url) + } + } +} + +impl From for Url { + fn from(value: CondaUrl) -> Self { + value.0 + } +} + +impl AsRef for CondaUrl { + fn as_ref(&self) -> &Url { + &self.0 + } +} + +impl Display for CondaUrl { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} diff --git a/crates/rattler_conda_types/src/channel/mod.rs b/crates/rattler_conda_types/src/channel/mod.rs index b9050ec04..ce01661ac 100644 --- a/crates/rattler_conda_types/src/channel/mod.rs +++ b/crates/rattler_conda_types/src/channel/mod.rs @@ -13,10 +13,11 @@ use typed_path::{Utf8NativePathBuf, Utf8TypedPath, Utf8TypedPathBuf}; use url::Url; use super::{ParsePlatformError, Platform}; -use crate::utils::{ - path::is_path, - url::{add_trailing_slash, parse_scheme}, -}; +use crate::utils::{path::is_path, url::parse_scheme}; + +mod conda_url; + +pub use conda_url::CondaUrl; const DEFAULT_CHANNEL_ALIAS: &str = "https://conda.anaconda.org"; @@ -105,7 +106,7 @@ impl NamedChannelOrUrl { /// Converts the channel to a base url using the given configuration. /// This method ensures that the base url always ends with a `/`. - pub fn into_base_url(self, config: &ChannelConfig) -> Result { + pub fn into_base_url(self, config: &ChannelConfig) -> Result { let url = match self { NamedChannelOrUrl::Name(name) => { let mut base_url = config.channel_alias.clone(); @@ -114,16 +115,17 @@ impl NamedChannelOrUrl { segments.push(segment); } } - base_url + base_url.into() } - NamedChannelOrUrl::Url(url) => url, + NamedChannelOrUrl::Url(url) => url.into(), NamedChannelOrUrl::Path(path) => { let absolute_path = absolute_path(path.as_str(), &config.root_dir)?; directory_path_to_url(absolute_path.to_path()) .map_err(|_err| ParseChannelError::InvalidPath(path.to_string()))? + .into() } }; - Ok(add_trailing_slash(&url).into_owned()) + Ok(url) } /// Converts this instance into a channel. @@ -136,7 +138,8 @@ impl NamedChannelOrUrl { let base_url = self.into_base_url(config)?; Ok(Channel { name, - ..Channel::from_url(base_url) + base_url, + platforms: None, }) } } @@ -189,7 +192,7 @@ pub struct Channel { pub platforms: Option>, /// Base URL of the channel, everything is relative to this url. - pub base_url: Url, + pub base_url: CondaUrl, /// The name of the channel pub name: Option, @@ -221,7 +224,7 @@ impl Channel { .map_err(|_err| ParseChannelError::InvalidPath(channel.to_owned()))?; Self { platforms, - base_url: url, + base_url: url.into(), name: Some(channel.to_owned()), } } @@ -252,15 +255,6 @@ impl Channel { // Get the path part of the URL but trim the directory suffix let path = url.path().trim_end_matches('/'); - // Ensure that the base_url does always ends in a `/` - let base_url = if url.path().ends_with('/') { - url.clone() - } else { - let mut url = url.clone(); - url.set_path(&format!("{path}/")); - url - }; - // Case 1: No path give, channel name is "" // Case 2: migrated_custom_channels @@ -268,23 +262,23 @@ impl Channel { // Case 4: custom_channels matches // Case 5: channel_alias match - if base_url.has_host() { + if url.has_host() { // Case 7: Fallback let name = path.trim_start_matches('/'); Self { platforms: None, name: (!name.is_empty()).then_some(name).map(str::to_owned), - base_url, + base_url: url.into(), } } else { // Case 6: non-otherwise-specified file://-type urls let name = path .rsplit_once('/') - .map_or_else(|| base_url.path(), |(_, path_part)| path_part); + .map_or_else(|| path, |(_, path_part)| path_part); Self { platforms: None, name: (!name.is_empty()).then_some(name).map(str::to_owned), - base_url, + base_url: url.into(), } } } @@ -305,7 +299,8 @@ impl Channel { base_url: config .channel_alias .join(dir_name.as_ref()) - .expect("name is not a valid Url"), + .expect("name is not a valid Url") + .into(), name: (!name.is_empty()).then_some(name).map(str::to_owned), } } @@ -329,14 +324,14 @@ impl Channel { let url = Url::from_directory_path(path).expect("path is a valid url"); Self { platforms: None, - base_url: url, + base_url: url.into(), name: None, } } /// Returns the name of the channel pub fn name(&self) -> &str { - match self.base_url().scheme() { + match self.base_url.url().scheme() { // The name of the channel is only defined for http and https channels. // If the name is not defined we return the base url. "https" | "http" => self @@ -347,17 +342,9 @@ impl Channel { } } - /// Returns the base Url of the channel. This does not include the platform - /// part. - pub fn base_url(&self) -> &Url { - &self.base_url - } - /// Returns the Urls for the given platform pub fn platform_url(&self, platform: Platform) -> Url { - self.base_url() - .join(&format!("{}/", platform.as_str())) // trailing slash is important here as this signifies a directory - .expect("platform is a valid url fragment") + self.base_url.platform_url(platform) } /// Returns the Urls for all the supported platforms of this package. @@ -380,7 +367,7 @@ impl Channel { /// Returns the canonical name of the channel pub fn canonical_name(&self) -> String { - self.base_url.clone().redact().to_string() + self.base_url.url().clone().redact().to_string() } } @@ -579,7 +566,7 @@ mod tests { let channel = Channel::from_str("conda-forge", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); @@ -596,14 +583,14 @@ mod tests { let channel = Channel::from_str("https://conda.anaconda.org/conda-forge/", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); assert_eq!(channel.name(), "conda-forge"); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_string(), + channel.base_url.to_string(), "https://conda.anaconda.org/conda-forge/" ); @@ -622,12 +609,12 @@ mod tests { assert_eq!(channel.name.as_deref(), Some("conda-forge")); assert_eq!(channel.name(), "file:///var/channels/conda-forge/"); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("file:///var/channels/conda-forge/").unwrap() ); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_string(), + channel.base_url.to_string(), "file:///var/channels/conda-forge/" ); @@ -643,7 +630,7 @@ mod tests { ); assert_eq!(channel.platforms, None); assert_eq!( - channel.base_url().to_file_path().unwrap(), + channel.base_url.url().to_file_path().unwrap(), current_dir.join("dir/does/not_exist") ); } @@ -654,7 +641,7 @@ mod tests { let channel = Channel::from_str("http://localhost:1234", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("http://localhost:1234/").unwrap() ); assert_eq!(channel.name, None); @@ -681,7 +668,7 @@ mod tests { ) .unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge")); @@ -693,7 +680,7 @@ mod tests { ) .unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/pkgs/main/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("pkgs/main")); @@ -701,7 +688,7 @@ mod tests { let channel = Channel::from_str("conda-forge/label/rust_dev", &config).unwrap(); assert_eq!( - channel.base_url, + channel.base_url.url().clone(), Url::from_str("https://conda.anaconda.org/conda-forge/label/rust_dev/").unwrap() ); assert_eq!(channel.name.as_deref(), Some("conda-forge/label/rust_dev")); @@ -785,8 +772,8 @@ mod tests { for channel_str in test_channels { let channel = Channel::from_str(channel_str, &channel_config).unwrap(); - assert!(channel.base_url().as_str().ends_with('/')); - assert!(!channel.base_url().as_str().ends_with("//")); + assert!(channel.base_url.as_str().ends_with('/')); + assert!(!channel.base_url.as_str().ends_with("//")); let named_channel = NamedChannelOrUrl::from_str(channel_str).unwrap(); let base_url = named_channel @@ -798,8 +785,8 @@ mod tests { assert!(!base_url_str.ends_with("//")); let channel = named_channel.into_channel(&channel_config).unwrap(); - assert!(channel.base_url().as_str().ends_with('/')); - assert!(!channel.base_url().as_str().ends_with("//")); + assert!(channel.base_url.as_str().ends_with('/')); + assert!(!channel.base_url.as_str().ends_with("//")); } } @@ -813,14 +800,14 @@ mod tests { let channel = Channel::from_str("conda-forge", &channel_config).unwrap(); assert_eq!( &channel.base_url, - named.into_channel(&channel_config).unwrap().base_url() + &named.into_channel(&channel_config).unwrap().base_url ); let named = NamedChannelOrUrl::Name("nvidia/label/cuda-11.8.0".to_string()); let channel = Channel::from_str("nvidia/label/cuda-11.8.0", &channel_config).unwrap(); assert_eq!( - channel.base_url(), - named.into_channel(&channel_config).unwrap().base_url() + channel.base_url, + named.into_channel(&channel_config).unwrap().base_url ); } } diff --git a/crates/rattler_conda_types/src/lib.rs b/crates/rattler_conda_types/src/lib.rs index d79622b17..a49df2253 100644 --- a/crates/rattler_conda_types/src/lib.rs +++ b/crates/rattler_conda_types/src/lib.rs @@ -28,7 +28,7 @@ pub mod prefix_record; use std::path::{Path, PathBuf}; pub use build_spec::{BuildNumber, BuildNumberSpec, ParseBuildNumberSpecError}; -pub use channel::{Channel, ChannelConfig, NamedChannelOrUrl, ParseChannelError}; +pub use channel::{Channel, ChannelConfig, CondaUrl, NamedChannelOrUrl, ParseChannelError}; pub use channel_data::{ChannelData, ChannelDataPackage}; pub use environment_yaml::{EnvironmentYaml, MatchSpecOrSubSection}; pub use explicit_environment_spec::{ diff --git a/crates/rattler_conda_types/src/repo_data/mod.rs b/crates/rattler_conda_types/src/repo_data/mod.rs index 56c3ec299..6f2abcb7e 100644 --- a/crates/rattler_conda_types/src/repo_data/mod.rs +++ b/crates/rattler_conda_types/src/repo_data/mod.rs @@ -234,7 +234,8 @@ impl RepoData { records.push(RepoDataRecord { url: compute_package_url( &channel - .base_url() + .base_url + .url() .join(&package_record.subdir) .expect("cannot join channel base_url and subdir"), base_url.as_deref(), @@ -609,7 +610,7 @@ mod test { &ChannelConfig::default_with_root_dir(std::env::current_dir().unwrap()), ) .unwrap(); - let base_url = channel.base_url().join("linux-64/").unwrap(); + let base_url = channel.base_url.url().join("linux-64/").unwrap(); assert_eq!( compute_package_url(&base_url, None, "bla.conda").to_string(), "https://conda.anaconda.org/conda-forge/linux-64/bla.conda" diff --git a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs index 65419292e..ad42e2301 100644 --- a/crates/rattler_repodata_gateway/src/gateway/channel_config.rs +++ b/crates/rattler_repodata_gateway/src/gateway/channel_config.rs @@ -1,7 +1,6 @@ +use rattler_conda_types::CondaUrl; use std::collections::HashMap; -use rattler_conda_types::Channel; - use crate::fetch::CacheAction; /// Describes additional properties that influence how the gateway fetches @@ -49,14 +48,14 @@ pub struct ChannelConfig { /// Describes per channel properties that influence how the gateway fetches /// repodata. - pub per_channel: HashMap, + pub per_channel: HashMap, } impl ChannelConfig { /// Returns the source configuration for the given channel. If the channel /// does not have a specific source configuration the default source /// configuration will be returned. - pub fn get(&self, channel: &Channel) -> &SourceConfig { + pub fn get(&self, channel: &CondaUrl) -> &SourceConfig { self.per_channel.get(channel).unwrap_or(&self.default) } } diff --git a/crates/rattler_repodata_gateway/src/gateway/mod.rs b/crates/rattler_repodata_gateway/src/gateway/mod.rs index c768acd06..2a198d745 100644 --- a/crates/rattler_repodata_gateway/src/gateway/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/mod.rs @@ -143,7 +143,7 @@ impl Gateway { /// This method does not clear any on-disk cache. pub fn clear_repodata_cache(&self, channel: &Channel, subdirs: SubdirSelection) { self.inner.subdirs.retain(|key, _| { - key.0.base_url() != channel.base_url() || !subdirs.contains(key.1.as_str()) + key.0.base_url != channel.base_url || !subdirs.contains(key.1.as_str()) }); } } @@ -292,8 +292,9 @@ impl GatewayInner { || url.scheme() == "oci" { // Check if the channel supports sharded repodata - let source_config = self.channel_config.get(channel); - if self.channel_config.get(channel).sharded_enabled || force_sharded_repodata(&url) { + let source_config = self.channel_config.get(&channel.base_url); + tracing::warn!("{:#?}", source_config); + if source_config.sharded_enabled || force_sharded_repodata(&url) { sharded_subdir::ShardedSubdir::new( channel.clone(), platform.to_string(), diff --git a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs index cd5a6827d..47a92318b 100644 --- a/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs +++ b/crates/rattler_repodata_gateway/src/gateway/sharded_subdir/mod.rs @@ -37,7 +37,9 @@ impl ShardedSubdir { reporter: Option<&dyn Reporter>, ) -> Result { // Construct the base url for the shards (e.g. `/`). - let index_base_url = add_trailing_slash(channel.base_url()) + let index_base_url = channel + .base_url + .url() .join(&format!("{subdir}/")) .expect("invalid subdir url"); diff --git a/crates/rattler_repodata_gateway/src/sparse/mod.rs b/crates/rattler_repodata_gateway/src/sparse/mod.rs index dea1363ab..6800d43f2 100644 --- a/crates/rattler_repodata_gateway/src/sparse/mod.rs +++ b/crates/rattler_repodata_gateway/src/sparse/mod.rs @@ -309,6 +309,7 @@ fn parse_records<'i>( url: compute_package_url( &channel .base_url + .url() .join(&format!("{}/", &package_record.subdir)) .expect("failed determine repo_base_url"), base_url, @@ -477,13 +478,13 @@ mod test { use std::path::{Path, PathBuf}; use bytes::Bytes; + use fs_err as fs; use itertools::Itertools; use rattler_conda_types::{Channel, ChannelConfig, PackageName, RepoData, RepoDataRecord}; use rstest::rstest; use super::{load_repo_data_recursively, PackageFilename, SparseRepoData}; use crate::utils::test::fetch_repo_data; - use fs_err as fs; fn test_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data") diff --git a/py-rattler/Cargo.lock b/py-rattler/Cargo.lock index 9301142b0..2c9c92a0e 100644 --- a/py-rattler/Cargo.lock +++ b/py-rattler/Cargo.lock @@ -2890,7 +2890,7 @@ dependencies = [ [[package]] name = "rattler_repodata_gateway" -version = "0.21.19" +version = "0.21.20" dependencies = [ "anyhow", "async-compression", @@ -2946,6 +2946,7 @@ name = "rattler_shell" version = "0.22.5" dependencies = [ "enum_dispatch", + "fs-err 3.0.0", "indexmap 2.6.0", "itertools 0.13.0", "rattler_conda_types", @@ -2958,7 +2959,7 @@ dependencies = [ [[package]] name = "rattler_solve" -version = "1.2.0" +version = "1.2.1" dependencies = [ "chrono", "futures", diff --git a/py-rattler/src/lock/mod.rs b/py-rattler/src/lock/mod.rs index d2186fafe..53308000c 100644 --- a/py-rattler/src/lock/mod.rs +++ b/py-rattler/src/lock/mod.rs @@ -270,7 +270,7 @@ impl From for PyLockChannel { impl From for PyLockChannel { fn from(value: rattler_conda_types::Channel) -> Self { Self { - inner: Channel::from(value.base_url().to_string()), + inner: Channel::from(value.base_url.to_string()), } } } diff --git a/py-rattler/src/repo_data/gateway.rs b/py-rattler/src/repo_data/gateway.rs index fb12fbe87..99fd04634 100644 --- a/py-rattler/src/repo_data/gateway.rs +++ b/py-rattler/src/repo_data/gateway.rs @@ -61,7 +61,7 @@ impl PyGateway { default: default_config.into(), per_channel: per_channel_config .into_iter() - .map(|(k, v)| (k.into(), v.into())) + .map(|(k, v)| (k.inner.base_url, v.into())) .collect(), };