-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
401 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
use std::sync::Arc; | ||
|
||
use bytes::{Buf, Bytes}; | ||
use http::{ | ||
header::{CONTENT_LENGTH, CONTENT_TYPE, ETAG}, | ||
Method, Request, Response, | ||
}; | ||
use http_body::Body; | ||
use http_body_util::{BodyExt, Empty, Full}; | ||
use itertools::Itertools; | ||
use percent_encoding::utf8_percent_encode; | ||
|
||
use crate::{ | ||
path::Path, | ||
remotes::{ | ||
aws::{options::S3Options, sign::Sign, S3Error, S3ResponseError, STRICT_PATH_ENCODE_SET}, | ||
http::{BoxBody, DynHttpClient, HttpClient}, | ||
serde::{ | ||
CompleteMultipartUploadRequest, CompleteMultipartUploadRequestPart, | ||
InitiateMultipartUploadResult, MultipartPart, | ||
}, | ||
}, | ||
Error, | ||
}; | ||
|
||
pub(crate) struct MultipartUpload { | ||
options: Arc<S3Options>, | ||
path: Path, | ||
client: Arc<dyn DynHttpClient>, | ||
} | ||
|
||
impl MultipartUpload { | ||
pub fn new(options: Arc<S3Options>, path: Path, client: Arc<dyn DynHttpClient>) -> Self { | ||
Self { | ||
options, | ||
path, | ||
client, | ||
} | ||
} | ||
|
||
async fn check_response(response: Response<BoxBody>) -> Result<Response<BoxBody>, Error> { | ||
if !response.status().is_success() { | ||
return Err(Error::Other( | ||
format!( | ||
"failed to write to S3, HTTP status: {} content: {}", | ||
response.status(), | ||
String::from_utf8_lossy( | ||
&response.into_body().collect().await.unwrap().to_bytes() | ||
) | ||
) | ||
.into(), | ||
)); | ||
} | ||
Ok(response) | ||
} | ||
|
||
async fn send_request<B>(&self, mut request: Request<B>) -> Result<Response<BoxBody>, Error> | ||
where | ||
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static, | ||
B::Error: std::error::Error + Send + Sync + 'static, | ||
{ | ||
request | ||
.sign(&self.options) | ||
.await | ||
.map_err(|e| Error::S3Error(S3Error::from(e)))?; | ||
let response = self | ||
.client | ||
.send_request(request) | ||
.await | ||
.map_err(|e| Error::S3Error(S3Error::from(e)))?; | ||
Self::check_response(response).await | ||
} | ||
|
||
pub(crate) async fn upload_once<B>(&self, size: usize, body: B) -> Result<(), Error> | ||
where | ||
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static, | ||
B::Error: std::error::Error + Send + Sync + 'static, | ||
{ | ||
let url = format!( | ||
"{}/{}", | ||
self.options.endpoint, | ||
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET) | ||
); | ||
let request = Request::builder() | ||
.uri(url) | ||
.method(Method::PUT) | ||
.header(CONTENT_LENGTH, size) | ||
.body(body) | ||
.map_err(|e| Error::Other(e.into()))?; | ||
let _ = self.send_request(request).await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
pub(crate) async fn initiate(&self) -> Result<String, Error> { | ||
let url = format!( | ||
"{}/{}?uploads", | ||
self.options.endpoint, | ||
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET) | ||
); | ||
let request = Request::builder() | ||
.uri(url) | ||
.method(Method::POST) | ||
.body(Empty::new()) | ||
.map_err(|e| Error::Other(e.into()))?; | ||
let response = self.send_request(request).await?; | ||
let result: InitiateMultipartUploadResult = quick_xml::de::from_reader( | ||
response | ||
.collect() | ||
.await | ||
.map_err(S3Error::from)? | ||
.aggregate() | ||
.reader(), | ||
) | ||
.map_err(S3Error::from)?; | ||
|
||
Ok(result.upload_id) | ||
} | ||
|
||
pub(crate) async fn upload_part<B>( | ||
&self, | ||
upload_id: &str, | ||
part_num: usize, | ||
size: usize, | ||
body: B, | ||
) -> Result<MultipartPart, Error> | ||
where | ||
B: Body<Data = Bytes> + Clone + Unpin + Send + Sync + 'static, | ||
B::Error: std::error::Error + Send + Sync + 'static, | ||
{ | ||
let url = format!( | ||
"{}/{}?partNumber={}&uploadId={}", | ||
self.options.endpoint, | ||
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET), | ||
part_num + 1, | ||
utf8_percent_encode(upload_id, &STRICT_PATH_ENCODE_SET), | ||
); | ||
let request = Request::builder() | ||
.uri(url) | ||
.method(Method::PUT) | ||
.header(CONTENT_LENGTH, size) | ||
.body(body) | ||
.map_err(|e| Error::Other(e.into()))?; | ||
let response = self.send_request(request).await?; | ||
let etag = response | ||
.headers() | ||
.get(ETAG) | ||
.ok_or_else(|| Error::Other("etag header not found".into()))? | ||
.to_str() | ||
.map_err(|e| Error::Other(e.into()))?; | ||
|
||
Ok(MultipartPart { | ||
part_num, | ||
etag: etag.to_string(), | ||
}) | ||
} | ||
|
||
pub(crate) async fn complete_part( | ||
&self, | ||
upload_id: &str, | ||
parts: &[MultipartPart], | ||
) -> Result<(), Error> { | ||
let url = format!( | ||
"{}/{}?uploadId={}", | ||
self.options.endpoint, | ||
utf8_percent_encode(self.path.as_ref(), &STRICT_PATH_ENCODE_SET), | ||
utf8_percent_encode(upload_id, &STRICT_PATH_ENCODE_SET), | ||
); | ||
let content = quick_xml::se::to_string(&CompleteMultipartUploadRequest { | ||
part: parts | ||
.iter() | ||
.map(|p| CompleteMultipartUploadRequestPart { | ||
part_number: p.part_num + 1, | ||
etag: p.etag.to_owned(), | ||
}) | ||
.collect_vec(), | ||
}) | ||
.map_err(S3Error::from)?; | ||
|
||
let request = Request::builder() | ||
.uri(url) | ||
.method(Method::POST) | ||
.header(CONTENT_LENGTH, content.len()) | ||
.header(CONTENT_TYPE, "application/xml") | ||
.body(Full::new(Bytes::from(content))) | ||
.map_err(|e| Error::Other(e.into()))?; | ||
let response = self.send_request(request).await?; | ||
// still check if there is any error because S3 might return error for status code 200 | ||
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html#API_CompleteMultipartUpload_Example_4 | ||
let (parts, body) = response.into_parts(); | ||
let maybe_error: S3ResponseError = quick_xml::de::from_reader( | ||
body.collect() | ||
.await | ||
.map_err(S3Error::from)? | ||
.aggregate() | ||
.reader(), | ||
) | ||
.map_err(S3Error::from)?; | ||
if !maybe_error.code.is_empty() { | ||
return Err(Error::Other( | ||
format!("{:#?}, {:?}", parts, maybe_error).into(), | ||
)); | ||
} | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
use std::{mem, sync::Arc}; | ||
|
||
use bytes::{BufMut, BytesMut}; | ||
use http_body_util::Full; | ||
use tokio::{task, task::JoinHandle}; | ||
|
||
use crate::{ | ||
remotes::{aws::multipart_upload::MultipartUpload, serde::MultipartPart}, | ||
Error, IoBuf, Write, | ||
}; | ||
|
||
const S3_PART_MINIMUM_SIZE: usize = 5 * 1024 * 1024; | ||
|
||
pub struct S3Writer { | ||
inner: Arc<MultipartUpload>, | ||
upload_id: Option<Arc<String>>, | ||
next_part_numer: usize, | ||
buf: BytesMut, | ||
|
||
handlers: Vec<JoinHandle<Result<MultipartPart, Error>>>, | ||
} | ||
|
||
impl S3Writer { | ||
pub fn new(inner: Arc<MultipartUpload>) -> Self { | ||
Self { | ||
inner, | ||
upload_id: None, | ||
next_part_numer: 0, | ||
buf: BytesMut::with_capacity(S3_PART_MINIMUM_SIZE), | ||
handlers: vec![], | ||
} | ||
} | ||
|
||
async fn upload_part<F: FnOnce() -> BytesMut>( | ||
&mut self, | ||
fn_bytes_init: F, | ||
) -> Result<(), Error> { | ||
let upload_id = match self.upload_id.clone() { | ||
None => { | ||
let upload_id = Arc::new(self.inner.initiate().await?); | ||
self.upload_id = Some(upload_id.clone()); | ||
upload_id | ||
} | ||
Some(upload_id) => upload_id, | ||
}; | ||
let part_num = self.next_part_numer; | ||
self.next_part_numer += 1; | ||
|
||
let upload = self.inner.clone(); | ||
let bytes = mem::replace(&mut self.buf, fn_bytes_init()).freeze(); | ||
self.handlers.push(task::spawn(async move { | ||
upload | ||
.upload_part(&upload_id, part_num, bytes.len(), Full::new(bytes)) | ||
.await | ||
})); | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
impl Write for S3Writer { | ||
async fn write_all<B: IoBuf>(&mut self, buf: B) -> (Result<(), Error>, B) { | ||
if self.buf.len() > S3_PART_MINIMUM_SIZE { | ||
if let Err(e) = self | ||
.upload_part(|| BytesMut::with_capacity(S3_PART_MINIMUM_SIZE)) | ||
.await | ||
{ | ||
return (Err(e), buf); | ||
} | ||
} | ||
self.buf.put(buf.as_slice()); | ||
|
||
(Ok(()), buf) | ||
} | ||
|
||
async fn sync_data(&self) -> Result<(), Error> { | ||
Ok(()) | ||
} | ||
|
||
async fn sync_all(&self) -> Result<(), Error> { | ||
Ok(()) | ||
} | ||
|
||
async fn close(&mut self) -> Result<(), Error> { | ||
let Some(upload_id) = self.upload_id.take() else { | ||
if !self.buf.is_empty() { | ||
let bytes = mem::replace(&mut self.buf, BytesMut::new()).freeze(); | ||
|
||
self.inner | ||
.upload_once(bytes.len(), Full::new(bytes)) | ||
.await?; | ||
} | ||
return Ok(()); | ||
}; | ||
if !self.buf.is_empty() { | ||
self.upload_part(BytesMut::new).await?; | ||
} | ||
let mut parts = Vec::with_capacity(self.handlers.len()); | ||
for handle in self.handlers.drain(..) { | ||
parts.push(handle.await.map_err(|e| Error::Io(e.into()))??) | ||
} | ||
assert_eq!(self.next_part_numer, parts.len()); | ||
self.inner.complete_part(&upload_id, &parts).await?; | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
#[ignore] | ||
#[cfg(all( | ||
feature = "aws", | ||
feature = "tokio-http", | ||
not(feature = "completion-based") | ||
))] | ||
#[tokio::test] | ||
async fn test_s3() { | ||
use std::sync::Arc; | ||
|
||
use bytes::Bytes; | ||
|
||
use crate::{ | ||
remotes::aws::{ | ||
multipart_upload::MultipartUpload, options::S3Options, writer::S3Writer, | ||
AwsCredential, | ||
}, | ||
Write, | ||
}; | ||
|
||
let region = "ap-southeast-2"; | ||
let options = Arc::new(S3Options { | ||
endpoint: "endpoint".into(), | ||
credential: Some(AwsCredential { | ||
key_id: "key".to_string(), | ||
secret_key: "secret_key".to_string(), | ||
token: None, | ||
}), | ||
region: region.into(), | ||
sign_payload: true, | ||
checksum: false, | ||
}); | ||
let client = Arc::new(crate::impls::remotes::http::tokio::TokioClient::new()); | ||
let upload = MultipartUpload::new(options, "read-write.txt".into(), client); | ||
let mut writer = S3Writer::new(Arc::new(upload)); | ||
|
||
let (result, _) = Write::write_all(&mut writer, Bytes::from("hello! Fusio!")).await; | ||
result.unwrap(); | ||
let (result, _) = Write::write_all(&mut writer, Bytes::from("hello! World!")).await; | ||
result.unwrap(); | ||
Write::close(&mut writer).await.unwrap(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
pub mod aws; | ||
#[cfg(feature = "http")] | ||
pub mod http; | ||
#[cfg(feature = "aws")] | ||
pub(crate) mod serde; |
Oops, something went wrong.