Skip to content

Commit

Permalink
Reuse the Client in all cases (#327)
Browse files Browse the repository at this point in the history
* Reuse the Client

* Working, with client per ironoxide

* Reuse the runtime on blocking initializations

* Add constructor for BlockingDeviceContext

* Add more functions to BlockingDeviceContext

* Generate device returns DeviceAddResult again

* Remove double comment

---------

Co-authored-by: Craig Colegrove <[email protected]>
  • Loading branch information
coltfred and giarc3 authored Oct 28, 2024
1 parent 6050eb3 commit b8b426a
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 41 deletions.
70 changes: 59 additions & 11 deletions src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,63 @@ use crate::{
InitAndRotationCheck::{NoRotationNeeded, RotationNeeded},
Result,
};
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};

#[cfg(feature = "beta")]
use crate::search::{BlindIndexSearchInitialize, EncryptedBlindIndexSalt};
use tokio::runtime::Runtime;

/// Struct that is used to hold the regular DeviceContext as well as a runtime that will be used
/// when initializing a BlockingIronOxide. This was added to fix a bug where initializing multiple
/// SDK instances with a single device would hang indefinitely (as each initialization call would
/// create its own runtime but share a request client)
#[derive(Clone, Debug)]
pub struct BlockingDeviceContext {
pub device: DeviceContext,
pub(crate) rt: Arc<Runtime>,
}

impl From<DeviceAddResult> for BlockingDeviceContext {
fn from(value: DeviceAddResult) -> Self {
Self {
device: value.into(),
rt: Arc::new(create_runtime()),
}
}
}

impl BlockingDeviceContext {
pub fn new(device: DeviceContext) -> Self {
Self {
device,
rt: Arc::new(create_runtime()),
}
}
/// ID of the device's owner
pub fn account_id(&self) -> &UserId {
&self.device.auth().account_id()
}
/// ID of the segment
pub fn segment_id(&self) -> usize {
self.device.auth().segment_id()
}
/// Private signing key of the device
pub fn signing_private_key(&self) -> &DeviceSigningKeyPair {
&self.device.auth().signing_private_key()
}
/// Private encryption key of the device
pub fn device_private_key(&self) -> &PrivateKey {
&self.device.device_private_key()
}
}

/// Struct that is used to make authenticated requests to the IronCore API. Instantiated with the details
/// of an account's various ids, device, and signing keys. Once instantiated all operations will be
/// performed in the context of the account provided. Identical to IronOxide but also contains a Runtime.
#[derive(Debug)]
pub struct BlockingIronOxide {
pub(crate) ironoxide: IronOxide,
pub(crate) runtime: tokio::runtime::Runtime,
pub(crate) runtime: Arc<tokio::runtime::Runtime>,
}

impl BlockingIronOxide {
Expand Down Expand Up @@ -293,35 +338,38 @@ fn create_runtime() -> tokio::runtime::Runtime {
/// Initialize the BlockingIronOxide SDK with a device. Verifies that the provided user/segment exists and the provided device
/// keys are valid and exist for the provided account. If successful, returns instance of the BlockingIronOxide SDK.
pub fn initialize(
device_context: &DeviceContext,
device_context: &BlockingDeviceContext,
config: &IronOxideConfig,
) -> Result<BlockingIronOxide> {
let rt = create_runtime();
let maybe_io = rt.block_on(crate::initialize(device_context, config));
let maybe_io = device_context
.rt
.block_on(crate::initialize(&device_context.device, config));
maybe_io.map(|io| BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
})
}

/// Initialize the BlockingIronOxide SDK and check to see if the user that owns this `DeviceContext` is
/// marked for private key rotation, or if any of the groups that the user is an admin of are marked
/// for private key rotation.
pub fn initialize_check_rotation(
device_context: &DeviceContext,
device_context: &BlockingDeviceContext,
config: &IronOxideConfig,
) -> Result<InitAndRotationCheck<BlockingIronOxide>> {
let rt = create_runtime();
let maybe_init = rt.block_on(crate::initialize_check_rotation(device_context, config));
let maybe_init = device_context.rt.block_on(crate::initialize_check_rotation(
&device_context.device,
config,
));
maybe_init.map(|init| match init {
NoRotationNeeded(io) => NoRotationNeeded(BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
}),
RotationNeeded(io, rot) => RotationNeeded(
BlockingIronOxide {
ironoxide: io,
runtime: rt,
runtime: device_context.rt.clone(),
},
rot,
),
Expand Down
6 changes: 2 additions & 4 deletions src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::internal::{
group_api::GroupId,
rest::{Authorization, IronCoreRequest, SignatureUrlString},
rest::{Authorization, SignatureUrlString},
user_api::UserId,
};
use base64::engine::Engine;
Expand Down Expand Up @@ -34,6 +34,7 @@ pub mod document_api;
pub mod group_api;
mod rest;
pub mod user_api;
pub use rest::IronCoreRequest;

lazy_static! {
pub static ref URL_STRING: String = match std::env::var("IRONCORE_ENV") {
Expand All @@ -45,9 +46,6 @@ lazy_static! {
.to_string(),
_ => "https://api.ironcorelabs.com/api/1/".to_string(),
};
static ref SHARED_CLIENT: reqwest::Client = reqwest::Client::new();
pub static ref OUR_REQUEST: IronCoreRequest =
IronCoreRequest::new(URL_STRING.as_str(), &SHARED_CLIENT);
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down
44 changes: 25 additions & 19 deletions src/internal/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::internal::{
auth_v2::AuthV2Builder,
user_api::{Jwt, UserId},
DeviceSigningKeyPair, IronOxideErr, RequestErrorCode, OUR_REQUEST,
DeviceSigningKeyPair, IronOxideErr, RequestErrorCode, URL_STRING,
};
use base64::engine::Engine;
use base64::prelude::BASE64_STANDARD;
Expand Down Expand Up @@ -303,20 +303,20 @@ impl<'a> HeaderIronCoreRequestSig<'a> {
}

///A struct which holds the basic info that will be needed for making requests to an ironcore service. Currently just the base_url.
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct IronCoreRequest {
base_url: &'static str,
#[serde(skip_serializing, skip_deserializing, default = "default_client")]
client: &'static reqwest::Client,
pub(crate) client: reqwest::Client,
}

fn default_client() -> &'static reqwest::Client {
OUR_REQUEST.client
fn default_client() -> reqwest::Client {
Client::new()
}

impl Default for IronCoreRequest {
fn default() -> Self {
*OUR_REQUEST
IronCoreRequest::new(&URL_STRING, default_client())
}
}
impl Hash for IronCoreRequest {
Expand All @@ -332,7 +332,7 @@ impl PartialEq for IronCoreRequest {
impl Eq for IronCoreRequest {}

impl IronCoreRequest {
pub const fn new(base_url: &'static str, client: &'static reqwest::Client) -> IronCoreRequest {
pub const fn new(base_url: &'static str, client: reqwest::Client) -> IronCoreRequest {
IronCoreRequest { base_url, client }
}

Expand Down Expand Up @@ -415,7 +415,7 @@ impl IronCoreRequest {
replace_headers(req.headers_mut(), auth.to_auth_header());
replace_headers(req.headers_mut(), request_sig.to_header());

Self::send_req(req, error_code, move |server_resp| {
self.send_req(req, error_code, move |server_resp| {
IronCoreRequest::deserialize_body(server_resp, error_code)
})
.await
Expand Down Expand Up @@ -551,7 +551,7 @@ impl IronCoreRequest {
Q: Serialize + ?Sized,
F: FnOnce(&Bytes) -> Result<B, IronOxideErr>,
{
let client = Client::new();
let client = self.client.clone();
let mut builder = client.request(
method,
format!("{}{}", self.base_url, relative_url).as_str(),
Expand Down Expand Up @@ -632,7 +632,7 @@ impl IronCoreRequest {
replace_headers(req.headers_mut(), auth.to_auth_header());
replace_headers(req.headers_mut(), request_sig.to_header());

Self::send_req(req, error_code, resp_handler).await
self.send_req(req, error_code, resp_handler).await
} else {
panic!("authorized requests must use version 2 of API authentication")
}
Expand All @@ -653,6 +653,7 @@ impl IronCoreRequest {
}

async fn send_req<B, F>(
&self,
req: Request,
error_code: RequestErrorCode,
resp_handler: F,
Expand All @@ -661,7 +662,7 @@ impl IronCoreRequest {
B: DeserializeOwned,
F: FnOnce(&Bytes) -> Result<B, IronOxideErr>,
{
let client = Client::new();
let client = self.client.clone();
let server_res = client.execute(req).await;
let res = server_res.map_err(|e| (e, error_code))?;
//Parse the body content into bytes
Expand Down Expand Up @@ -1049,12 +1050,11 @@ mod tests {

use recrypt::api::{Ed25519Signature, PublicSigningKey};

lazy_static! {
static ref SHARED_CLIENT: reqwest::Client = reqwest::Client::new();
static ref TEST_REQUEST: IronCoreRequest = IronCoreRequest {
fn create_test_request() -> IronCoreRequest {
IronCoreRequest {
base_url: "https://example.com",
client: &SHARED_CLIENT
};
client: Client::new(),
}
}

#[test]
Expand Down Expand Up @@ -1238,7 +1238,8 @@ mod tests {
public_signing_key: signing_keys.public_key(),
};

let build_url = |relative_url| format!("{}{}", OUR_REQUEST.base_url(), relative_url);
let build_url =
|relative_url| format!("{}{}", IronCoreRequest::default().base_url(), relative_url);
let signing_url_string = SignatureUrlString::new(&build_url("users?id=user-10")).unwrap();

// note that this and the expected value must correspond
Expand Down Expand Up @@ -1378,7 +1379,7 @@ mod tests {
fn query_params_encoded_correctly() {
let mut req = Request::new(
Method::GET,
url::Url::parse(&format!("{}/{}", TEST_REQUEST.base_url(), "users")).unwrap(),
url::Url::parse(&format!("{}/{}", create_test_request().base_url(), "users")).unwrap(),
);
let q = "!\"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~";
IronCoreRequest::req_add_query(&mut req, &[("id".to_string(), url_encode(q))]);
Expand All @@ -1391,7 +1392,12 @@ mod tests {
fn empty_query_params_encoded_correctly() {
let mut req = Request::new(
Method::GET,
url::Url::parse(&format!("{}/{}", TEST_REQUEST.base_url(), "policies")).unwrap(),
url::Url::parse(&format!(
"{}/{}",
create_test_request().base_url(),
"policies"
))
.unwrap(),
);
IronCoreRequest::req_add_query(&mut req, &[]);
assert_eq!(req.url().query(), None);
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ pub mod search;
#[cfg(feature = "blocking")]
pub mod blocking;

pub use crate::internal::IronOxideErr;
pub use crate::internal::{IronCoreRequest, IronOxideErr};

use crate::{
common::{DeviceContext, DeviceSigningKeyPair, PublicKey, SdkOperation},
Expand Down
10 changes: 5 additions & 5 deletions src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub use crate::internal::user_api::{
};
use crate::{
common::{PublicKey, SdkOperation},
internal::{add_optional_timeout, user_api, OUR_REQUEST},
IronOxide, Result,
internal::{add_optional_timeout, user_api},
IronCoreRequest, IronOxide, Result,
};
use futures::Future;
use recrypt::api::Recrypt;
Expand Down Expand Up @@ -299,7 +299,7 @@ impl UserOps for IronOxide {
jwt,
password.try_into()?,
user_create_opts.needs_rotation,
*OUR_REQUEST,
IronCoreRequest::default(),
),
timeout,
SdkOperation::UserCreate,
Expand All @@ -324,7 +324,7 @@ impl UserOps for IronOxide {
password.try_into()?,
device_create_options.device_name,
&std::time::SystemTime::now().into(),
&OUR_REQUEST,
&IronCoreRequest::default(),
),
timeout,
SdkOperation::GenerateNewDevice,
Expand All @@ -337,7 +337,7 @@ impl UserOps for IronOxide {
timeout: Option<std::time::Duration>,
) -> Result<Option<UserResult>> {
add_optional_timeout(
user_api::user_verify(jwt, *OUR_REQUEST),
user_api::user_verify(jwt, IronCoreRequest::default()),
timeout,
SdkOperation::UserVerify,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/blocking_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod common;
// Note: The blocking functions need minimal testing as they primarily just call their async counterparts

#[cfg(feature = "blocking")]
mod integration_tests {
mod blocking_integration_tests {
use crate::common::{create_id_all_classes, gen_jwt, USER_PASSWORD};
use galvanic_assert::{matchers::*, *};
use ironoxide::prelude::*;
Expand Down

0 comments on commit b8b426a

Please sign in to comment.