Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invoice callback -> channel, log removal #7

Merged
1 commit merged into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
577 changes: 381 additions & 196 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ sha2 = "0.10.8"
tokio = "1.37.0"
uuid = {version="1.8.0",features=["v4"]}
reqwest = "0.12.4"
log4rs = "1.3.0"
log = "0.4.21"
zeroize = {version="1.7.0",features=["zeroize_derive"]}
async-std = "1.12.0"
24 changes: 10 additions & 14 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::common::DatabaseError;
use crate::types::Serializable;
use crate::{common::DatabaseError, types::Invoice};
use sled::Tree;

/// Retrieve a value by key from a tree.
Expand All @@ -24,11 +23,10 @@ async fn get_last_from_tree(db: &Tree) -> Result<(Vec<u8>, Vec<u8>), DatabaseErr
db.last()?
.map(|(key, value)| (key.to_vec(), value.to_vec()))
.ok_or(DatabaseError::NotFound)

}

/// Wrapper for retrieving the last added item to the tree
pub async fn get_last<T: Serializable>(tree: &sled::Tree) -> Result<(String, T), DatabaseError> {
pub async fn get_last(tree: &sled::Tree) -> Result<(String, Invoice), DatabaseError> {
let binary_data = get_last_from_tree(tree).await?;
// Convert binary key to String
let key = String::from_utf8(binary_data.0).map_err(|error| {
Expand All @@ -37,17 +35,15 @@ pub async fn get_last<T: Serializable>(tree: &sled::Tree) -> Result<(String, T),
})?;

// Deserialize binary value to T
let value = T::from_bin(binary_data.1).map_err(|error| {
let value = bincode::deserialize::<Invoice>(&binary_data.1).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})?;
Ok((key, value))
}

/// Wrapper for retrieving all key value pairs from a tree
pub async fn get_all<T: Serializable>(
tree: &sled::Tree,
) -> Result<Vec<(String, T)>, DatabaseError> {
pub async fn get_all(tree: &sled::Tree) -> Result<Vec<(String, Invoice)>, DatabaseError> {
let binary_data = get_all_from_tree(tree).await?;
let mut all = Vec::with_capacity(binary_data.len());
for (binary_key, binary_value) in binary_data {
Expand All @@ -57,8 +53,8 @@ pub async fn get_all<T: Serializable>(
DatabaseError::Deserialize
})?;

// Deserialize binary value to T
let value = T::from_bin(binary_value).map_err(|error| {
// Deserialize binary value to invoice
let value = bincode::deserialize::<Invoice>(&binary_value).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})?;
Expand All @@ -69,9 +65,9 @@ pub async fn get_all<T: Serializable>(
}

/// Wrapper for retrieving a value from a tree
pub async fn get<T: Serializable>(tree: &Tree, key: &str) -> Result<T, DatabaseError> {
pub async fn get(tree: &Tree, key: &str) -> Result<Invoice, DatabaseError> {
let binary_data = get_from_tree(tree, key).await?;
T::from_bin(binary_data).map_err(|error| {
bincode::deserialize::<Invoice>(&binary_data).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Deserialize
})
Expand All @@ -89,8 +85,8 @@ async fn set_to_tree(db: &Tree, key: &str, bin: Vec<u8>) -> Result<(), DatabaseE
}

/// Wrapper for setting a value to a tree
pub async fn set<T: Serializable>(tree: &Tree, key: &str, data: &T) -> Result<(), DatabaseError> {
let binary_data = T::to_bin(data).map_err(|error| {
pub async fn set(tree: &Tree, key: &str, data: &Invoice) -> Result<(), DatabaseError> {
let binary_data = bincode::serialize::<Invoice>(data).map_err(|error| {
log::error!("Db Interaction Error: {}", error);
DatabaseError::Serialize
})?;
Expand Down
103 changes: 56 additions & 47 deletions src/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@ use alloy::{
signers::wallet::LocalWallet,
transports::http::Http,
};
use log::LevelFilter;
use log4rs::{
append::file::FileAppender,
config::{Appender, Root},
encode::pattern::PatternEncoder,
Config,
};

use async_std::channel::Sender;
use reqwest::{Client, Url};
use sled::Tree;

Expand Down Expand Up @@ -44,10 +39,32 @@ pub struct PaymentGatewayConfiguration {
pub provider: RootProvider<Http<Client>>,
pub treasury_address: Address,
pub invoice_delay_millis: u64,
pub callback: AsyncCallback,
pub reflector: Reflector,
pub transfer_gas_limit: Option<u128>,
}

/// ## Reflector
/// The reflector allows your payment gateway to be used in a more flexible way.
///
/// In its current state you can pass a Sender from an unbound async-std channel
/// which you can create by doing:
/// ```rust
/// use async_std::channel::unbounded;
/// use acceptevm::gateway::Reflector;
///
/// let (sender, receiver) = unbounded();
///
/// let reflector=Reflector::Sender(sender);
/// ```
///
/// You may clone the receiver as many times as you want but do not use the sender
/// for anything other than passing it to the try_new() method.
#[derive(Clone)]
pub enum Reflector {
/// A sender from async-std
Sender(Sender<Invoice>),
}

// Type alias for the underlying Web3 type.
pub type Wei = U256;

Expand All @@ -62,57 +79,49 @@ impl PaymentGateway {
/// - `treasury_address`: the address of the treasury for all paid invoices, on this EVM network.
/// - `invoice_delay_millis`: how long to wait before checking the next invoice in milliseconds.
/// This is used to prevent potential rate limits from the node.
/// - `callback`: an async function that is called when an invoice is paid.
/// - `reflector`: The reflector is an enum that allows you to receive the paid invoices.
/// At the moment, the only reflector available is the `Sender` from the async-std channel.
/// This means that you will need to create a channel and pass the sender as the reflector.
/// - `sled_path`: The path of the sled database where the pending invoices will
/// be stored. In the event of a crash the invoices are saved and will be
/// checked on reboot.
/// - `name`: A name that describes this gateway. Perhaps the EVM network used?
/// - `transfer_gas_limit`: An optional gas limit used when transferring gas from paid invoices to
/// the treasury. Useful in case your treasury address is a contract address
/// that implements custom functionality for handling incoming gas.
pub fn new<F, Fut>(
///
/// Example:
/// ```rust
/// use acceptevm::gateway::{PaymentGateway, Reflector};
/// use async_std::channel::unbounded;
/// let (sender, _receiver) = unbounded();
/// let reflector = Reflector::Sender(sender);
///
/// PaymentGateway::new(
/// "https://123.com",
/// "0xdac17f958d2ee523a2206206994597c13d831ec7".to_string(),
/// 10,
/// reflector,
/// "./your-wanted-db-path",
/// "test".to_string(),
/// Some(21000),
/// );
/// ```


pub fn new(
rpc_url: &str,
treasury_address: String,
invoice_delay_millis: u64,
callback: F,
reflector: Reflector,
sled_path: &str,
name: String,
transfer_gas_limit: Option<u128>,
) -> PaymentGateway
where
F: Fn(Invoice) -> Fut + 'static + Send + Sync,
Fut: Future<Output = ()> + 'static + Send,
{
// Send allows ownership to be transferred across threads
// Sync allows references to be shared

) -> PaymentGateway {
let db = sled::open(sled_path).unwrap();
let tree = db.open_tree("invoices").unwrap();
let provider = ProviderBuilder::new().on_http(Url::from_str(rpc_url).unwrap());

// Wrap the callback in Arc<Mutex<>> to allow sharing across threads and state mutation
// We have to create a pinned box to prevent the future from being moved around in heap memory.
let callback = Arc::new(move |invoice: Invoice| {
Box::pin(callback(invoice)) as Pin<Box<dyn Future<Output = ()> + Send>>
});

// Setup logging
let logfile = FileAppender::builder()
.encoder(Box::new(PatternEncoder::new("{l} - {m}\n")))
.build("./acceptevm.log")
.unwrap();

let config = Config::builder()
.appender(Appender::builder().build("logfile", Box::new(logfile)))
.build(Root::builder().appender("logfile").build(LevelFilter::Info))
.unwrap();

// Try to initialize and catch error silently if already initialized
// during tests this make this function throw error
if log4rs::init_config(config).is_err() {
println!("Logger already initialized.");
}

// TODO: When implementing token transfers allow the user to add their gas wallet here.

PaymentGateway {
Expand All @@ -122,7 +131,7 @@ impl PaymentGateway {
.parse()
.unwrap_or_else(|_| panic!("Invalid treasury address")),
invoice_delay_millis,
callback,
reflector,
transfer_gas_limit,
},
tree,
Expand All @@ -132,20 +141,20 @@ impl PaymentGateway {

/// Retrieves the last invoice
pub async fn get_last_invoice(&self) -> Result<(String, Invoice), DatabaseError> {
get_last::<Invoice>(&self.tree).await
get_last(&self.tree).await
}

/// Retrieves all invoices in the form of a tuple: String,Invoice
/// where the first element is the key that was used in the database
/// and the second part is the invoice. The key is a SHA256 hash of the
/// creation timestamp and the recipient address.
pub async fn get_all_invoices(&self) -> Result<Vec<(String, Invoice)>, DatabaseError> {
get_all::<Invoice>(&self.tree).await
get_all(&self.tree).await
}

/// Retrieve an invoice from the payment gateway
pub async fn get_invoice(&self, key: String) -> Result<Invoice, DatabaseError> {
get::<Invoice>(&self.tree, &key).await
get(&self.tree, &key).await
}

/// Spawns an asynchronous task that checks all the pending invoices
Expand Down Expand Up @@ -189,7 +198,7 @@ impl PaymentGateway {
let seed = format!("{}{}", signer.address(), get_unix_time_millis());
let invoice_id = hash_now(seed);
// Save the invoice in db.
set::<Invoice>(&self.tree, &invoice_id, &invoice).await?;
set(&self.tree, &invoice_id, &invoice).await?;
Ok(invoice)
}
}
27 changes: 7 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,23 @@ mod tests {
use std::{fs, path::Path, str::FromStr};

use alloy::primitives::U256;
use async_std::channel::unbounded;

use crate::{
common::DatabaseError,
gateway::PaymentGateway,
gateway::{PaymentGateway, Reflector},
types::{Invoice, PaymentMethod},
};

struct Foo {
bar: std::sync::Mutex<i64>,
}

impl Foo {
async fn increase(&self) {
*self.bar.lock().unwrap() += 1;
}
}

fn setup_test_gateway(db_path: &str) -> PaymentGateway {
let foo = std::sync::Arc::new(Foo {
bar: Default::default(),
});
let foo_clone = foo.clone();
let callback = move |_| {
let foo = foo_clone.clone();
async move { foo.increase().await }
};
let (sender, _receiver) = unbounded();
let reflector = Reflector::Sender(sender);

PaymentGateway::new(
"https://123.com",
"0xdac17f958d2ee523a2206206994597c13d831ec7".to_string(),
10,
callback,
reflector,
db_path,
"test".to_string(),
Some(21000),
Expand Down Expand Up @@ -86,4 +72,5 @@ mod tests {
assert_eq!(address_length, 42);
remove_test_db("./test-assert-valid-address-length");
}

}
19 changes: 12 additions & 7 deletions src/poller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use alloy::{
rpc::types::eth::TransactionReceipt,
transports::http::Http,
};
use crate::gateway::Reflector::Sender;
use reqwest::Client;
use sled::Tree;

Expand Down Expand Up @@ -70,11 +71,8 @@ async fn check_and_process(provider: RootProvider<Http<Client>>, invoice: &Invoi

async fn delete_invoice(tree: &Tree, key: String) {
// Optimistically delete the old invoice.
match delete(tree, &key).await {
Ok(()) => {}
Err(error) => {
log::error!("Could not remove invoice, did not callback: {}", error);
}
if let Err(delete_error) = delete(tree, &key).await {
log::error!("Could not remove invoice: {}", delete_error);
}
}

Expand All @@ -89,7 +87,7 @@ async fn transfer_to_treasury(
/// to the specified polling interval.
pub async fn poll_payments(gateway: PaymentGateway) {
loop {
match get_all::<Invoice>(&gateway.tree).await {
match get_all(&gateway.tree).await {
Ok(all) => {
// Loop through all invoices
for (key, mut invoice) in all {
Expand Down Expand Up @@ -121,7 +119,14 @@ pub async fn poll_payments(gateway: PaymentGateway) {
// lock to the callback function.
delete_invoice(&gateway.tree, key).await;
invoice.paid_at_timestamp = get_unix_time_seconds();
(gateway.config.callback)(invoice).await;// Execute callback function
match gateway.config.reflector {
Sender(ref sender) => {
// Attempt to send the PriceData through the channel.
if let Err(error) = sender.send(invoice).await {
log::error!("Failed sending data: {}", error);
}
}
}
}
// To prevent rate limitations on certain Web3 RPC's we sleep here for the specified amount.
tokio::time::sleep(std::time::Duration::from_millis(
Expand Down
7 changes: 0 additions & 7 deletions src/types/errors.rs

This file was deleted.

19 changes: 0 additions & 19 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
mod errors;
use std::ops::{Deref, DerefMut};

use self::errors::SerializableError;
use alloy::{
primitives::{B256, U256},
rpc::types::eth::TransactionReceipt,
};
use serde::{Deserialize, Serialize};
use zeroize::ZeroizeOnDrop;
pub trait Serializable {
fn to_bin(&self) -> Result<Vec<u8>, Box<bincode::ErrorKind>>;
fn from_bin(data: Vec<u8>) -> Result<Self, SerializableError>
where
Self: Sized;
}

/// Describes the structure of a payment method in
/// a gateway
Expand Down Expand Up @@ -62,14 +54,3 @@ pub struct Invoice {
pub receipt: Option<TransactionReceipt>,
}

impl Serializable for Invoice {
/// Serializes invoice to bytes
fn to_bin(&self) -> Result<Vec<u8>, Box<bincode::ErrorKind>> {
bincode::serialize(&self)
}

/// Deserializes invoice from bytes
fn from_bin(data: Vec<u8>) -> Result<Self, SerializableError> {
bincode::deserialize(&data).map_err(|_| SerializableError::Deserialize)
}
}