diff --git a/Cargo.toml b/Cargo.toml index 9f2a7f12..d33cfdc9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "mlar", "mla-fuzz-afl", "bindings/C", + "bindings/python", ] [profile.release] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml new file mode 100644 index 00000000..ca5d3786 --- /dev/null +++ b/bindings/python/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "pymla" +version = "0.1.0" +edition = "2021" +authors = ["Camille Mougey "] +license = "LGPL-3.0-only" +description = "Multi Layer Archive - A pure rust encrypted and compressed archive file format" +homepage = "https://github.com/ANSSI-FR/MLA" +repository = "https://github.com/ANSSI-FR/MLA" +readme = "../../README.md" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "pymla" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.19.0" +mla = { version = "1", features = ["send"], path = "../../mla"} diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml new file mode 100644 index 00000000..b076d6e9 --- /dev/null +++ b/bindings/python/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[project] +name = "mla" +description = "Bindings for MLA Archive manipulation" +requires-python = ">=3.8" +keywords = ["archive", "mla"] +license = {file = "../../LICENSE.md"} +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] + +[project.urls] +documentation = "https://github.com/ANSSI-FR/MLA" +repository = "https://github.com/ANSSI-FR/MLA" + +[tool.maturin] +features = ["pyo3/extension-module"] +module-name = "mla" diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs new file mode 100644 index 00000000..a22f4126 --- /dev/null +++ b/bindings/python/src/lib.rs @@ -0,0 +1,598 @@ +use std::{borrow::Cow, collections::HashMap, io::Read}; + +use mla::{ + config::{ArchiveReaderConfig, ArchiveWriterConfig}, + ArchiveReader, ArchiveWriter, Layers, +}; +use pyo3::{ + create_exception, + exceptions::{PyKeyError, PyRuntimeError}, + prelude::*, +}; + +// -------- Error handling -------- + +/// Wrapper over MLA custom error, due to the "orphan rule" +/// - WrappedMLA: MLA specifics errors +/// - WrappedPy: Python related errors +#[derive(Debug)] +enum WrappedError { + WrappedMLA(mla::errors::Error), + WrappedPy(PyErr), +} + +// Add a dedicated MLA Exception (mla.MLAError) and associated sub-Exception +// IOError and AssertionError are not mapped, as they already map to Python Exception +create_exception!(mla, MLAError, pyo3::exceptions::PyException); +create_exception!(mla, WrongMagic, MLAError, "Wrong magic, must be \"MLA\""); +create_exception!( + mla, + UnsupportedVersion, + MLAError, + "Unsupported version, must be 1" +); +create_exception!( + mla, + InvalidECCKeyFormat, + MLAError, + "Supplied ECC key is not in the expected format" +); +create_exception!(mla, WrongBlockSubFileType, MLAError, "Wrong BlockSubFile magic has been encountered. Is the deserializion tarting at the beginning of a block?"); +create_exception!( + mla, + UTF8ConversionError, + MLAError, + "An error has occurred while converting into UTF8. This error could" +); +create_exception!( + mla, + FilenameTooLong, + MLAError, + "Filenames have a limited size `FILENAME_MAX_SIZE`" +); +create_exception!( + mla, + WrongArchiveWriterState, + MLAError, + "The writer state is not in the expected state for the current operation" +); +create_exception!( + mla, + WrongReaderState, + MLAError, + "The reader state is not in the expected state for the current operation" +); +create_exception!( + mla, + WrongWriterState, + MLAError, + "The writer state is not in the expected state for the current operation" +); +create_exception!( + mla, + RandError, + MLAError, + "Error with the inner random generator" +); +create_exception!( + mla, + PrivateKeyNeeded, + MLAError, + "A Private Key is required to decrypt the encrypted cipher key" +); +create_exception!( + mla, + DeserializationError, + MLAError, + "Deserialization error. May happens when starting from a wrong offset / version mismatch" +); +create_exception!( + mla, + SerializationError, + MLAError, + "Serialization error. May happens on I/O errors" +); +create_exception!(mla, MissingMetadata, MLAError, "Missing metadata (usually means the footer has not been correctly read, a repair might be needed)"); +create_exception!( + mla, + BadAPIArgument, + MLAError, + "Error returned on API call with incorrect argument" +); +create_exception!( + mla, + EndOfStream, + MLAError, + "End of stream reached, no more data should be expected" +); +create_exception!( + mla, + ConfigError, + MLAError, + "An error happens in the configuration" +); +create_exception!(mla, DuplicateFilename, MLAError, "Filename already used"); +create_exception!( + mla, + AuthenticatedDecryptionWrongTag, + MLAError, + "Wrong tag while decrypting authenticated data" +); +create_exception!( + mla, + HKDFInvalidKeyLength, + MLAError, + "Unable to expand while using the HKDF" +); + +// Convert potentials errors to the wrapped type + +impl From for WrappedError { + fn from(err: mla::errors::Error) -> Self { + WrappedError::WrappedMLA(err) + } +} + +impl From for WrappedError { + fn from(err: mla::errors::ConfigError) -> Self { + WrappedError::WrappedMLA(mla::errors::Error::ConfigError(err)) + } +} + +impl From for WrappedError { + fn from(err: std::io::Error) -> Self { + WrappedError::WrappedPy(err.into()) + } +} + +/// Convert back the wrapped type to Python errors +impl From for PyErr { + fn from(err: WrappedError) -> PyErr { + match err { + WrappedError::WrappedMLA(inner_err) => { + match inner_err { + mla::errors::Error::IOError(err) => PyErr::new::(err), + mla::errors::Error::AssertionError(msg) => PyErr::new::(msg), + mla::errors::Error::WrongMagic => PyErr::new::("Wrong magic, must be \"MLA\""), + mla::errors::Error::UnsupportedVersion => PyErr::new::("Unsupported version, must be 1"), + mla::errors::Error::InvalidECCKeyFormat => PyErr::new::("Supplied ECC key is not in the expected format"), + mla::errors::Error::WrongBlockSubFileType => PyErr::new::("Wrong BlockSubFile magic has been encountered. Is the deserializion tarting at the beginning of a block?"), + mla::errors::Error::UTF8ConversionError(err) => PyErr::new::(err), + mla::errors::Error::FilenameTooLong => PyErr::new::("Filenames have a limited size `FILENAME_MAX_SIZE`"), + mla::errors::Error::WrongArchiveWriterState { current_state, expected_state } => PyErr::new::(format!("The writer state is not in the expected state for the current operation. Current state: {:?}, expected state: {:?}", current_state, expected_state)), + mla::errors::Error::WrongReaderState(msg) => PyErr::new::(msg), + mla::errors::Error::WrongWriterState(msg) => PyErr::new::(msg), + mla::errors::Error::RandError(err) => PyErr::new::(format!("{:}", err)), + mla::errors::Error::PrivateKeyNeeded => PyErr::new::("A Private Key is required to decrypt the encrypted cipher key"), + mla::errors::Error::DeserializationError => PyErr::new::("Deserialization error. May happens when starting from a wrong offset / version mismatch"), + mla::errors::Error::SerializationError => PyErr::new::("Serialization error. May happens on I/O errors"), + mla::errors::Error::MissingMetadata => PyErr::new::("Missing metadata (usually means the footer has not been correctly read, a repair might be needed)"), + mla::errors::Error::BadAPIArgument(msg) => PyErr::new::(msg), + mla::errors::Error::EndOfStream => PyErr::new::("End of stream reached, no more data should be expected"), + mla::errors::Error::ConfigError(err) => PyErr::new::(format!("{:}", err)), + mla::errors::Error::DuplicateFilename => PyErr::new::("Filename already used"), + mla::errors::Error::AuthenticatedDecryptionWrongTag => PyErr::new::("Wrong tag while decrypting authenticated data"), + mla::errors::Error::HKDFInvalidKeyLength => PyErr::new::("Unable to expand while using the HKDF"), + } + }, + WrappedError::WrappedPy(inner_err) => inner_err + } + } +} + +// -------- mla.FileMetadata -------- + +#[pyclass] +struct FileMetadata { + size: Option, + hash: Option<[u8; 32]>, +} + +#[pymethods] +impl FileMetadata { + #[getter] + fn size(&self) -> Option { + self.size + } + + #[getter] + fn hash(&self) -> Option> { + match self.hash { + Some(ref hash) => Some(Cow::Borrowed(hash)), + None => None, + } + } + + fn __repr__(&self) -> String { + format!("", self.size, self.hash) + } +} + +// -------- mla.ConfigWriter -------- + +// from mla::layers::DEFAULT_COMPRESSION_LEVEL +const DEFAULT_COMPRESSION_LEVEL: u32 = 5; + +#[pyclass] +struct WriterConfig { + inner: ArchiveWriterConfig, +} + +#[pymethods] +impl WriterConfig { + #[new] + #[pyo3(signature = (layers=None, compression_level=DEFAULT_COMPRESSION_LEVEL))] + fn new(layers: Option, compression_level: u32) -> Result { + let mut output = WriterConfig { + inner: ArchiveWriterConfig::new(), + }; + if let Some(layers_enabled) = layers { + output + .inner + .set_layers(Layers::from_bits(layers_enabled).ok_or( + mla::errors::Error::BadAPIArgument(format!("Unknown layers")), + )?); + } + output.inner.with_compression_level(compression_level)?; + + Ok(output) + } + + #[getter] + fn layers(&self) -> Result { + Ok(self.inner.to_persistent()?.layers_enabled.bits()) + } + + /// Enable a layer + fn enable_layer(mut slf: PyRefMut, layer: u8) -> Result, WrappedError> { + slf.inner.enable_layer( + Layers::from_bits(layer) + .ok_or(mla::errors::Error::BadAPIArgument(format!("Unknown layer")))?, + ); + Ok(slf) + } + + /// Disable a layer + fn disable_layer(mut slf: PyRefMut, layer: u8) -> Result, WrappedError> { + slf.inner.disable_layer( + Layers::from_bits(layer) + .ok_or(mla::errors::Error::BadAPIArgument(format!("Unknown layer")))?, + ); + Ok(slf) + } + + /// Set several layers at once + fn set_layers(mut slf: PyRefMut, layers: u8) -> Result, WrappedError> { + slf.inner.set_layers(Layers::from_bits(layers).ok_or( + mla::errors::Error::BadAPIArgument(format!("Unknown layers")), + )?); + Ok(slf) + } + + /// Set the compression level + /// compression level (0-11); bigger values cause denser, but slower compression + fn with_compression_level( + mut slf: PyRefMut, + compression_level: u32, + ) -> Result, WrappedError> { + slf.inner.with_compression_level(compression_level)?; + Ok(slf) + } +} + +// -------- mla.MLAFile -------- + +/// `ArchiveWriter` is a generic type. To avoid generating several Python implementation +/// (see https://pyo3.rs/v0.20.2/class.html#no-generic-parameters), this enum explicitely +/// instanciate `ArchiveWriter` for common & expected types +/// +/// Additionnaly, as the GC in Python might drop objects at any time, we need to use +/// `'static` lifetime for the writer. This should not be a problem as the writer is not +/// supposed to be used after the drop of the parent object +/// (see https://pyo3.rs/v0.20.2/class.html#no-lifetime-parameters) +enum ExplicitWriters { + FileWriter(ArchiveWriter<'static, std::fs::File>), +} + +/// Wrap calls to the inner type +impl ExplicitWriters { + fn finalize(&mut self) -> Result<(), mla::errors::Error> { + match self { + ExplicitWriters::FileWriter(writer) => { + writer.finalize()?; + Ok(()) + } + } + } +} + +/// See `ExplicitWriters` for details +enum ExplicitReaders { + FileReader(ArchiveReader<'static, std::fs::File>), +} + +/// Wrap calls to the inner type +impl ExplicitReaders { + fn list_files(&self) -> Result, mla::errors::Error> { + match self { + ExplicitReaders::FileReader(reader) => reader.list_files(), + } + } +} + +/// Opening Mode for a MLAFile +enum OpeningModeInner { + Read(ExplicitReaders), + Write(ExplicitWriters), +} + +#[pyclass] +pub struct MLAFile { + /// Wrapping over the rust object, depending on the opening mode + inner: OpeningModeInner, + /// Path of the file, used for messages + path: String, +} + +/// Used to check whether the opening mode is the expected one, and unwrap it +/// return a BadAPI argument error if not +/// ```text +/// let inner = check_mode!(self, Read); +/// ``` +macro_rules! check_mode { + ( $self:expr, $x:ident ) => {{ + match &$self.inner { + OpeningModeInner::$x(inner) => inner, + _ => { + return Err(mla::errors::Error::BadAPIArgument(format!( + "This API is only callable in {:} mode", + stringify!($x) + )) + .into()) + } + } + }}; + ( mut $self:expr, $x:ident ) => {{ + match &mut $self.inner { + OpeningModeInner::$x(inner) => inner, + _ => { + return Err(mla::errors::Error::BadAPIArgument(format!( + "This API is only callable in {:} mode", + stringify!($x) + )) + .into()) + } + } + }}; +} + +#[pymethods] +impl MLAFile { + #[new] + #[pyo3(signature = (path, mode="r"))] + fn new(path: &str, mode: &str) -> Result { + match mode { + "r" => { + let config = ArchiveReaderConfig::new(); + let input_file = std::fs::File::open(path)?; + let arch_reader = ArchiveReader::from_config(input_file, config)?; + Ok(MLAFile { + inner: OpeningModeInner::Read(ExplicitReaders::FileReader(arch_reader)), + path: path.to_string(), + }) + } + "w" => { + let mut config = ArchiveWriterConfig::new(); + config.enable_layer(Layers::COMPRESS); + let output_file = std::fs::File::create(path)?; + let arch_writer = ArchiveWriter::from_config(output_file, config)?; + Ok(MLAFile { + inner: OpeningModeInner::Write(ExplicitWriters::FileWriter(arch_writer)), + path: path.to_string(), + }) + } + _ => Err(mla::errors::Error::BadAPIArgument(format!( + "Unknown mode {}, use 'r' or 'w'", + mode + )) + .into()), + } + } + + fn __repr__(&self) -> String { + format!( + "", + self.path, + match self.inner { + OpeningModeInner::Read(_) => "r", + OpeningModeInner::Write(_) => "w", + } + ) + } + + /// Return the list of files in the archive + fn keys(&self) -> Result, WrappedError> { + let inner = check_mode!(self, Read); + Ok(inner.list_files()?.map(|x| x.to_string()).collect()) + } + + /// Return the size of a file in the archive + #[pyo3(signature = (include_size=false, include_hash=false))] + fn list_files( + &mut self, + include_size: bool, + include_hash: bool, + ) -> Result, WrappedError> { + let inner = check_mode!(mut self, Read); + + let mut output = HashMap::new(); + let iter: Vec = inner.list_files()?.cloned().collect(); + for fname in iter { + let mut metadata = FileMetadata { + size: None, + hash: None, + }; + match inner { + ExplicitReaders::FileReader(mla) => { + if include_size { + metadata.size = Some( + mla.get_file(fname.clone())? + .ok_or(WrappedError::WrappedPy(PyRuntimeError::new_err(format!( + "File {} not found", + fname + ))))? + .size, + ); + } + if include_hash { + metadata.hash = + Some(mla.get_hash(&fname)?.ok_or(WrappedError::WrappedPy( + PyRuntimeError::new_err(format!("File {} not found", fname)), + ))?); + } + } + } + output.insert(fname.to_string(), metadata); + } + Ok(output) + } + + /// Return whether the file is in the archive + fn __contains__(&self, key: &str) -> Result { + let inner = check_mode!(self, Read); + Ok(inner.list_files()?.any(|x| x == key)) + } + + /// Return the content of a file as bytes + fn __getitem__(&mut self, key: &str) -> Result, WrappedError> { + let inner = check_mode!(mut self, Read); + match inner { + ExplicitReaders::FileReader(reader) => { + let mut buf = Vec::new(); + let file = reader.get_file(key.to_string())?; + if let Some(mut archive_file) = file { + archive_file.data.read_to_end(&mut buf)?; + Ok(Cow::Owned(buf)) + } else { + Err(WrappedError::WrappedPy(PyKeyError::new_err(format!( + "File {} not found", + key + ))) + .into()) + } + } + } + } + + /// Add a file to the archive + fn __setitem__(&mut self, key: &str, value: &[u8]) -> Result<(), WrappedError> { + let writer = check_mode!(mut self, Write); + match writer { + ExplicitWriters::FileWriter(writer) => { + let mut reader = std::io::Cursor::new(value); + writer.add_file(key, value.len() as u64, &mut reader)?; + Ok(()) + } + } + } + + /// Return the number of file in the archive + fn __len__(&self) -> Result { + let inner = check_mode!(self, Read); + Ok(inner.list_files()?.count()) + } + + /// Finalize the archive creation. This API *must* be called or essential records will no be written + /// An archive can only be finalized once + fn finalize(&mut self) -> Result<(), WrappedError> { + let inner = check_mode!(mut self, Write); + Ok(inner.finalize()?) + } + + // Context management protocol (PEP 0343) + // https://docs.python.org/3/reference/datamodel.html#context-managers + fn __enter__(slf: PyRef) -> PyRef { + slf + } + + fn __exit__( + &mut self, + exc_type: Option<&PyAny>, + _exc_value: Option<&PyAny>, + _traceback: Option<&PyAny>, + ) -> Result { + if exc_type.is_some() { + // An exception occured, let it be raised again + return Ok(false); + } + + match self.inner { + OpeningModeInner::Read(_) => { + // Nothing to do, dropping this object should close the inner stream + } + OpeningModeInner::Write(ref mut writer) => { + // Finalize. If an exception occured, raise it + writer.finalize()?; + } + } + Ok(false) + } +} + +// -------- Python module instanciation -------- + +/// Instanciate the Python module +#[pymodule] +#[pyo3(name = "mla")] +fn pymla(py: Python, m: &PyModule) -> PyResult<()> { + // Classes + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Exceptions + m.add("MLAError", py.get_type::())?; + m.add("WrongMagic", py.get_type::())?; + m.add("UnsupportedVersion", py.get_type::())?; + m.add("InvalidECCKeyFormat", py.get_type::())?; + m.add( + "WrongBlockSubFileType", + py.get_type::(), + )?; + m.add("UTF8ConversionError", py.get_type::())?; + m.add("FilenameTooLong", py.get_type::())?; + m.add( + "WrongArchiveWriterState", + py.get_type::(), + )?; + m.add("WrongReaderState", py.get_type::())?; + m.add("WrongWriterState", py.get_type::())?; + m.add("RandError", py.get_type::())?; + m.add("PrivateKeyNeeded", py.get_type::())?; + m.add( + "DeserializationError", + py.get_type::(), + )?; + m.add("SerializationError", py.get_type::())?; + m.add("MissingMetadata", py.get_type::())?; + m.add("BadAPIArgument", py.get_type::())?; + m.add("EndOfStream", py.get_type::())?; + m.add("ConfigError", py.get_type::())?; + m.add("DuplicateFilename", py.get_type::())?; + m.add( + "AuthenticatedDecryptionWrongTag", + py.get_type::(), + )?; + m.add( + "HKDFInvalidKeyLength", + py.get_type::(), + )?; + + // Add constants + m.add("LAYER_COMPRESS", Layers::COMPRESS.bits())?; + m.add("LAYER_ENCRYPT", Layers::ENCRYPT.bits())?; + m.add("LAYER_DEFAULT", Layers::DEFAULT.bits())?; + m.add("LAYER_EMPTY", Layers::EMPTY.bits())?; + m.add("DEFAULT_COMPRESSION_LEVEL", DEFAULT_COMPRESSION_LEVEL)?; + Ok(()) +} diff --git a/bindings/python/tests/test_mla.py b/bindings/python/tests/test_mla.py new file mode 100644 index 00000000..f2103cb5 --- /dev/null +++ b/bindings/python/tests/test_mla.py @@ -0,0 +1,245 @@ +import hashlib +import pytest +import tempfile + +import mla +from mla import MLAFile, MLAError + +# Test data +FILES = { + "file1": b"DATA1", + "file2": b"DATA_2", +} + +@pytest.fixture +def basic_archive(): + "Create a temporary archive and return its path" + fname = tempfile.mkstemp(suffix=".mla")[1] + archive = MLAFile(fname, "w") + for name, data in FILES.items(): + archive[name] = data + archive.finalize() + return fname + +def test_layers_bitflag_export(): + assert mla.LAYER_DEFAULT == mla.LAYER_COMPRESS | mla.LAYER_ENCRYPT + assert mla.LAYER_EMPTY == 0 + assert mla.LAYER_DEFAULT != mla.LAYER_EMPTY + +def test_bad_mode(): + "Ensure MLAFile with an unknown mode raise an error" + target_file = "/tmp/must_not_exists" + with pytest.raises(mla.BadAPIArgument): + MLAFile(target_file, "x") + # Ensure the file has not been created + with pytest.raises(FileNotFoundError): + open(target_file) + +def test_repr(): + "Ensure the repr is correct" + path = tempfile.mkstemp(suffix=".mla")[1] + archive = MLAFile(path, "w") + assert repr(archive) == "" % path + archive.finalize() + +def test_forbidden_in_write_mode(): + "Ensure some API cannot be called in write mode" + archive = MLAFile(tempfile.mkstemp(suffix=".mla")[1], "w") + + # .keys + with pytest.raises(mla.BadAPIArgument): + archive.keys() + + # __contains__ + with pytest.raises(mla.BadAPIArgument): + "name" in archive + + # __getitem__ + with pytest.raises(mla.BadAPIArgument): + archive["name"] + + # __len__ + with pytest.raises(mla.BadAPIArgument): + len(archive) + + # list_files + with pytest.raises(mla.BadAPIArgument): + archive.list_files() + +def test_forbidden_in_read_mode(basic_archive): + "Ensure some API cannot be called in write mode" + archive = MLAFile(basic_archive) + + # __setitem__ + with pytest.raises(mla.BadAPIArgument): + archive["file"] = b"data" + + # .finalize + with pytest.raises(mla.BadAPIArgument): + archive.finalize() + +def test_read_api(basic_archive): + "Test basics read APIs" + archive = MLAFile(basic_archive) + + # .keys + assert sorted(archive.keys()) == sorted(list(FILES.keys())) + + # __contains__ + assert "file1" in archive + assert "file3" not in archive + + # __getitem__ + assert archive["file1"] == FILES["file1"] + assert archive["file2"] == FILES["file2"] + with pytest.raises(KeyError): + archive["file3"] + + # __len__ + assert len(archive) == 2 + +def test_list_files(basic_archive): + "Test list files possibilities" + archive = MLAFile(basic_archive) + + # Basic + assert sorted(archive.list_files()) == sorted(list(FILES.keys())) + + # With size + assert sorted([ + (filename, info.size) for filename, info in archive.list_files(include_size=True).items() + ]) == sorted([ + (filename, len(data)) for filename, data in FILES.items() + ]) + + # With hash + assert sorted([ + (filename, info.hash) for filename, info in archive.list_files(include_hash=True).items() + ]) == sorted([ + (filename, hashlib.sha256(data).digest()) for filename, data in FILES.items() + ]) + + # With size and hash + assert sorted([ + (filename, info.size, info.hash) for filename, info in archive.list_files(include_size=True, include_hash=True).items() + ]) == sorted([ + (filename, len(data), hashlib.sha256(data).digest()) for filename, data in FILES.items() + ]) + +def test_write_api(): + "Test basics write APIs" + path = tempfile.mkstemp(suffix=".mla")[1] + archive = MLAFile(path, "w") + + # __setitem__ + for name, data in FILES.items(): + archive[name] = data + + # close + archive.finalize() + + # Check the resulting archive + archive = MLAFile(path) + assert sorted(archive.keys()) == sorted(list(FILES.keys())) + assert archive["file1"] == FILES["file1"] + assert archive["file2"] == FILES["file2"] + +def test_double_write(): + "Rewriting the file must raise an MLA error" + archive = MLAFile(tempfile.mkstemp(suffix=".mla")[1], "w") + archive["file1"] = FILES["file1"] + + with pytest.raises(mla.DuplicateFilename): + archive["file1"] = FILES["file1"] + +def test_context_read(basic_archive): + "Test reading using a `with` statement (context management protocol)" + with MLAFile(basic_archive) as mla: + assert sorted(mla.keys()) == sorted(list(FILES.keys())) + for name, data in FILES.items(): + assert mla[name] == data + +def test_context_write(): + "Test writing using a `with` statement (context management protocol)" + path = tempfile.mkstemp(suffix=".mla")[1] + with MLAFile(path, "w") as mla: + for name, data in FILES.items(): + mla[name] = data + + # Check the resulting file + with MLAFile(path) as mla: + assert sorted(mla.keys()) == sorted(list(FILES.keys())) + for name, data in FILES.items(): + assert mla[name] == data + +def test_context_write_error(): + "Raise an error during the context write __exit__" + with pytest.raises(mla.WrongArchiveWriterState): + with MLAFile(tempfile.mkstemp(suffix=".mla")[1], "w") as archive: + # INTENTIONNALY BUGGY + # .finalize will be called twice, causing an exception + archive.finalize() + +def test_context_write_error_in_with(): + "Raise an error in the with statement, it must be re-raised" + CustomException = type("CustomException", (Exception,), {}) + with pytest.raises(CustomException): + with MLAFile(tempfile.mkstemp(suffix=".mla")[1], "w") as mla: + # INTENTIONNALY BUGGY + raise CustomException + +def test_writer_config_layers(): + "Test writer config creation for layers" + # Enable and disable layers + config = mla.WriterConfig() + assert config.layers == mla.LAYER_EMPTY + + config = mla.WriterConfig(layers=mla.LAYER_COMPRESS) + assert config.layers == mla.LAYER_COMPRESS + + config.enable_layer(mla.LAYER_ENCRYPT) + assert config.layers == mla.LAYER_COMPRESS | mla.LAYER_ENCRYPT + + config.disable_layer(mla.LAYER_COMPRESS) + assert config.layers == mla.LAYER_ENCRYPT + + config.disable_layer(mla.LAYER_ENCRYPT) + assert config.layers == mla.LAYER_EMPTY + + # Check for error on unknown layer (0xFF) + with pytest.raises(mla.BadAPIArgument): + config.enable_layer(0xFF) + + with pytest.raises(mla.BadAPIArgument): + config.disable_layer(0xFF) + + with pytest.raises(mla.BadAPIArgument): + config.set_layers(0xFF) + + with pytest.raises(mla.BadAPIArgument): + config = mla.WriterConfig(layers=0xFF) + + # Chaining + config = mla.WriterConfig().enable_layer( + mla.LAYER_COMPRESS + ).enable_layer( + mla.LAYER_ENCRYPT + ).disable_layer( + mla.LAYER_COMPRESS + ).set_layers( + mla.LAYER_COMPRESS + ) + assert config.layers == mla.LAYER_COMPRESS + +def test_writer_config_compression(): + "Test compression API in WriterConfig creation" + config = mla.WriterConfig() + with pytest.raises(OverflowError): + config.with_compression_level(-1) + with pytest.raises(mla.ConfigError): + config.with_compression_level(0xFF) + + # Chaining + out = config.with_compression_level(mla.DEFAULT_COMPRESSION_LEVEL) + assert out is config +