Skip to content

Commit

Permalink
Add test to guarantee that replays are handle correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Oct 6, 2023
1 parent 4a18a1e commit 03460b8
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 7 deletions.
173 changes: 173 additions & 0 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,176 @@ fn resolve_agg_job_id<'id, S>(
_ => unreachable!("unhandled resource {:?}", req.resource),
}
}

#[cfg(test)]
mod tests {
use std::borrow::Cow;
use std::sync::Arc;

use assert_matches::assert_matches;
use futures::StreamExt;
use prio::codec::ParameterizedDecode;

use crate::messages::{
AggregationJobInitReq, AggregationJobResp, ReportShare, Transition, TransitionVar,
};
use crate::roles::DapHelper;
use crate::MetaAggregationJobId;
use crate::{roles::test::TestData, DapVersion};

#[tokio::test]
async fn replay_reports_when_continuing_aggregation() {
let mut data = TestData::new(DapVersion::Draft02);
let task_id = data.insert_task(
DapVersion::Draft02,
crate::VdafConfig::Prio2 { dimension: 100_000 },
);
let helper = data.new_helper();
let test = data.with_leader(Arc::clone(&helper));

let report_shares = futures::stream::iter(0..3)
.then(|_| async {
let mut report = test.gen_test_report(&task_id).await;
ReportShare {
report_metadata: report.report_metadata,
public_share: report.public_share,
encrypted_input_share: report.encrypted_input_shares.remove(1),
}
})
.collect::<Vec<_>>()
.await;

let report_ids = report_shares
.iter()
.map(|r| r.report_metadata.id.clone())
.collect::<Vec<_>>();

let req = test
.gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, report_shares)
.await;

let meta_agg_job_id = MetaAggregationJobId::Draft02(Cow::Owned(
AggregationJobInitReq::get_decoded_with_param(&DapVersion::Draft02, &req.payload)
.unwrap()
.draft02_agg_job_id
.unwrap(),
));

helper
.handle_agg_job_init_req(&req, helper.metrics.with_host("test"), &task_id)
.await
.unwrap();

let helper_state = helper
.get_helper_state(&task_id, &meta_agg_job_id)
.await
.unwrap()
.unwrap();
// restore helper state so that the aggregation job id is found. The reason this is
// needed is because the test implementation removes the state when get is called.
helper
.put_helper_state_if_not_exists(&task_id, &meta_agg_job_id, &helper_state)
.await
.unwrap();
{
let req = test
.gen_test_agg_job_cont_req(
&meta_agg_job_id,
report_ids[0..2]
.iter()
.map(|id| Transition {
report_id: id.clone(),
var: TransitionVar::Continued(vec![]),
})
.collect(),
DapVersion::Draft02,
)
.await;

let resp = helper
.handle_agg_job_cont_req(&req, helper.metrics.with_host("test"), &task_id)
.await
.unwrap();

let a_job_resp =
AggregationJobResp::get_decoded_with_param(&DapVersion::Draft02, &resp.payload)
.unwrap();
assert_eq!(a_job_resp.transitions.len(), 2);
assert!(a_job_resp
.transitions
.iter()
.all(|t| matches!(t.var, TransitionVar::Finished)));
}
{
// restore helper state so that the aggregation job id is found. The reason this is
// needed is because the test implementation removes the state when get is called.
helper
.put_helper_state_if_not_exists(&task_id, &meta_agg_job_id, &helper_state)
.await
.unwrap();

let req = test
.gen_test_agg_job_cont_req(
&meta_agg_job_id,
report_ids[1..3]
.iter()
.map(|id| Transition {
report_id: id.clone(),
var: TransitionVar::Continued(vec![]),
})
.collect(),
DapVersion::Draft02,
)
.await;

let resp = helper
.handle_agg_job_cont_req(&req, helper.metrics.with_host("test"), &task_id)
.await
.unwrap();

let a_job_resp =
AggregationJobResp::get_decoded_with_param(&DapVersion::Draft02, &resp.payload)
.unwrap();
assert_eq!(a_job_resp.transitions.len(), 2);
assert_matches!(
a_job_resp.transitions[0].var,
TransitionVar::Failed(crate::messages::TransitionFailure::ReportReplayed)
);
assert_matches!(a_job_resp.transitions[1].var, TransitionVar::Finished);
};

let Some(metric) = test
.prometheus_registry
.gather()
.into_iter()
.find(|metric| metric.get_name().ends_with("report_counter"))
.map(|mut m| m.take_metric())
else {
panic!("report_counter metric no found");
};

let Some(aggregated_counter) = metric
.iter()
.find(|m| m.get_label().iter().any(|l| l.get_value() == "aggregated"))
.map(|m| m.get_counter())
else {
panic!("no aggregated metric found");
};

assert_eq!(aggregated_counter.get_value(), 3.0);

let Some(rejected_counter) = metric
.iter()
.find(|m| {
m.get_label()
.iter()
.any(|l| l.get_value() == "rejected_report_replayed")
})
.map(|m| m.get_counter())
else {
panic!("no aggregated metric found");
};

assert_eq!(rejected_counter.get_value(), 1.0);
}
}
13 changes: 6 additions & 7 deletions daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ mod test {
let vdaf_config = VdafConfig::Prio3(Prio3Config::Count);
let leader_url = Url::parse("https://leader.com/v02/").unwrap();
let helper_url = Url::parse("http://helper.org:8788/v02/").unwrap();
let time_precision = 3600;
let collector_hpke_receiver_config =
HpkeReceiverConfig::gen(rng.gen(), HpkeKemId::X25519HkdfSha256).unwrap();

Expand All @@ -294,8 +293,8 @@ mod test {
collector_hpke_config: collector_hpke_receiver_config.config.clone(),
leader_url: leader_url.clone(),
helper_url: helper_url.clone(),
time_precision,
expiration: now + 3600,
time_precision: Self::TASK_TIME_PRECISION,
expiration: now + Self::TASK_TIME_PRECISION,
min_batch_size: 1,
query: DapQueryConfig::TimeInterval,
vdaf: vdaf_config.clone(),
Expand All @@ -310,8 +309,8 @@ mod test {
collector_hpke_config: collector_hpke_receiver_config.config.clone(),
leader_url: leader_url.clone(),
helper_url: helper_url.clone(),
time_precision,
expiration: now + 3600,
time_precision: Self::TASK_TIME_PRECISION,
expiration: now + Self::TASK_TIME_PRECISION,
min_batch_size: 1,
query: DapQueryConfig::FixedSize { max_batch_size: 2 },
vdaf: vdaf_config.clone(),
Expand All @@ -326,7 +325,7 @@ mod test {
collector_hpke_config: collector_hpke_receiver_config.config.clone(),
leader_url,
helper_url,
time_precision,
time_precision: Self::TASK_TIME_PRECISION,
expiration: now, // Expires this second
min_batch_size: 1,
query: DapQueryConfig::TimeInterval,
Expand Down Expand Up @@ -452,7 +451,7 @@ mod test {
fixed_size_task_id: TaskId,
expired_task_id: TaskId,
version: DapVersion,
prometheus_registry: prometheus::Registry,
pub prometheus_registry: prometheus::Registry,
}

impl Test {
Expand Down

0 comments on commit 03460b8

Please sign in to comment.