Skip to content

Commit

Permalink
fix: Fallback to fully reading the package stream when downloading be…
Browse files Browse the repository at this point in the history
…fore attempting decompression (#797)
  • Loading branch information
jpcorreia99 authored Aug 1, 2024
1 parent a8df43f commit ce331ff
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 42 deletions.
2 changes: 2 additions & 0 deletions crates/rattler_package_streaming/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["fs"] }
tokio-util = { workspace = true, features = ["io-util"] }
tracing = { workspace = true }
url = { workspace = true }
zip = { workspace = true, features = ["deflate", "time"] }
zstd = { workspace = true, features = ["zstdmt"] }
Expand All @@ -44,3 +45,4 @@ tools = { path = "../tools" }
walkdir = { workspace = true }
rstest = { workspace = true }
rstest_reuse = { workspace = true }
insta = { workspace = true, features = ["yaml"] }
2 changes: 1 addition & 1 deletion crates/rattler_package_streaming/src/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<ExtractResu
/// ```
pub fn extract_conda(archive: &Path, destination: &Path) -> Result<ExtractResult, ExtractError> {
let file = File::open(archive)?;
crate::read::extract_conda(file, destination)
crate::read::extract_conda_via_streaming(file, destination)
}

/// Extracts the contents a package archive at the specified path to a directory. The type of
Expand Down
90 changes: 69 additions & 21 deletions crates/rattler_package_streaming/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
//! [`std::io::Read`] trait.

use super::{ExtractError, ExtractResult};
use rattler_digest::HashingReader;
use std::io::{copy, Seek, SeekFrom};
use std::mem::ManuallyDrop;
use std::{ffi::OsStr, io::Read, path::Path};
use zip::read::read_zipfile_from_stream;
use tempfile::SpooledTempFile;
use zip::read::{read_zipfile_from_stream, ZipArchive, ZipFile};

/// Returns the `.tar.bz2` as a decompressed `tar::Archive`. The `tar::Archive` can be used to
/// extract the files from it, or perform introspection.
Expand Down Expand Up @@ -44,7 +47,10 @@ pub fn extract_tar_bz2(
}

/// Extracts the contents of a `.conda` package archive.
pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractResult, ExtractError> {
pub fn extract_conda_via_streaming(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
// Construct the destination path if it doesnt exist yet
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;

Expand All @@ -56,27 +62,69 @@ pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractRes

// Iterate over all entries in the zip-file and extract them one-by-one
while let Some(file) = read_zipfile_from_stream(&mut md5_reader)? {
// If an error occurs while we are reading the contents of the zip we don't want to
// seek to the end of the file. Using [`ManuallyDrop`] we prevent `drop` to be called on
// the `file` in case the stack unwinds.
let mut file = ManuallyDrop::new(file);

if file
.mangled_name()
.file_name()
.map(OsStr::to_string_lossy)
.map_or(false, |file_name| file_name.ends_with(".tar.zst"))
{
stream_tar_zst(&mut *file)?.unpack(destination)?;
} else {
// Manually read to the end of the stream if that didn't happen.
std::io::copy(&mut *file, &mut std::io::sink())?;
}

// Take the file out of the [`ManuallyDrop`] to properly drop it.
let _ = ManuallyDrop::into_inner(file);
extract_zipfile(file, destination)?;
}
compute_hashes(md5_reader)
}

/// Extracts the contents of a .conda package archive by fully reading the stream and then decompressing
pub fn extract_conda_via_buffering(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
// delete destination first, as this method is usually used as a fallback from a failed streaming decompression
if destination.exists() {
std::fs::remove_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
}
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;

// Create a SpooledTempFile with a 5MB limit
let mut temp_file = SpooledTempFile::new(5 * 1024 * 1024);
let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader);
let mut md5_reader =
rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader);

copy(&mut md5_reader, &mut temp_file)?;
temp_file.seek(SeekFrom::Start(0))?;
let mut archive = ZipArchive::new(temp_file)?;

for i in 0..archive.len() {
let file = archive.by_index(i)?;
extract_zipfile(file, destination)?;
}
// Read the file to the end to make sure the hash is properly computed.
std::io::copy(&mut md5_reader, &mut std::io::sink())?;

compute_hashes(md5_reader)
}

fn extract_zipfile(zip_file: ZipFile<'_>, destination: &Path) -> Result<(), ExtractError> {
// If an error occurs while we are reading the contents of the zip we don't want to
// seek to the end of the file. Using [`ManuallyDrop`] we prevent `drop` to be called on
// the `file` in case the stack unwinds.
let mut file = ManuallyDrop::new(zip_file);

if file
.mangled_name()
.file_name()
.map(OsStr::to_string_lossy)
.map_or(false, |file_name| file_name.ends_with(".tar.zst"))
{
stream_tar_zst(&mut *file)?.unpack(destination)?;
} else {
// Manually read to the end of the stream if that didn't happen.
std::io::copy(&mut *file, &mut std::io::sink())?;
}

// Take the file out of the [`ManuallyDrop`] to properly drop it.
let _ = ManuallyDrop::into_inner(file);

Ok(())
}

fn compute_hashes<R: Read>(
mut md5_reader: HashingReader<HashingReader<R, rattler_digest::Sha256>, rattler_digest::Md5>,
) -> Result<ExtractResult, ExtractError> {
// Read the file to the end to make sure the hash is properly computed.
std::io::copy(&mut md5_reader, &mut std::io::sink())?;

Expand Down
38 changes: 33 additions & 5 deletions crates/rattler_package_streaming/src/reqwest/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ use std::sync::Arc;
use tokio::io::BufReader;
use tokio_util::either::Either;
use tokio_util::io::StreamReader;
use tracing;
use url::Url;
use zip::result::ZipError;

/// zipfiles may use data descriptors to signal that the decompressor needs to seek ahead in the buffer
/// to find the compressed data length.
/// Since we stream the package over a non seekable HTTP connection, this condition will cause an error during
/// decompression. In this case, we fallback to reading the whole data to a buffer before attempting decompression.
/// Read more in https://github.com/conda-incubator/rattler/issues/794
const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header";

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
Expand Down Expand Up @@ -131,12 +140,31 @@ pub async fn extract_conda(
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<ExtractResult, ExtractError> {
// The `response` is used to stream in the package data
let reader = get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
let result = crate::tokio::async_read::extract_conda(reader, destination).await?;
if let Some(reporter) = &reporter {
reporter.on_download_complete();
let reader = get_reader(
url.clone(),
client.clone(),
expected_sha256,
reporter.clone(),
)
.await?;
match crate::tokio::async_read::extract_conda(reader, destination).await {
Ok(result) => {
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
Ok(result)
}
// https://github.com/conda-incubator/rattler/issues/794
Err(ExtractError::ZipError(ZipError::UnsupportedArchive(zip_error)))
if (zip_error.contains(DATA_DESCRIPTOR_ERROR_MESSAGE)) =>
{
tracing::warn!("Failed to stream decompress conda package from '{}' due to the presence of zip data descriptors. Falling back to non streaming decompression", url);
let new_reader =
get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
crate::tokio::async_read::extract_conda_via_buffering(new_reader, destination).await
}
Err(e) => Err(e),
}
Ok(result)
}

/// Extracts the contents a package archive from the specified remote location. The type of package
Expand Down
36 changes: 34 additions & 2 deletions crates/rattler_package_streaming/src/tokio/async_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! [`tokio::io::AsyncRead`] trait.

use crate::{ExtractError, ExtractResult};
use std::io::Read;
use std::path::Path;
use tokio::io::AsyncRead;
use tokio_util::io::SyncIoBridge;
Expand Down Expand Up @@ -30,17 +31,48 @@ pub async fn extract_tar_bz2(
}

/// Extracts the contents of a `.conda` package archive.
/// This will perform on-the-fly decompression by streaming the reader.
pub async fn extract_conda(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
extract_conda_internal(
reader,
destination,
crate::read::extract_conda_via_streaming,
)
.await
}

/// Extracts the contents of a .conda package archive by fully reading the stream and then decompressing
pub async fn extract_conda_via_buffering(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
extract_conda_internal(
reader,
destination,
crate::read::extract_conda_via_buffering,
)
.await
}

/// Extracts the contents of a `.conda` package archive using the provided extraction function
async fn extract_conda_internal(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
extract_fn: fn(Box<dyn Read>, &Path) -> Result<ExtractResult, ExtractError>,
) -> Result<ExtractResult, ExtractError> {
// Create a async -> sync bridge
let reader = SyncIoBridge::new(Box::pin(reader));

// Spawn a block task to perform the extraction
let destination = destination.to_owned();
match tokio::task::spawn_blocking(move || crate::read::extract_conda(reader, &destination))
.await
match tokio::task::spawn_blocking(move || {
let reader: Box<dyn Read> = Box::new(reader);
extract_fn(reader, &destination)
})
.await
{
Ok(result) => result,
Err(err) => {
Expand Down
51 changes: 40 additions & 11 deletions crates/rattler_package_streaming/tests/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use std::{

use rattler_conda_types::package::IndexJson;
use rattler_package_streaming::{
read::{extract_conda, extract_tar_bz2},
read::{extract_conda_via_buffering, extract_conda_via_streaming, extract_tar_bz2},
ExtractError,
};
use rstest::rstest;
use rstest_reuse::{self, apply, template};
use serde_json::json;
use url::Url;

fn test_data_dir() -> PathBuf {
Expand Down Expand Up @@ -111,7 +112,7 @@ fn test_extract_conda(#[case] input: Url, #[case] sha256: &str, #[case] md5: &st
println!("Target dir: {}", temp_dir.display());
let file_path = tools::download_and_cache_file(input, sha256).unwrap();
let target_dir = temp_dir.join(file_path.file_stem().unwrap());
let result = extract_conda(
let result = extract_conda_via_streaming(
File::open(test_data_dir().join(file_path)).unwrap(),
&target_dir,
)
Expand Down Expand Up @@ -211,14 +212,15 @@ async fn test_extract_conda_async(#[case] input: Url, #[case] sha256: &str, #[ca
.unwrap();

let target_dir = temp_dir.join(file_path.file_stem().unwrap());
let result = rattler_package_streaming::tokio::async_read::extract_conda(
tokio::fs::File::open(&test_data_dir().join(file_path))
.await
.unwrap(),
&target_dir,
)
.await
.unwrap();
let result: rattler_package_streaming::ExtractResult =
rattler_package_streaming::tokio::async_read::extract_conda(
tokio::fs::File::open(&test_data_dir().join(file_path))
.await
.unwrap(),
&target_dir,
)
.await
.unwrap();

assert_eq!(&format!("{:x}", result.sha256), sha256);
assert_eq!(&format!("{:x}", result.md5), md5);
Expand Down Expand Up @@ -266,7 +268,7 @@ fn test_extract_flaky_conda(#[values(0, 1, 13, 50, 74, 150, 8096, 16384, 20000)]
let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR"));
println!("Target dir: {}", temp_dir.display());
let target_dir = temp_dir.join(package_path.file_stem().unwrap());
let result = extract_conda(
let result = extract_conda_via_streaming(
FlakyReader {
reader: File::open(package_path).unwrap(),
total_read: 0,
Expand All @@ -279,6 +281,33 @@ fn test_extract_flaky_conda(#[values(0, 1, 13, 50, 74, 150, 8096, 16384, 20000)]
assert_matches::assert_matches!(result, ExtractError::IoError(_));
}

#[rstest]
fn test_extract_data_descriptor_package_fails_streaming_and_uses_buffering() {
let package_path = "tests/resources/ca-certificates-2024.7.4-hbcca054_0.conda";

let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR"));
let target_dir = temp_dir.join("package_using_data_descriptors");
let result = extract_conda_via_streaming(File::open(package_path).unwrap(), &target_dir)
.expect_err("this should error out and not panic");

assert_matches::assert_matches!(
result,
ExtractError::ZipError(zip::result::ZipError::UnsupportedArchive(
"The file length is not available in the local header"
))
);

let new_result =
extract_conda_via_buffering(File::open(package_path).unwrap(), &target_dir).unwrap();

let combined_result = json!({
"sha256": format!("{:x}", new_result.sha256),
"md5": format!("{:x}", new_result.md5),
});

insta::assert_snapshot!(combined_result, @r###"{"sha256":"6a5d6d8a1a7552dbf8c617312ef951a77d2dac09f2aeaba661deebce603a7a97","md5":"a1d1adb5a5dc516dfb3dccc7b9b574a9"}"###);
}

struct FlakyReader<R: Read> {
reader: R,
cutoff: usize,
Expand Down
Binary file not shown.
4 changes: 2 additions & 2 deletions crates/rattler_package_streaming/tests/write.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rattler_conda_types::package::ArchiveType;
use rattler_package_streaming::read::{extract_conda, extract_tar_bz2};
use rattler_package_streaming::read::{extract_conda_via_streaming, extract_tar_bz2};
use rattler_package_streaming::write::{
write_conda_package, write_tar_bz2_package, CompressionLevel,
};
Expand Down Expand Up @@ -209,7 +209,7 @@ fn test_rewrite_conda() {

let name = file_path.file_stem().unwrap().to_string_lossy();
let target_dir = temp_dir.join(file_path.file_stem().unwrap());
extract_conda(File::open(&file_path).unwrap(), &target_dir).unwrap();
extract_conda_via_streaming(File::open(&file_path).unwrap(), &target_dir).unwrap();

let new_archive = temp_dir.join(format!(
"{}-new.conda",
Expand Down

0 comments on commit ce331ff

Please sign in to comment.