Skip to content

Commit

Permalink
feat: impl S3Writer for append
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Oct 18, 2024
1 parent d0f4b06 commit c9fe5f7
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 0 deletions.
12 changes: 12 additions & 0 deletions fusio/src/impls/remotes/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ pub mod credential;
mod error;
#[cfg(feature = "fs")]
pub mod fs;
pub(crate) mod multipart_upload;
pub(crate) mod options;
mod s3;
pub(crate) mod sign;
pub(crate) mod writer;

pub use credential::AwsCredential;
pub use error::S3Error;
pub use s3::S3File;
use serde::Deserialize;

const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
.remove(b'-')
Expand All @@ -17,3 +20,12 @@ const STRICT_ENCODE_SET: percent_encoding::AsciiSet = percent_encoding::NON_ALPH
.remove(b'~');
const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.remove(b'/');
const CHECKSUM_HEADER: &str = "x-amz-checksum-sha256";

#[derive(Default, Debug, Deserialize, PartialEq, Eq)]
#[serde(default, rename_all = "PascalCase")]
pub(crate) struct S3ResponseError {
pub code: String,
pub message: String,
pub resource: String,
pub request_id: String,
}
207 changes: 207 additions & 0 deletions fusio/src/impls/remotes/aws/multipart_upload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use std::sync::Arc;

use bytes::{Buf, Bytes};
use http::{
header::{CONTENT_LENGTH, CONTENT_TYPE, ETAG},
Method, Request, Response,
};
use http_body::Body;
use http_body_util::{BodyExt, Empty, Full};
use itertools::Itertools;
use percent_encoding::utf8_percent_encode;

use crate::{
path::Path,
remotes::{
aws::{options::S3Options, sign::Sign, S3Error, S3ResponseError, STRICT_PATH_ENCODE_SET},
http::{BoxBody, DynHttpClient, HttpClient},
serde::{
CompleteMultipartUploadRequest, CompleteMultipartUploadRequestPart,
InitiateMultipartUploadResult, MultipartPart,
},
},
Error,
};

pub(crate) struct MultipartUpload {
options: Arc<S3Options>,
path: Path,
client: Arc<dyn DynHttpClient>,
}

impl MultipartUpload {
pub fn new(options: Arc<S3Options>, path: Path, client: Arc<dyn DynHttpClient>) -> Self {
Self {
options,
path,
client,
}
}

async fn check_response(response: Response<BoxBody>) -> Result<Response<BoxBody>, Error> {
if !response.status().is_success() {
return Err(Error::Other(
format!(
"failed to write to S3, HTTP status: {} content: {}",
response.status(),
String::from_utf8_lossy(
&response.into_body().collect().await.unwrap().to_bytes()
)
)
.into(),
));
}
Ok(response)
}

async fn send_request<B>(&self, mut request: Request<B>) -> Result<Response<BoxBody>, Error>
where
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
request
.sign(&self.options)
.await
.map_err(|e| Error::S3Error(S3Error::from(e)))?;
let response = self
.client
.send_request(request)
.await
.map_err(|e| Error::S3Error(S3Error::from(e)))?;
Self::check_response(response).await
}

pub(crate) async fn upload_once<B>(&self, size: usize, body: B) -> Result<(), Error>
where
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let url = format!(
"{}/{}",
self.options.endpoint,
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET)
);
let request = Request::builder()
.uri(url)
.method(Method::PUT)
.header(CONTENT_LENGTH, size)
.body(body)
.map_err(|e| Error::Other(e.into()))?;
let _ = self.send_request(request).await?;

Ok(())
}

pub(crate) async fn initiate(&self) -> Result<String, Error> {
let url = format!(
"{}/{}?uploads",
self.options.endpoint,
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET)
);
let request = Request::builder()
.uri(url)
.method(Method::POST)
.body(Empty::new())
.map_err(|e| Error::Other(e.into()))?;
let response = self.send_request(request).await?;
let result: InitiateMultipartUploadResult = quick_xml::de::from_reader(
response
.collect()
.await
.map_err(S3Error::from)?
.aggregate()
.reader(),
)
.map_err(S3Error::from)?;

Ok(result.upload_id)
}

pub(crate) async fn upload_part<B>(
&self,
upload_id: &str,
part_num: usize,
size: usize,
body: B,
) -> Result<MultipartPart, Error>
where
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let url = format!(
"{}/{}?partNumber={}&uploadId={}",
self.options.endpoint,
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET),
part_num + 1,
utf8_percent_encode(upload_id, &STRICT_PATH_ENCODE_SET),
);
let request = Request::builder()
.uri(url)
.method(Method::PUT)
.header(CONTENT_LENGTH, size)
.body(body)
.map_err(|e| Error::Other(e.into()))?;
let response = self.send_request(request).await?;
let etag = response
.headers()
.get(ETAG)
.ok_or_else(|| Error::Other("etag header not found".into()))?
.to_str()
.map_err(|e| Error::Other(e.into()))?;

Ok(MultipartPart {
part_num,
etag: etag.to_string(),
})
}

pub(crate) async fn complete_part(
&self,
upload_id: &str,
parts: &[MultipartPart],
) -> Result<(), Error> {
let url = format!(
"{}/{}?uploadId={}",
self.options.endpoint,
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET),
utf8_percent_encode(upload_id, &STRICT_PATH_ENCODE_SET),
);
let content = quick_xml::se::to_string(&CompleteMultipartUploadRequest {
part: parts
.iter()
.map(|p| CompleteMultipartUploadRequestPart {
part_number: p.part_num + 1,
etag: p.etag.to_owned(),
})
.collect_vec(),
})
.map_err(S3Error::from)?;

let request = Request::builder()
.uri(url)
.method(Method::POST)
.header(CONTENT_LENGTH, content.len())
.header(CONTENT_TYPE, "application/xml")
.body(Full::new(Bytes::from(content)))
.map_err(|e| Error::Other(e.into()))?;
let response = self.send_request(request).await?;
// still check if there is any error because S3 might return error for status code 200
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html#API_CompleteMultipartUpload_Example_4
let (parts, body) = response.into_parts();
let maybe_error: S3ResponseError = quick_xml::de::from_reader(
body.collect()
.await
.map_err(S3Error::from)?
.aggregate()
.reader(),
)
.map_err(S3Error::from)?;
if !maybe_error.code.is_empty() {
return Err(Error::Other(
format!("{:#?}, {:?}", parts, maybe_error).into(),
));
}

Ok(())
}
}
153 changes: 153 additions & 0 deletions fusio/src/impls/remotes/aws/writer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::{mem, sync::Arc};

use bytes::{BufMut, BytesMut};
use http_body_util::Full;
use tokio::{task, task::JoinHandle};

use crate::{
remotes::{aws::multipart_upload::MultipartUpload, serde::MultipartPart},
Error, IoBuf, Write,
};

const S3_PART_MINIMUM_SIZE: usize = 5 * 1024 * 1024;

pub struct S3Writer {
inner: Arc<MultipartUpload>,
upload_id: Option<Arc<String>>,
next_part_numer: usize,
buf: BytesMut,

handlers: Vec<JoinHandle<Result<MultipartPart, Error>>>,
}

impl S3Writer {
pub fn new(inner: Arc<MultipartUpload>) -> Self {
Self {
inner,
upload_id: None,
next_part_numer: 0,
buf: BytesMut::with_capacity(S3_PART_MINIMUM_SIZE),
handlers: vec![],
}
}

async fn upload_part<F: FnOnce() -> BytesMut>(
&mut self,
fn_bytes_init: F,
) -> Result<(), Error> {
let upload_id = match self.upload_id.clone() {
None => {
let upload_id = Arc::new(self.inner.initiate().await?);
self.upload_id = Some(upload_id.clone());
upload_id
}
Some(upload_id) => upload_id,
};
let part_num = self.next_part_numer;
self.next_part_numer += 1;

let upload = self.inner.clone();
let bytes = mem::replace(&mut self.buf, fn_bytes_init()).freeze();
self.handlers.push(task::spawn(async move {
upload
.upload_part(&upload_id, part_num, bytes.len(), Full::new(bytes))
.await
}));

Ok(())
}
}

impl Write for S3Writer {
async fn write_all<B: IoBuf>(&mut self, buf: B) -> (Result<(), Error>, B) {
if self.buf.len() > S3_PART_MINIMUM_SIZE {
if let Err(e) = self
.upload_part(|| BytesMut::with_capacity(S3_PART_MINIMUM_SIZE))
.await
{
return (Err(e), buf);
}
}
self.buf.put(buf.as_slice());

(Ok(()), buf)
}

async fn sync_data(&self) -> Result<(), Error> {
Ok(())
}

async fn sync_all(&self) -> Result<(), Error> {
Ok(())
}

async fn close(&mut self) -> Result<(), Error> {
let Some(upload_id) = self.upload_id.take() else {
if !self.buf.is_empty() {
let bytes = mem::replace(&mut self.buf, BytesMut::new()).freeze();

self.inner
.upload_once(bytes.len(), Full::new(bytes))
.await?;
}
return Ok(());
};
if !self.buf.is_empty() {
self.upload_part(BytesMut::new).await?;
}
let mut parts = Vec::with_capacity(self.handlers.len());
for handle in self.handlers.drain(..) {
parts.push(handle.await.map_err(|e| Error::Io(e.into()))??)
}
assert_eq!(self.next_part_numer, parts.len());
self.inner.complete_part(&upload_id, &parts).await?;

Ok(())
}
}

#[cfg(test)]
mod tests {
#[ignore]
#[cfg(all(
feature = "aws",
feature = "tokio-http",
not(feature = "completion-based")
))]
#[tokio::test]
async fn test_s3() {
use std::sync::Arc;

use bytes::Bytes;

use crate::{
remotes::aws::{
multipart_upload::MultipartUpload, options::S3Options, writer::S3Writer,
AwsCredential,
},
Write,
};

let region = "ap-southeast-2";
let options = Arc::new(S3Options {
endpoint: "endpoint".into(),
credential: Some(AwsCredential {
key_id: "key".to_string(),
secret_key: "secret_key".to_string(),
token: None,
}),
region: region.into(),
sign_payload: true,
checksum: false,
});
let client = Arc::new(crate::impls::remotes::http::tokio::TokioClient::new());
let upload = MultipartUpload::new(options, "read-write.txt".into(), client);
let mut writer = S3Writer::new(Arc::new(upload));

let (result, _) = Write::write_all(&mut writer, Bytes::from("hello! Fusio!")).await;
result.unwrap();
let (result, _) = Write::write_all(&mut writer, Bytes::from("hello! World!")).await;
result.unwrap();
Write::close(&mut writer).await.unwrap();
}
}
2 changes: 2 additions & 0 deletions fusio/src/impls/remotes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
pub mod aws;
#[cfg(feature = "http")]
pub mod http;
#[cfg(feature = "aws")]
pub(crate) mod serde;
Loading

0 comments on commit c9fe5f7

Please sign in to comment.