diff --git a/base_engine/Cargo.toml b/base_engine/Cargo.toml index 72a46eab..71fc6f5b 100644 --- a/base_engine/Cargo.toml +++ b/base_engine/Cargo.toml @@ -22,6 +22,6 @@ futures = { version = "0.3", optional=true } dashmap = {workspace = true, optional=true} [features] -default = ["aws_s3"] +default = [] aws_s3 = ["dep:aws-config", "dep:aws-sdk-s3", "dep:tokio", "dep:futures"] cache = ["dep:dashmap"] \ No newline at end of file diff --git a/base_engine/src/api/execute_agg.rs b/base_engine/src/api/execute_agg.rs index 849f58c4..81ffcd3b 100644 --- a/base_engine/src/api/execute_agg.rs +++ b/base_engine/src/api/execute_agg.rs @@ -19,14 +19,7 @@ pub fn execute_aggregation( // Assuming Front End knows which columns can be in groupby, agg etc // Step 0.1 - let f1 = data.frame(); - //let tmp = f1.clone().lazy().filter(col("RiskClass").eq(lit("DRC_SecNonCTP"))).collect()?; - //dbg!(&tmp["SensWeights"]); - //let f1_cols = f1.get_column_names(); - - // Polars DataFrame clone is cheap: - // https://stackoverflow.com/questions/72320911/how-to-avoid-deep-copy-when-using-groupby-in-polars-rust - let mut f1 = f1.clone().lazy(); + let mut f1 = data.lazy_frame().clone(); // Step 1.0 Applying FILTERS: // TODO check if column is present in DF - ( is this "second line of defence" even needed?) @@ -91,7 +84,7 @@ pub fn execute_aggregation( let groups_fill_nulls: Vec = groups .clone() .into_iter() - .map(|e| e.fill_null(lit("EMPTY"))) + .map(|e| e.fill_null(lit(" "))) .collect(); // Step 2.5 Apply GroupBy and Agg diff --git a/base_engine/src/dataset.rs b/base_engine/src/dataset.rs index 12a8a593..b3d4c0be 100644 --- a/base_engine/src/dataset.rs +++ b/base_engine/src/dataset.rs @@ -7,13 +7,12 @@ use crate::{derive_measure_map, DataSourceConfig, MeasuresMap}; /// This is the default struct which implements Dataset /// Usually a client/user would overwrite it with their own DataSet -#[derive(Debug, Default, Serialize)] +#[derive(Default)] pub struct DataSetBase { - pub frame: DataFrame, + pub frame: LazyFrame, pub measures: MeasuresMap, /// build_params are used in .prepare() pub build_params: HashMap, - pub calc_params: Vec, } /// This struct is purely for DataSet descriptive purposes. @@ -30,15 +29,39 @@ pub struct CalcParameter { /// /// If you have your own DataSet, implement this pub trait DataSet: Send + Sync { - fn frame(&self) -> &DataFrame; + /// Polars DataFrame clone is cheap: + /// https://stackoverflow.com/questions/72320911/how-to-avoid-deep-copy-when-using-groupby-in-polars-rust + fn lazy_frame(&self) -> &LazyFrame; + fn measures(&self) -> &MeasuresMap; - fn build(conf: DataSourceConfig) -> Self + + // Cannot be defined since returns Self which is a Struct + // TODO potentially remove to keep things simple + fn from_config(conf: DataSourceConfig) -> Self where Self: Sized; + + /// See [DataSetBase] and [CalcParameter] for description of the parameters + fn new(frame: LazyFrame, mm: MeasuresMap, build_params: HashMap) -> Self + where + Self: Sized; + + fn collect(self) -> PolarsResult + where + Self: Sized, + { + Ok(self) + } + // These methods could be overwritten. - /// Prepare runs ONCE before server starts. - /// Any computations which are common to most queries could go in here. + /// Clones + fn frame(&self) -> PolarsResult { + self.lazy_frame().clone().collect() + } + + /// Prepare runs BEFORE any calculations. In eager mode it runs ONCE + /// Any pre-computations which are common to all queries could go in here. fn prepare(&mut self) {} fn calc_params(&self) -> Vec { @@ -46,7 +69,7 @@ pub trait DataSet: Send + Sync { } fn overridable_columns(&self) -> Vec { - overrides_columns(self.frame()) + overrides_columns(self.lazy_frame()) } /// Validate DataSet /// Runs once, making sure all the required columns, their contents, types etc are valid @@ -57,27 +80,35 @@ pub trait DataSet: Send + Sync { } impl DataSet for DataSetBase { - fn frame(&self) -> &DataFrame { + /// Polars DataFrame clone is cheap: + /// https://stackoverflow.com/questions/72320911/how-to-avoid-deep-copy-when-using-groupby-in-polars-rust + fn lazy_frame(&self) -> &LazyFrame { &self.frame } fn measures(&self) -> &MeasuresMap { &self.measures } - /// It's ok to clone. Function is only called upon serialisation, so very rarely - fn calc_params(&self) -> Vec { - self.calc_params.clone() - } - fn build(conf: DataSourceConfig) -> Self { - let (frames, measure_cols, build_params) = conf.build(); + fn from_config(conf: DataSourceConfig) -> Self { + let (frame, measure_cols, build_params) = conf.build(); let mm: MeasuresMap = derive_measure_map(measure_cols); Self { - frame: frames, + frame, measures: mm, build_params, - calc_params: vec![], } } + fn new(frame: LazyFrame, mm: MeasuresMap, build_params: HashMap) -> Self { + Self { + frame, + measures: mm, + build_params, + } + } + fn collect(self) -> PolarsResult { + let lf = self.frame.collect()?.lazy(); + Ok(Self { frame: lf, ..self }) + } // /// Validate Dataset contains columns // /// files_join_attributes and attributes_join_hierarchy @@ -89,43 +120,54 @@ impl DataSet for DataSetBase { // fn validate(&self) {} } -pub(crate) fn numeric_columns(df: &DataFrame) -> Vec { - let mut res = vec![]; - for c in df.get_columns() { - if c.dtype().is_numeric() { - res.push(c.name().to_string()) - } - } - res +// TODO return Result +pub(crate) fn numeric_columns(lf: &LazyFrame) -> Vec { + lf.schema().map_or_else( + |_| vec![], + |schema| { + schema + .iter_fields() + .filter(|f| f.data_type().is_numeric()) + .map(|f| f.name) + .collect::>() + }, + ) } -pub(crate) fn utf8_columns(df: &DataFrame) -> Vec { - let mut res = vec![]; - for c in df.get_columns() { - if let DataType::Utf8 = c.dtype() { - res.push(c.name().to_string()) - } - } - res +// TODO return Result +pub(crate) fn utf8_columns(lf: &LazyFrame) -> Vec { + lf.schema().map_or_else( + |_| vec![], + |schema| { + schema + .iter_fields() + .filter(|field| matches!(field.data_type(), DataType::Utf8)) + .map(|field| field.name) + .collect::>() + }, + ) } /// DataTypes supported for overrides are defined in [overrides::string_to_lit] -pub(crate) fn overrides_columns(df: &DataFrame) -> Vec { - let mut res = vec![]; - for c in df.get_columns() { - match c.dtype() { - DataType::Utf8 | DataType::Boolean | DataType::Float64 => { - res.push(c.name().to_string()) - } - DataType::List(x) => { - if let DataType::Float64 = x.as_ref() { - res.push(c.name().to_string()) - } - } - _ => (), - } - } - res +pub(crate) fn overrides_columns(lf: &LazyFrame) -> Vec { + //let mut res = vec![]; + lf.schema().map_or_else( + |_| vec![], + |schema| { + let res = schema + .iter_fields() + .filter(|c| match c.data_type() { + DataType::Utf8 | DataType::Boolean | DataType::Float64 => true, + DataType::List(x) => { + matches!(x.as_ref(), DataType::Float64) + } + _ => false, + }) + .map(|c| c.name) + .collect::>(); + res + }, + ) } impl Serialize for dyn DataSet { @@ -140,7 +182,7 @@ impl Serialize for dyn DataSet { .collect::>>(); let ordered_measures: BTreeMap<_, _> = measures.iter().collect(); - let utf8_cols = utf8_columns(self.frame()); + let utf8_cols = utf8_columns(self.lazy_frame()); let calc_params = self.calc_params(); let mut seq = serializer.serialize_map(Some(4))?; diff --git a/base_engine/src/datasource/helpers.rs b/base_engine/src/datasource/helpers.rs index 9e6ec882..78324050 100644 --- a/base_engine/src/datasource/helpers.rs +++ b/base_engine/src/datasource/helpers.rs @@ -2,7 +2,8 @@ use std::collections::HashMap; use polars::{ prelude::{ - col, DataFrame, DataType, Expr, Field, IntoLazy, JoinType, LazyCsvReader, NamedFrom, Schema, + col, concat, DataFrame, DataType, Expr, Field, JoinType, LazyCsvReader, LazyFrame, Literal, + NamedFrom, PolarsResult, Schema, NULL, }, series::Series, }; @@ -20,7 +21,7 @@ pub fn empty_frame(with_columns: &[String]) -> DataFrame { } /// reads DataFrame from path, casts cols to str and numeric cols to f64 -pub fn path_to_df(path: &str, cast_to_str: &[String], cast_to_f64: &[String]) -> DataFrame { +pub fn path_to_df(path: &str, cast_to_str: &[String], cast_to_f64: &[String]) -> LazyFrame { let mut vc = Vec::with_capacity(cast_to_str.len() + cast_to_f64.len()); for str_col in cast_to_str { vc.push(Field::new(str_col, DataType::Utf8)) @@ -33,51 +34,57 @@ pub fn path_to_df(path: &str, cast_to_str: &[String], cast_to_f64: &[String]) -> // if path provided, then we expect it to be of the correct format // unrecoverable. Panic if failed to read file - let df = LazyCsvReader::new(path) + let lf = LazyCsvReader::new(path) .has_header(true) .with_parse_dates(true) .with_dtype_overwrite(Some(&schema)) //.with_ignore_parser_errors(ignore) .finish() - .and_then(|lf| lf.collect()) .unwrap_or_else(|_| panic!("Error reading file: {path}")); - df + lf } pub fn finish( a2h: Vec, f2a: Vec, measures: Vec, - mut df_attr: DataFrame, - df_hms: DataFrame, - mut concatinated_frame: DataFrame, + mut df_attr: LazyFrame, + df_hms: LazyFrame, + mut concatinated_frame: LazyFrame, build_params: HashMap, -) -> (DataFrame, Vec, HashMap) { +) -> (LazyFrame, Vec, HashMap) { // join with hms if a2h was provided if !a2h.is_empty() { let a2h_expr = a2h.iter().map(|c| col(c)).collect::>(); - df_attr = df_attr.lazy() - .join(df_hms.lazy(), a2h_expr.clone(), a2h_expr, JoinType::Left) - .collect() - .expect("Could not join attributes to hms. Review attributes_join_hierarchy field in the setup"); + df_attr = df_attr.join(df_hms, a2h_expr.clone(), a2h_expr, JoinType::Left) + //.collect() + //.expect("Could not join attributes to hms. Review attributes_join_hierarchy field in the setup"); } // if files to attributes was provided if !f2a.is_empty() { let f2a_expr = f2a.iter().map(|c| col(c)).collect::>(); - concatinated_frame = concatinated_frame.lazy() - .join(df_attr.lazy(), f2a_expr.clone(), f2a_expr, JoinType::Outer) - .collect() - .expect("Could not join files with attributes-hms. Review files_join_attributes field in the setup"); + concatinated_frame = + concatinated_frame.join(df_attr, f2a_expr.clone(), f2a_expr, JoinType::Outer) + //.collect() + //.expect("Could not join files with attributes-hms. Review files_join_attributes field in the setup"); } // if measures were provided let measures = if !measures.is_empty() { + let schema = concatinated_frame + .schema() + .expect("Could not extract Schema"); + let fields = schema + .iter_fields() + .map(|f| f.name) + .collect::>(); + // Checking if each measure is present in DF measures.iter().for_each(|col| { - concatinated_frame - .column(col) - .unwrap_or_else(|_| panic!("Column {} not found", col)); + if !fields.contains(col) { + panic!("Measure: {}, is not part of the fields: {:?}", col, fields) + } }); derive_basic_measures_vec(measures) } @@ -89,3 +96,59 @@ pub fn finish( (concatinated_frame, measures, build_params) } + +/// TODO contribute to Polars +/// Concat [LazyFrame]s diagonally. +/// Calls [concat] internally. +pub fn diag_concat_lf>( + lfs: L, + rechunk: bool, + parallel: bool, +) -> PolarsResult { + let lfs = lfs.as_ref().to_vec(); + let upper_bound_width = lfs + .iter() + .map(|lf| Ok(lf.schema()?.len())) + .collect::>>()? + .iter() + .sum(); + // Use Vec instead of a HashSet to preserve order + let mut column_names = Vec::with_capacity(upper_bound_width); + let mut total_schema = Vec::with_capacity(upper_bound_width); + + for lf in lfs.iter() { + lf.schema()?.iter().for_each(|(name, dtype)| { + if !column_names.contains(name) { + column_names.push(name.clone()); + total_schema.push((name.clone(), dtype.clone())) + } + }); + } + + let dfs = lfs + .into_iter() + .map(|mut lf| { + // Get current frame's Schema + let lf_schema = lf.schema()?; + + for (name, dtype) in total_schema.iter() { + // If a name from Total Schema is not present - append + if lf_schema.get_field(name).is_none() { + lf = lf.with_column(NULL.lit().cast(dtype.clone()).alias(name)) + } + } + + // Now, reorder to match schema + let reordered_lf = lf.select( + column_names + .iter() + .map(|col_name| col(col_name)) + .collect::>(), + ); + + Ok(reordered_lf) + }) + .collect::>>()?; + + concat(dfs, rechunk, parallel) +} diff --git a/base_engine/src/datasource/mod.rs b/base_engine/src/datasource/mod.rs index 0e2c8334..a77f1cfb 100644 --- a/base_engine/src/datasource/mod.rs +++ b/base_engine/src/datasource/mod.rs @@ -8,6 +8,8 @@ use crate::Measure; pub mod helpers; use helpers::{empty_frame, finish, path_to_df}; +use self::helpers::diag_concat_lf; + #[cfg(feature = "aws_s3")] pub mod awss3; @@ -79,7 +81,7 @@ impl DataSourceConfig { /// Returns: /// /// (joined concatinated DataFrame, vec of base measures, build params) - pub fn build(self) -> (DataFrame, Vec, HashMap) { + pub fn build(self) -> (LazyFrame, Vec, HashMap) { match self { DataSourceConfig::CSV { file_paths: files, @@ -92,14 +94,19 @@ impl DataSourceConfig { f1_numeric_cols: f64_cols, build_params, } => { - // what if str_cols already contains f2a? - str_cols.extend(f2a.clone()); + for s in f2a.iter() { + if !str_cols.contains(s) { + str_cols.push(s.to_string()) + } + } - let concatinated_frame = diag_concat_df( + let concatinated_frame = diag_concat_lf( &files .iter() .map(|f| path_to_df(f, &str_cols, &f64_cols)) - .collect::>(), + .collect::>(), + true, + true, ) .expect("Failed to concatinate provided frames"); // <- Ok to panic upon server startup @@ -108,18 +115,19 @@ impl DataSourceConfig { let df_attr = match ta { Some(y) => path_to_df(&y, &tmp, &f64_cols) - .unique(Some(&f2a), UniqueKeepStrategy::First) - .unwrap(), - _ => empty_frame(&tmp), + .unique(Some(f2a.clone()), UniqueKeepStrategy::First), + //.unwrap(), + _ => empty_frame(&tmp).lazy(), }; //here we expect if hms is provided then a2h is not empty - let df_hms = match hms{ - Some(y) =>{ path_to_df(&y, &a2h, &[]) - .unique(Some(&a2h), UniqueKeepStrategy::First) - .expect("hms file path was provided, hence attributes_join_hierarchy list must also be provided - in the datasource_config.toml") }, - _ => empty_frame(&a2h) }; + let df_hms = match hms { + Some(y) => path_to_df(&y, &a2h, &[]) + .unique(Some(a2h.clone()), UniqueKeepStrategy::First), + //.expect("hms file path was provided, hence attributes_join_hierarchy list must also be provided + //in the datasource_config.toml") }, + _ => empty_frame(&a2h).lazy(), + }; finish( a2h, @@ -178,9 +186,9 @@ impl DataSourceConfig { a2h, f2a, measures, - df_attr, - df_hms, - concatinated_frame, + df_attr.lazy(), + df_hms.lazy(), + concatinated_frame.lazy(), build_params, ) } diff --git a/base_engine/src/overrides.rs b/base_engine/src/overrides.rs index 873c6395..8e5e1d3c 100644 --- a/base_engine/src/overrides.rs +++ b/base_engine/src/overrides.rs @@ -41,7 +41,6 @@ impl Override { pub fn lf_with_overwrite(&self, lf: LazyFrame) -> PolarsResult { let schema = lf.schema()?; let dt = schema.try_get(self.field.as_str())?; - //let dt = lf.column(&self.field)?.dtype(); let lt = string_to_lit(&self.value, dt, &self.field)?; let new_col_as_expr = self.override_builder(lt); Ok(lf.with_column(new_col_as_expr)) diff --git a/base_engine/src/prelude.rs b/base_engine/src/prelude.rs index 00ec42b1..998a974f 100644 --- a/base_engine/src/prelude.rs +++ b/base_engine/src/prelude.rs @@ -4,3 +4,6 @@ pub use super::datarequest::*; pub use super::dataset::*; pub use super::datasource::*; pub use super::measure::*; + +//Reexports +pub use polars::prelude::*; diff --git a/base_engine/tests/common/mod.rs b/base_engine/tests/common/mod.rs index 217eb6fb..b633b12b 100644 --- a/base_engine/tests/common/mod.rs +++ b/base_engine/tests/common/mod.rs @@ -7,7 +7,7 @@ pub static TEST_DASET: Lazy> = Lazy::new(|| { let conf_path = r"./tests/data/test_config.toml"; let conf = read_toml2::(conf_path) .expect("Can not proceed without valid Data Set Up"); //Unrecovarable error - let mut data: DataSetBase = DataSet::build(conf); + let mut data: DataSetBase = DataSet::from_config(conf); data.prepare(); Arc::new(data) }); diff --git a/base_engine/tests/datasource.rs b/base_engine/tests/datasource.rs index 92c0e693..070b78eb 100644 --- a/base_engine/tests/datasource.rs +++ b/base_engine/tests/datasource.rs @@ -12,15 +12,16 @@ fn config_build() { let conf_path = r"./tests/data/bad_config.toml"; let conf = read_toml2::(conf_path) .expect("Can not proceed without valid Data Set Up"); - let mut _data: DataSetBase = DataSet::build(conf); + let mut _data: DataSetBase = DataSet::from_config(conf); } /// In this config, files_join_attributes was provided but no such column is present #[test] -#[should_panic(expected = "NotFound")] +#[should_panic(expected = "Couldn't build")] fn config_build2() { let conf_path = r"./tests/data/bad_config2.toml"; let conf = read_toml2::(conf_path) .expect("Can not proceed without valid Data Set Up"); - let mut _data: DataSetBase = DataSet::build(conf); + let (lf, _, _) = conf.build(); + lf.collect().expect("Couldn't build"); } diff --git a/driver/Cargo.toml b/driver/Cargo.toml index 80816037..b4060b75 100644 --- a/driver/Cargo.toml +++ b/driver/Cargo.toml @@ -47,3 +47,4 @@ FRTB = ["frtb_engine"] # feature FRTB of my lib activate optional dep frtb_engin FRTB_CRR2 = ["FRTB", "frtb_engine/CRR2"] # feature FRTB_CRR2 activates optional dep frtb_engine with it's CRR2 feature # BUT also need to activate FRTB because one_off.rs and server.rs have #[cfg(feature = "FRTB")] cache = ["base_engine/cache"] +streaming = [] diff --git a/driver/src/api/mod.rs b/driver/src/api/mod.rs index 9a43da36..fa993607 100644 --- a/driver/src/api/mod.rs +++ b/driver/src/api/mod.rs @@ -25,7 +25,7 @@ use std::{net::TcpListener, sync::Arc}; use tokio::task; use base_engine::{ - api::aggregations::BASE_CALCS, prelude::PolarsResult, AggregationRequest, DataSet, + api::aggregations::BASE_CALCS, col, prelude::PolarsResult, AggregationRequest, DataSet, }; #[cfg(feature = "cache")] @@ -72,7 +72,9 @@ async fn column_search( let (page, pat) = (page.page, page.pattern.clone()); let res = task::spawn_blocking(move || { let d = data.get_ref(); - let srs = d.frame().column(&column_name)?; + let lf = d.lazy_frame(); + let df = lf.clone().select([col(&column_name)]).collect()?; + let srs = df.column(&column_name)?; let search = base_engine::searches::filter_contains_unique(srs, &pat)?; let first = page * PER_PAGE as usize; let last = first + PER_PAGE as usize; @@ -82,6 +84,7 @@ async fn column_search( .await .context("Failed to spawn blocking task.") .map_err(actix_web::error::ErrorInternalServerError)?; + match res { Ok(srs) => Ok(HttpResponse::Ok().json(Vec::from( srs.utf8() diff --git a/driver/src/bin/one_off.rs b/driver/src/bin/one_off.rs index d37c689e..2183e36c 100644 --- a/driver/src/bin/one_off.rs +++ b/driver/src/bin/one_off.rs @@ -43,19 +43,19 @@ fn main() -> anyhow::Result<()> { // Build Data let data = acquire::data::(setup_path.as_str()); - let x = Arc::new(data); + let arc_data = Arc::new(data); let json = fs::read_to_string(requests_path.as_str()).expect("Unable to read request file"); // Later this will be RequestE (to match other requests as well) - let requests: Vec = serde_json::from_str(&json).unwrap(); + let requests: Vec = serde_json::from_str(&json).expect("Bad requests"); // From here we do not panic for request in requests { let rqst_str = serde_json::to_string(&request); info!("{:?}", rqst_str); let now = Instant::now(); - match base_engine::execute_aggregation(request, Arc::clone(&x)) { + match base_engine::execute_aggregation(request, Arc::clone(&arc_data)) { Err(e) => { error!("On request: {:?}, Application error: {:#?}", rqst_str, e); continue; diff --git a/driver/src/helpers/acquire.rs b/driver/src/helpers/acquire.rs index 594e4fb5..54f2273c 100644 --- a/driver/src/helpers/acquire.rs +++ b/driver/src/helpers/acquire.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use base_engine::{read_toml2, DataSet, DataSourceConfig}; +use base_engine::{derive_measure_map, read_toml2, DataSet, DataSourceConfig}; use log::info; /// Reads initial DataSet from Source @@ -14,14 +14,30 @@ pub fn data(config_path: &str) -> impl DataSet { .expect("Can not proceed without valid Data Set Up"); //Unrecovarable error info!("Data SetUp: {:?}", conf); - // Build data - let mut data = DS::build(conf); + let (lf, measure_vec, build_params) = conf.build(); + let mut data = DS::new(lf, derive_measure_map(measure_vec), build_params); + + // If cfg is streaming then we can't collect, otherwise collect to check errors + if !cfg!(feature = "streaming") { + let now = Instant::now(); + data = data.collect().expect("Failed to read frame"); + let elapsed = now.elapsed(); + println!("Time to Read/Aggregate DF: {:.6?}", elapsed); + } + + // Build DataSet + // TODO // data.validate().expect(); + // Pre build some columns, which you wish to store in memory alongside the original data - let now = Instant::now(); data.prepare(); - let elapsed = now.elapsed(); - println!("Time to Prepare DF: {:.6?}", elapsed); + if !cfg!(feature = "streaming") { + let now = Instant::now(); + data = data.collect().expect("Failed to Prepare Frame"); + let elapsed = now.elapsed(); + println!("Time to Prepare DF: {:.6?}", elapsed); + } + data } diff --git a/driver/src/request.json b/driver/src/request.json index c670c964..013b7819 100644 --- a/driver/src/request.json +++ b/driver/src/request.json @@ -3,12 +3,16 @@ "filters": [], "groupby": [ "Country", - "TradeId" + "BookId" ], "measures": [ [ "SBM_Charge", "first" + ], + [ + "EQ_TotalCharge_Medium", + "first" ] ], "overrides": [], diff --git a/frtb_engine/src/lib.rs b/frtb_engine/src/lib.rs index 4f7c3e28..16b679e2 100644 --- a/frtb_engine/src/lib.rs +++ b/frtb_engine/src/lib.rs @@ -19,16 +19,13 @@ use prelude::{calc_params::frtb_calc_params, drc::common::drc_scalinng, frtb_mea use risk_weights::*; use sbm::buckets; -use polars::prelude::*; -use serde::Serialize; use std::collections::HashMap; pub trait FRTBDataSetT { fn prepare(self) -> Self; } -#[derive(Debug, Serialize)] pub struct FRTBDataSet { - pub frame: DataFrame, + pub frame: LazyFrame, pub measures: MeasuresMap, pub build_params: HashMap, //pub calc_params: Vec @@ -45,7 +42,7 @@ impl FRTBDataSet { } impl DataSet for FRTBDataSet { - fn frame(&self) -> &DataFrame { + fn lazy_frame(&self) -> &LazyFrame { &self.frame } fn measures(&self) -> &MeasuresMap { @@ -55,176 +52,190 @@ impl DataSet for FRTBDataSet { frtb_calc_params() } - fn build(conf: DataSourceConfig) -> Self { + fn from_config(conf: DataSourceConfig) -> Self { let (frames, measure_cols, build_params) = conf.build(); let mm: MeasuresMap = derive_measure_map(measure_cols); let mut res = Self { frame: frames, measures: mm, build_params, - //calc_params: frtb_calc_params(), }; res.with_measures(frtb_measure_vec()); res } + + fn new(frame: LazyFrame, mm: MeasuresMap, build_params: HashMap) -> Self { + let mut res = Self { + frame, + measures: mm, + build_params, + }; + res.with_measures(frtb_measure_vec()); + res + } + + fn collect(self) -> PolarsResult { + let lf = self.frame.collect()?.lazy(); + Ok(Self { frame: lf, ..self }) + } /// Adds: BCBS buckets, CRR2 Buckets /// Adds: SensWeights, CurvatureRiskWeight, SensWeightsCRR2, SeniorityRank fn prepare(&mut self) { let f1 = &mut self.frame; - if f1.height() != 0 { - //First, identify buckets - let mut lf1 = f1 - .clone() - .lazy() - .with_column(buckets::sbm_buckets(&self.build_params)); - // If CRR2, then also provide CRR2 buckets - #[cfg(feature = "CRR2")] - if cfg!(feature = "CRR2") { - lf1 = lf1.with_column(buckets::sbm_buckets_crr2()) - }; - - // Then assign risk weights based on buckets - lf1 = lf1.with_column(weights_assign(&self.build_params).alias("SensWeights")); - //let tmp_frame = lf1.collect().expect("Failed to unwrap tmp_frame while .prepare()"); - - // Some risk weights assignments (DRC Sec Non CTP) would result in too many when().then() statements - // which panics: https://github.com/pola-rs/polars/issues/4827 - // Hence, for such scenarios we need to use left join - let drc_secnonctp_weights: DataFrame = drc_weights::drc_secnonctp_weights_frame(); - let left_on = concat_str( - [ - col("CreditQuality").map( - |s| Ok(s.utf8()?.to_uppercase().into_series()), - GetOutput::from_type(DataType::Utf8), - ), - col("RiskFactorType").map( - |s| Ok(s.utf8()?.to_uppercase().into_series()), - GetOutput::from_type(DataType::Utf8), - ), - ], - "_", + //First, identify buckets + let mut lf1 = f1 + .clone() + .with_column(buckets::sbm_buckets(&self.build_params)); + // If CRR2, then also provide CRR2 buckets + #[cfg(feature = "CRR2")] + if cfg!(feature = "CRR2") { + lf1 = lf1.with_column(buckets::sbm_buckets_crr2()) + }; + + // Then assign risk weights based on buckets + lf1 = lf1.with_column(weights_assign(&self.build_params).alias("SensWeights")); + //let tmp_frame = lf1.collect().expect("Failed to unwrap tmp_frame while .prepare()"); + + // Some risk weights assignments (DRC Sec Non CTP) would result in too many when().then() statements + // which panics: https://github.com/pola-rs/polars/issues/4827 + // Hence, for such scenarios we need to use left join + let drc_secnonctp_weights: DataFrame = drc_weights::drc_secnonctp_weights_frame(); + let left_on = concat_str( + [ + col("CreditQuality").map( + |s| Ok(s.utf8()?.to_uppercase().into_series()), + GetOutput::from_type(DataType::Utf8), + ), + col("RiskFactorType").map( + |s| Ok(s.utf8()?.to_uppercase().into_series()), + GetOutput::from_type(DataType::Utf8), + ), + ], + "_", + ) + .alias("LeftKey"); + + lf1 = lf1 + .left_join(drc_secnonctp_weights.lazy(), left_on, col("Key")) + .with_column(concat_lst([col("RiskWeightDRC")])); + //let tmp_frame = lf1 + // .collect() + // .expect("Failed to unwrap tmp_frame while .prepare()"); + + // lf1 = tmp_frame + // .lazy() + lf1 = lf1 + .with_column( + when(col("RiskClass").eq(lit("DRC_SecNonCTP"))) + .then(col("RiskWeightDRC")) + .otherwise(col("SensWeights")) + .alias("SensWeights"), ) - .alias("LeftKey"); - - lf1 = lf1 - .left_join(drc_secnonctp_weights.lazy(), left_on, col("Key")) - .with_column(concat_lst([col("RiskWeightDRC")])); - let tmp_frame = lf1 - .collect() - .expect("Failed to unwrap tmp_frame while .prepare()"); - - lf1 = tmp_frame - .lazy() - .with_column( - when(col("RiskClass").eq(lit("DRC_SecNonCTP"))) - .then(col("RiskWeightDRC")) - .otherwise(col("SensWeights")) - .alias("SensWeights"), - ) - .select([col("*").exclude(["RiskWeightDRC", "LeftKey"])]); - let tmp_frame = lf1 - .collect() - .expect("Failed to unwrap tmp_frame while .prepare()"); + .select([col("*").exclude(["RiskWeightDRC", "LeftKey"])]); + //let tmp_frame = lf1 + // .collect() + // .expect("Failed to unwrap tmp_frame while .prepare()"); + + // Curvature risk weight + //lf1 = tmp_frame.lazy() + lf1 = lf1.with_column( + when( + col("PnL_Up") + .is_not_null() + .or(col("PnL_Down").is_not_null()), + ) + .then(col("SensWeights").arr().max().alias("CurvatureRiskWeight")) + .otherwise(NULL.lit()), + ); - // Curvature risk weight - lf1 = tmp_frame.lazy().with_column( + // Now, ammend weights if required. ie has to be done after main assignment of risk weights + let mut other_cols: Vec = vec![]; + // 21.53 Footnote 17 + let csrnonsec_covered_bond_15 = self + .build_params + .get("csrnonsec_covered_bond_15") + .and_then(|s| s.parse::().ok()) + .unwrap_or_else(|| false); + + if csrnonsec_covered_bond_15 { + other_cols.push( when( - col("PnL_Up") - .is_not_null() - .or(col("PnL_Down").is_not_null()), + col("RiskClass") + .eq(lit("CSR_nonSec")) + .and(col("RiskCategory").eq(lit("Delta"))) + .and(col("BucketBCBS").eq(lit("8"))) + .and(col("CoveredBondReducedWeight").eq(lit::(true))), ) - .then(col("SensWeights").arr().max().alias("CurvatureRiskWeight")) - .otherwise(NULL.lit()), - ); - - // Now, ammend weights if required. ie has to be done after main assignment of risk weights - let mut other_cols: Vec = vec![]; - // 21.53 Footnote 17 - let csrnonsec_covered_bond_15 = self - .build_params - .get("csrnonsec_covered_bond_15") - .and_then(|s| s.parse::().ok()) - .unwrap_or_else(|| false); + .then(Series::new("", &[0.015]).lit().list()) + .otherwise(col("SensWeights")) + .alias("SensWeights"), + ) + }; + // If CRR2 config, we need to derive SensWeightsCRR2 + #[cfg(feature = "CRR2")] + if cfg!(feature = "CRR2") { + other_cols.push(weights_assign_crr2().alias("SensWeightsCRR2")) + }; + if !other_cols.is_empty() { + lf1 = lf1.with_columns(&other_cols) + }; + + // Now, we need to also ammend CRR2 weights + // Bucket 10 as per + // https://www.eba.europa.eu/regulation-and-policy/single-rulebook/interactive-single-rulebook/108776 + #[cfg(feature = "CRR2")] + if cfg!(feature = "CRR2") { + let mut with_cols = vec![col("SensWeightsCRR2") + .arr() + .max() + .alias("CurvatureRiskWeightCRR2")]; if csrnonsec_covered_bond_15 { - other_cols.push( + with_cols.push( when( col("RiskClass") .eq(lit("CSR_nonSec")) .and(col("RiskCategory").eq(lit("Delta"))) - .and(col("BucketBCBS").eq(lit("8"))) + .and(col("BucketCRR2").eq(lit("10"))) .and(col("CoveredBondReducedWeight").eq(lit::(true))), ) .then(Series::new("", &[0.015]).lit().list()) - .otherwise(col("SensWeights")) - .alias("SensWeights"), + .otherwise(col("SensWeightsCRR2")) + .alias("SensWeightsCRR2"), ) - }; - // If CRR2 config, we need to derive SensWeightsCRR2 - #[cfg(feature = "CRR2")] - if cfg!(feature = "CRR2") { - other_cols.push(weights_assign_crr2().alias("SensWeightsCRR2")) - }; - - if !other_cols.is_empty() { - lf1 = lf1.with_columns(&other_cols) - }; - - // Now, we need to also ammend CRR2 weights - // Bucket 10 as per - // https://www.eba.europa.eu/regulation-and-policy/single-rulebook/interactive-single-rulebook/108776 - #[cfg(feature = "CRR2")] - if cfg!(feature = "CRR2") { - let mut with_cols = vec![col("SensWeightsCRR2") - .arr() - .max() - .alias("CurvatureRiskWeightCRR2")]; - if csrnonsec_covered_bond_15 { - with_cols.push( - when( - col("RiskClass") - .eq(lit("CSR_nonSec")) - .and(col("RiskCategory").eq(lit("Delta"))) - .and(col("BucketCRR2").eq(lit("10"))) - .and(col("CoveredBondReducedWeight").eq(lit::(true))), - ) - .then(Series::new("", &[0.015]).lit().list()) - .otherwise(col("SensWeightsCRR2")) - .alias("SensWeightsCRR2"), - ) - } - - lf1 = lf1.with_columns(with_cols) } - // Have to collect into a tmp df, as the code panics otherwise - let tmp_frame = lf1 - .collect() - .expect("Failed to unwrap tmp_frame while .prepare()"); - lf1 = tmp_frame.lazy().with_columns(&[ - when( - col("RiskClass") - .eq(lit("GIRR")) - .and(col("RiskCategory").eq(lit("Vega"))), - ) - .then(col("GirrVegaUnderlyingMaturity").fill_null(col("RiskFactorType"))) - .otherwise(NULL.lit()), - drc_scalinng( - self.build_params - .get("DayCountConvention") - .and_then(|x| x.parse::().ok()), - self.build_params.get("DateFormat"), - ) - .alias("ScaleFactor"), - drc_seniority().alias("SeniorityRank"), - ]); - let tmp2_frame = lf1 - .collect() - .expect("Failed to unwrap tmp2_frame while .prepare()"); - - *f1 = tmp2_frame; + lf1 = lf1.with_columns(with_cols) } + + // Have to collect into a tmp df, as the code panics otherwise + //let tmp_frame = lf1 + // .collect() + // .expect("Failed to unwrap tmp_frame while .prepare()"); + //lf1 = tmp_frame.lazy() + lf1 = lf1.with_columns(&[ + when( + col("RiskClass") + .eq(lit("GIRR")) + .and(col("RiskCategory").eq(lit("Vega"))), + ) + .then(col("GirrVegaUnderlyingMaturity").fill_null(col("RiskFactorType"))) + .otherwise(NULL.lit()), + drc_scalinng( + self.build_params + .get("DayCountConvention") + .and_then(|x| x.parse::().ok()), + self.build_params.get("DateFormat"), + ) + .alias("ScaleFactor"), + drc_seniority().alias("SeniorityRank"), + ]); + //let tmp2_frame = lf1 + // .collect() + // .expect("Failed to unwrap tmp2_frame while .prepare()"); + + *f1 = lf1; } // TODO Validate: diff --git a/frtb_engine/src/sbm/buckets.rs b/frtb_engine/src/sbm/buckets.rs index 85458df8..ea53864f 100644 --- a/frtb_engine/src/sbm/buckets.rs +++ b/frtb_engine/src/sbm/buckets.rs @@ -15,21 +15,19 @@ pub fn sbm_buckets(_: &HashMap) -> Expr { .or(col("RiskClass").eq(lit("GIRR"))), ) .then(col("BucketBCBS").fill_null( - col("RiskFactor") - // Code will panic on a long onshore-offshore when().then() - // Hence, such mapping has to be done separately - /*.map( - move |srs| { - let mut res = srs.utf8()?.to_owned(); - for (k, v) in &offshore_onshore { - res = res.replace(k, v)?; - } - Ok(res.into_series()) - }, - GetOutput::from_type(DataType::Utf8), - ) */ - ) - ) + col("RiskFactor"), // Code will panic on a long onshore-offshore when().then() + // Hence, such mapping has to be done separately + /*.map( + move |srs| { + let mut res = srs.utf8()?.to_owned(); + for (k, v) in &offshore_onshore { + res = res.replace(k, v)?; + } + Ok(res.into_series()) + }, + GetOutput::from_type(DataType::Utf8), + ) */ + )) .otherwise(col("BucketBCBS")) } diff --git a/frtb_engine/src/sbm/common.rs b/frtb_engine/src/sbm/common.rs index fdbfc686..212d8ecc 100644 --- a/frtb_engine/src/sbm/common.rs +++ b/frtb_engine/src/sbm/common.rs @@ -7,7 +7,7 @@ use polars::export::num::Signed; use polars::lazy::dsl::{apply_multiple, GetOutput}; use polars::prelude::{ AnyValue, ChunkAgg, ChunkSet, DataType, FillNullStrategy, Float64Chunked, Float64Type, - NumOpsDispatch, PolarsError, TakeRandom, + NumOpsDispatch, PolarsError, }; use polars::series::{ChunkCompare, IntoSeries, Series}; diff --git a/frtb_engine/tests/common/mod.rs b/frtb_engine/tests/common/mod.rs index 6e3c3f30..8795306c 100644 --- a/frtb_engine/tests/common/mod.rs +++ b/frtb_engine/tests/common/mod.rs @@ -10,7 +10,7 @@ pub static LAZY_DASET: Lazy> = Lazy::new(|| { let conf_path = r"./tests/data/datasource_config.toml"; let conf = read_toml2::(conf_path) .expect("Can not proceed without valid Data Set Up"); //Unrecovarable error - let mut data: FRTBDataSet = DataSet::build(conf); + let mut data: FRTBDataSet = DataSet::from_config(conf); data.prepare(); Arc::new(data) });