Skip to content

Commit

Permalink
update: optimised fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
heemankv committed Nov 13, 2024
1 parent 1a60dc2 commit 3c3f83d
Show file tree
Hide file tree
Showing 24 changed files with 224 additions and 274 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ rstest = "0.22.0"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
starknet = "0.11.0"
strum = "0.26.0"
strum_macros = "0.26.0"
tempfile = "3.12.0"
thiserror = "1.0.57"
tokio = { version = "1.37.0" }
Expand Down
4 changes: 2 additions & 2 deletions crates/orchestrator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ starknet = { workspace = true }
starknet-core = "0.9.0"
starknet-os = { workspace = true }
starknet-settlement-client = { workspace = true }
strum = "0.26.0"
strum_macros = "0.26.0"
strum = { workspace = true }
strum_macros = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["sync", "macros", "rt-multi-thread"] }
Expand Down
7 changes: 3 additions & 4 deletions crates/orchestrator/src/cli/cron/event_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ pub struct AWSEventBridgeCliArgs {
pub aws_event_bridge: bool,

/// The name of the S3 bucket.
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_TARGET_QUEUE_NAME", long, default_value = Some("madara-orchestrator-event-bridge-target-queue-name"))]
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_TARGET_QUEUE_NAME", long, default_value = Some("madara-orchestrator-event-bridge-target-queue-name"), help = "The name of the SNS queue to send messages to from the event bridge.")]
pub target_queue_name: Option<String>,

/// The cron time for the event bridge trigger rule.
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_CRON_TIME", long, default_value = Some("10"))]
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_CRON_TIME", long, default_value = Some("10"), help = "The cron time for the event bridge trigger rule. Defaults to 10 seconds.")]
pub cron_time: Option<String>,

/// The name of the event bridge trigger rule.
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_TRIGGER_RULE_NAME", long, default_value = Some("madara-orchestrator-event-bridge-trigger-rule-name"))]
#[arg(env = "MADARA_ORCHESTRATOR_EVENT_BRIDGE_TRIGGER_RULE_NAME", long, default_value = Some("madara-orchestrator-event-bridge-trigger-rule-name"), help = "The name of the event bridge trigger rule.")]
pub trigger_rule_name: Option<String>,
}
140 changes: 67 additions & 73 deletions crates/orchestrator/src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::str::FromStr as _;
use std::time::Duration;

use alert::AlertValidatedArgs;
use alloy::primitives::Address;
use clap::{ArgGroup, Parser, Subcommand};
use cron::event_bridge::AWSEventBridgeCliArgs;
use cron::CronValidatedArgs;
Expand Down Expand Up @@ -179,7 +181,7 @@ impl RunCmd {
}

pub fn validate_alert_params(&self) -> Result<AlertValidatedArgs, String> {
if self.aws_sns_args.aws_sns {
if self.aws_sns_args.aws_sns && self.aws_config_args.aws {
Ok(AlertValidatedArgs::AWSSNS(AWSSNSValidatedArgs {
topic_arn: self.aws_sns_args.sns_arn.clone().expect("SNS ARN is required"),
}))
Expand All @@ -189,9 +191,12 @@ impl RunCmd {
}

pub fn validate_queue_params(&self) -> Result<QueueValidatedArgs, String> {
if self.aws_sqs_args.aws_sqs {
if self.aws_sqs_args.aws_sqs && self.aws_config_args.aws {
Ok(QueueValidatedArgs::AWSSQS(AWSSQSValidatedArgs {
queue_base_url: self.aws_sqs_args.queue_base_url.clone().expect("Queue base URL is required"),
queue_base_url: Url::parse(
&self.aws_sqs_args.queue_base_url.clone().expect("Queue base URL is required"),
)
.expect("Invalid queue base URL"),
sqs_prefix: self.aws_sqs_args.sqs_prefix.clone().expect("SQS prefix is required"),
sqs_suffix: self.aws_sqs_args.sqs_suffix.clone().expect("SQS suffix is required"),
}))
Expand All @@ -201,7 +206,7 @@ impl RunCmd {
}

pub fn validate_storage_params(&self) -> Result<StorageValidatedArgs, String> {
if self.aws_s3_args.aws_s3 {
if self.aws_s3_args.aws_s3 && self.aws_config_args.aws {
Ok(StorageValidatedArgs::AWSS3(AWSS3ValidatedArgs {
bucket_name: self.aws_s3_args.bucket_name.clone().expect("Bucket name is required"),
}))
Expand All @@ -213,11 +218,10 @@ impl RunCmd {
pub fn validate_database_params(&self) -> Result<DatabaseValidatedArgs, String> {
if self.mongodb_args.mongodb {
Ok(DatabaseValidatedArgs::MongoDB(MongoDBValidatedArgs {
connection_url: self
.mongodb_args
.mongodb_connection_url
.clone()
.expect("MongoDB connection URL is required"),
connection_url: Url::parse(
&self.mongodb_args.mongodb_connection_url.clone().expect("MongoDB connection URL is required"),
)
.expect("Invalid MongoDB connection URL"),
database_name: self
.mongodb_args
.mongodb_database_name
Expand All @@ -244,67 +248,55 @@ impl RunCmd {
}

pub fn validate_settlement_params(&self) -> Result<settlement::SettlementValidatedArgs, String> {
match (self.ethereum_args.settle_on_ethereum, self.starknet_args.settle_on_starknet) {
(true, false) => {
let ethereum_params = EthereumSettlementValidatedArgs {
ethereum_rpc_url: self
.ethereum_args
.ethereum_rpc_url
.clone()
.expect("Ethereum RPC URL is required"),
ethereum_private_key: self
.ethereum_args
.ethereum_private_key
.clone()
.expect("Ethereum private key is required"),
l1_core_contract_address: self
.ethereum_args
.l1_core_contract_address
.clone()
.expect("L1 core contract address is required"),
starknet_operator_address: self
.ethereum_args
.starknet_operator_address
.clone()
.expect("Starknet operator address is required"),
};
Ok(SettlementValidatedArgs::Ethereum(ethereum_params))
}
(false, true) => {
let starknet_params = StarknetSettlementValidatedArgs {
starknet_rpc_url: self
.starknet_args
.starknet_rpc_url
.clone()
.expect("Starknet RPC URL is required"),
starknet_private_key: self
.starknet_args
.starknet_private_key
.clone()
.expect("Starknet private key is required"),
starknet_account_address: self
.starknet_args
.starknet_account_address
.clone()
.expect("Starknet account address is required"),
starknet_cairo_core_contract_address: self
if self.ethereum_args.settle_on_ethereum {
let l1_core_contract_address = Address::from_str(
&self.ethereum_args.l1_core_contract_address.clone().expect("L1 core contract address is required"),
)
.expect("Invalid L1 core contract address");
let starknet_operator_address = Address::from_str(
&self.ethereum_args.starknet_operator_address.clone().expect("Starknet operator address is required"),
)
.expect("Invalid Starknet operator address");

let ethereum_params = EthereumSettlementValidatedArgs {
ethereum_rpc_url: self.ethereum_args.ethereum_rpc_url.clone().expect("Ethereum RPC URL is required"),
ethereum_private_key: self
.ethereum_args
.ethereum_private_key
.clone()
.expect("Ethereum private key is required"),
l1_core_contract_address,
starknet_operator_address,
};
Ok(SettlementValidatedArgs::Ethereum(ethereum_params))
} else if self.starknet_args.settle_on_starknet {
let starknet_params = StarknetSettlementValidatedArgs {
starknet_rpc_url: self.starknet_args.starknet_rpc_url.clone().expect("Starknet RPC URL is required"),
starknet_private_key: self
.starknet_args
.starknet_private_key
.clone()
.expect("Starknet private key is required"),
starknet_account_address: Address::from_str(
&self.starknet_args.starknet_account_address.clone().expect("Starknet account address is required"),
)
.expect("Invalid Starknet account address"),
starknet_cairo_core_contract_address: Address::from_str(
&self
.starknet_args
.starknet_cairo_core_contract_address
.clone()
.expect("Starknet Cairo core contract address is required"),
starknet_finality_retry_wait_in_secs: self
.starknet_args
.starknet_finality_retry_wait_in_secs
.expect("Starknet finality retry wait in seconds is required"),
madara_binary_path: self
.starknet_args
.starknet_madara_binary_path
.clone()
.expect("Starknet Madara binary path is required"),
};
Ok(SettlementValidatedArgs::Starknet(starknet_params))
}
(true, true) | (false, false) => Err("Exactly one settlement layer must be selected".to_string()),
)
.expect("Invalid Starknet Cairo core contract address"),
starknet_finality_retry_wait_in_secs: self
.starknet_args
.starknet_finality_retry_wait_in_secs
.expect("Starknet finality retry wait in seconds is required"),
};
Ok(SettlementValidatedArgs::Starknet(starknet_params))
} else {
Err("Settlement layer is required".to_string())
}
}

Expand Down Expand Up @@ -339,7 +331,7 @@ impl RunCmd {
.instrumentation_args
.otel_service_name
.clone()
.expect("OTel service name is required"),
.expect("Otel service name is required"),
otel_collector_endpoint: self.instrumentation_args.otel_collector_endpoint.clone(),
log_level: self.instrumentation_args.log_level,
})
Expand Down Expand Up @@ -440,7 +432,7 @@ impl SetupCmd {
}

pub fn validate_alert_params(&self) -> Result<AlertValidatedArgs, String> {
if self.aws_sns_args.aws_sns {
if self.aws_sns_args.aws_sns && self.aws_config_args.aws {
Ok(AlertValidatedArgs::AWSSNS(AWSSNSValidatedArgs {
topic_arn: self.aws_sns_args.sns_arn.clone().expect("SNS ARN is required"),
}))
Expand All @@ -450,9 +442,12 @@ impl SetupCmd {
}

pub fn validate_queue_params(&self) -> Result<QueueValidatedArgs, String> {
if self.aws_sqs_args.aws_sqs {
if self.aws_sqs_args.aws_sqs && self.aws_config_args.aws {
Ok(QueueValidatedArgs::AWSSQS(AWSSQSValidatedArgs {
queue_base_url: self.aws_sqs_args.queue_base_url.clone().expect("Queue base URL is required"),
queue_base_url: Url::parse(
&self.aws_sqs_args.queue_base_url.clone().expect("Queue base URL is required"),
)
.expect("Invalid queue base URL"),
sqs_prefix: self.aws_sqs_args.sqs_prefix.clone().expect("SQS prefix is required"),
sqs_suffix: self.aws_sqs_args.sqs_suffix.clone().expect("SQS suffix is required"),
}))
Expand All @@ -462,7 +457,7 @@ impl SetupCmd {
}

pub fn validate_storage_params(&self) -> Result<StorageValidatedArgs, String> {
if self.aws_s3_args.aws_s3 {
if self.aws_s3_args.aws_s3 && self.aws_config_args.aws {
Ok(StorageValidatedArgs::AWSS3(AWSS3ValidatedArgs {
bucket_name: self.aws_s3_args.bucket_name.clone().expect("Bucket name is required"),
}))
Expand All @@ -472,7 +467,7 @@ impl SetupCmd {
}

pub fn validate_cron_params(&self) -> Result<CronValidatedArgs, String> {
if self.aws_event_bridge_args.aws_event_bridge {
if self.aws_event_bridge_args.aws_event_bridge && self.aws_config_args.aws {
Ok(CronValidatedArgs::AWSEventBridge(AWSEventBridgeValidatedArgs {
target_queue_name: self
.aws_event_bridge_args
Expand Down Expand Up @@ -580,7 +575,6 @@ impl SetupCmd {
// starknet_account_address: Some("".to_string()),
// starknet_cairo_core_contract_address: Some("".to_string()),
// starknet_finality_retry_wait_in_secs: Some(0),
// starknet_madara_binary_path: Some("".to_string()),
// settle_on_starknet: false,
// },

Expand Down
4 changes: 0 additions & 4 deletions crates/orchestrator/src/cli/settlement/starknet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,4 @@ pub struct StarknetSettlementCliArgs {
/// The number of seconds to wait for finality.
#[arg(env = "MADARA_ORCHESTRATOR_STARKNET_FINALITY_RETRY_WAIT_IN_SECS", long)]
pub starknet_finality_retry_wait_in_secs: Option<u64>,

/// The path to the Madara binary.
#[arg(env = "MADARA_ORCHESTRATOR_MADARA_BINARY_PATH", long)]
pub starknet_madara_binary_path: Option<String>,
}
25 changes: 13 additions & 12 deletions crates/orchestrator/src/cron/event_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@ pub struct AWSEventBridgeValidatedArgs {
}

pub struct AWSEventBridge {
target_queue_name: String,
cron_time: Duration,
trigger_rule_name: String,
client: EventBridgeClient,
queue_client: SqsClient,
}

impl AWSEventBridge {
pub fn new_with_args(_params: &AWSEventBridgeValidatedArgs, aws_config: &SdkConfig) -> Self {
pub fn new_with_args(params: &AWSEventBridgeValidatedArgs, aws_config: &SdkConfig) -> Self {
Self {
target_queue_name: params.target_queue_name.clone(),
cron_time: params.cron_time,
trigger_rule_name: params.trigger_rule_name.clone(),
client: aws_sdk_eventbridge::Client::new(aws_config),
queue_client: aws_sdk_sqs::Client::new(aws_config),
}
Expand All @@ -33,24 +39,19 @@ impl AWSEventBridge {
#[async_trait]
#[allow(unreachable_patterns)]
impl Cron for AWSEventBridge {
async fn create_cron(&self, cron_time: Duration, trigger_rule_name: String) -> color_eyre::Result<()> {
async fn create_cron(&self) -> color_eyre::Result<()> {
self.client
.put_rule()
.name(&trigger_rule_name)
.schedule_expression(duration_to_rate_string(cron_time))
.name(&self.trigger_rule_name)
.schedule_expression(duration_to_rate_string(self.cron_time))
.state(RuleState::Enabled)
.send()
.await?;

Ok(())
}
async fn add_cron_target_queue(
&self,
target_queue_name: String,
message: String,
trigger_rule_name: String,
) -> color_eyre::Result<()> {
let queue_url = self.queue_client.get_queue_url().queue_name(target_queue_name).send().await?;
async fn add_cron_target_queue(&self, message: String) -> color_eyre::Result<()> {
let queue_url = self.queue_client.get_queue_url().queue_name(&self.target_queue_name).send().await?;

let queue_attributes = self
.queue_client
Expand All @@ -67,7 +68,7 @@ impl Cron for AWSEventBridge {

self.client
.put_targets()
.rule(trigger_rule_name)
.rule(&self.trigger_rule_name)
.targets(
Target::builder()
.id(uuid::Uuid::new_v4().to_string())
Expand Down
Loading

0 comments on commit 3c3f83d

Please sign in to comment.