Skip to content

Commit

Permalink
feat: WIP on metric sets & outputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
jetuk committed Aug 22, 2023
1 parent d600f53 commit 7fcc506
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 46 deletions.
14 changes: 10 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate core;

use crate::node::NodeIndex;
use crate::parameters::{IndexParameterIndex, MultiValueParameterIndex, ParameterIndex};
use crate::recorders::RecorderIndex;
use crate::recorders::{MetricSetIndex, RecorderIndex};
use crate::schema::ConversionError;
use thiserror::Error;

Expand Down Expand Up @@ -51,17 +51,23 @@ pub enum PywrError {
MultiValueParameterKeyNotFound(String),
#[error("parameter {0} not found")]
ParameterNotFound(String),
#[error("metric set index not found")]
MetricSetIndexNotFound,
#[error("metric set with name {0} not found")]
MetricSetNotFound(String),
#[error("recorder index not found")]
RecorderIndexNotFound,
#[error("recorder not found")]
RecorderNotFound,
#[error("node name `{0}` already exists")]
NodeNameAlreadyExists(String),
#[error("parameter name `{0}` already exists on parameter {1}")]
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex),
#[error("index parameter name `{0}` already exists on index parameter {1}")]
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, IndexParameterIndex),
#[error("recorder name `{0}` already exists on parameter {1}")]
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
RecorderNameAlreadyExists(String, RecorderIndex),
#[error("connections from output nodes are invalid. node: {0}")]
InvalidNodeConnectionFromOutput(String),
Expand Down
34 changes: 34 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::edge::{EdgeIndex, EdgeVec};
use crate::metric::Metric;
use crate::node::{ConstraintValue, Node, NodeVec, StorageInitialVolume};
use crate::parameters::{MultiValueParameterIndex, ParameterType};
use crate::recorders::{MetricSet, MetricSetIndex};
use crate::scenario::{ScenarioGroupCollection, ScenarioIndex};
use crate::solvers::{MultiStateSolver, Solver, SolverTimings};
use crate::state::{ParameterStates, State};
Expand Down Expand Up @@ -168,6 +169,7 @@ pub struct Model {
parameters: Vec<Box<dyn parameters::Parameter>>,
index_parameters: Vec<Box<dyn parameters::IndexParameter>>,
multi_parameters: Vec<Box<dyn parameters::MultiValueParameter>>,
metric_sets: Vec<MetricSet>,
resolve_order: Vec<ComponentType>,
recorders: Vec<Box<dyn recorders::Recorder>>,
}
Expand Down Expand Up @@ -1273,6 +1275,38 @@ impl Model {
Ok(parameter_index)
}

/// Add a [`MetricSet`] to the model.
pub fn add_metric_set(&mut self, metric_set: MetricSet) -> Result<MetricSetIndex, PywrError> {
if let Ok(_) = self.get_metric_set_by_name(&metric_set.name()) {
return Err(PywrError::MetricSetNameAlreadyExists(metric_set.name().to_string()));
}

let metric_set_idx = MetricSetIndex::new(self.metric_sets.len());
self.metric_sets.push(metric_set);
Ok(metric_set_idx)
}

/// Get a [`MetricSet'] from its index.
pub fn get_metric_set(&self, index: MetricSetIndex) -> Result<&MetricSet, PywrError> {
self.metric_sets.get(*index).ok_or(PywrError::MetricSetIndexNotFound)
}

/// Get a ['MetricSet'] by its name.
pub fn get_metric_set_by_name(&self, name: &str) -> Result<&MetricSet, PywrError> {
self.metric_sets
.iter()
.find(|&m| m.name() == name)
.ok_or(PywrError::MetricSetNotFound(name.to_string()))
}

/// Get a ['MetricSetIndex'] by its name.
pub fn get_metric_set_index_by_name(&self, name: &str) -> Result<MetricSetIndex, PywrError> {
match self.metric_sets.iter().position(|m| m.name() == name) {
Some(idx) => Ok(MetricSetIndex::new(idx)),
None => Err(PywrError::MetricSetNotFound(name.to_string())),
}
}

/// Add a `recorders::Recorder` to the model
pub fn add_recorder(&mut self, recorder: Box<dyn recorders::Recorder>) -> Result<RecorderIndex, PywrError> {
// TODO reinstate this check
Expand Down
246 changes: 246 additions & 0 deletions src/recorders/aggregator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
use time::{Date, Duration, Month};

#[derive(Clone, Debug)]
enum AggregationPeriod {
Monthly,
Annual,
}

impl AggregationPeriod {
fn is_date_in_period(&self, period_start: &Date, date: &Date) -> bool {
match self {
Self::Monthly => (period_start.year() == date.year()) && (period_start.month() == date.month()),
Self::Annual => period_start.year() == date.year(),
}
}

fn start_of_next_period(&self, current_date: &Date) -> Date {
match self {
Self::Monthly => {
// Increment the year if we're in December
let year = if current_date.month() == Month::December {
current_date.year() + 1
} else {
current_date.year()
};
// 1st of the next month
Date::from_calendar_date(year, current_date.month().next(), 1).unwrap()
}
// 1st of January in the next year
Self::Annual => Date::from_calendar_date(current_date.year() + 1, Month::January, 1).unwrap(),
}
}

/// Split the value representing a period into multiple ['PeriodValue'] that do not cross the
/// boundary of the given period.
fn split_value_into_periods(&self, value: PeriodValue) -> Vec<PeriodValue> {
let mut sub_values = Vec::new();

let mut current_date = value.start;
let end_date = value.start + value.duration;

while current_date < end_date {
// This should be safe to unwrap as it will always create a valid date unless
// we are at the limit of dates that are representable.
let start_of_next_month = current_date
.replace_day(1)
.unwrap()
.replace_month(current_date.month().next())
.unwrap();

let current_duration = if start_of_next_month <= end_date {
start_of_next_month - current_date
} else {
end_date - current_date
};

sub_values.push(PeriodValue {
start: current_date,
duration: current_duration,
value: value.value,
});

current_date = start_of_next_month;
}

sub_values
}
}

#[derive(Clone, Debug)]
enum AggregationFunction {
Sum,
Mean,
Min,
Max,
}

impl AggregationFunction {
fn calc(&self, values: &[PeriodValue]) -> Option<f64> {
match self {
AggregationFunction::Sum => Some(values.iter().map(|v| v.value * v.duration.whole_days() as f64).sum()),
AggregationFunction::Mean => {
let ndays: i64 = values.iter().map(|v| v.duration.whole_days()).sum();
if ndays == 0 {
None
} else {
let sum: f64 = values.iter().map(|v| v.value * v.duration.whole_days() as f64).sum();

Some(sum / ndays as f64)
}
}
AggregationFunction::Min => values.iter().map(|v| v.value).min_by(|a, b| {
a.partial_cmp(b)
.expect("Failed to calculate minimum of values containing a NaN.")
}),
AggregationFunction::Max => values.iter().map(|v| v.value).max_by(|a, b| {
a.partial_cmp(b)
.expect("Failed to calculate maximum of values containing a NaN.")
}),
}
}
}

#[derive(Default)]
pub struct PeriodicAggregatorState {
current_values: Option<Vec<PeriodValue>>,
}

impl PeriodicAggregatorState {
fn process_value(
&mut self,
value: PeriodValue,
agg_period: &AggregationPeriod,
agg_func: &AggregationFunction,
) -> Option<PeriodValue> {
if let Some(current_values) = self.current_values.as_mut() {
let current_period_start = current_values
.get(0)
.expect("Aggregation state contains no values when at least one is expected.")
.start;

// Determine if the value is in the current period
if agg_period.is_date_in_period(&current_period_start, &value.start) {
// New value in the current aggregation period; just append it.
current_values.push(value);

None
} else {
// New value is part of a different period (assume the next one).

// Calculate the aggregated value of the previous period.
let agg_period = if let Some(agg_value) = agg_func.calc(&current_values) {
let agg_duration = value.start - current_period_start;
Some(PeriodValue::new(current_period_start, agg_duration, agg_value))
} else {
None
};

// Reset the state for the next period
current_values.clear();
current_values.push(value);

// Finally return the aggregated value from the previous period
agg_period
}
} else {
// No previous values defined; just append the value
self.current_values = Some(vec![value]);

None
}
}

// fn calc_aggregation(&self, agg_func: &AggregationFunction) -> f64 {
// match agg_func
// }
}

#[derive(Clone, Debug)]
pub struct PeriodicAggregator {
period: AggregationPeriod,
function: AggregationFunction,
}

#[derive(Debug, Copy, Clone)]
pub struct PeriodValue {
start: Date,
duration: Duration,
value: f64,
}

impl PeriodValue {
pub fn new(start: Date, duration: Duration, value: f64) -> Self {
Self { start, duration, value }
}
}

impl PeriodicAggregator {
/// Append a new value to the aggregator.
///
/// The new value should sequentially follow from the previously processed values. If the
/// value completes a new aggregation period then a value representing that aggregation is
/// returned.
pub fn process_value(
&self,
current_state: &mut PeriodicAggregatorState,
value: PeriodValue,
) -> Option<PeriodValue> {
// Split the given period into separate periods that align with the aggregation period.
let mut agg_value = None;

for v in self.period.split_value_into_periods(value) {
let av = current_state.process_value(v, &self.period, &self.function);
if av.is_some() {
if agg_value.is_some() {
panic!("Multiple aggregated values yielded from aggregator. This indicates that the given value spans multiple aggregation periods which is not supported.")
}
agg_value = av;
}
}

agg_value
}
}

#[cfg(test)]
mod tests {
use super::{AggregationFunction, AggregationPeriod, PeriodicAggregator, PeriodicAggregatorState};
use crate::recorders::aggregator::PeriodValue;
use time::macros::date;
use time::Duration;

#[test]
fn test_aggregator() {
let agg = PeriodicAggregator {
period: AggregationPeriod::Monthly,
function: AggregationFunction::Sum,
};

let mut state = PeriodicAggregatorState::default();

let agg_value = agg.process_value(
&mut state,
PeriodValue::new(date!(2023 - 01 - 30), Duration::days(1), 1.0),
);
assert!(agg_value.is_none());

let agg_value = agg.process_value(
&mut state,
PeriodValue::new(date!(2023 - 01 - 31), Duration::days(1), 1.0),
);
assert!(agg_value.is_none());

let agg_value = agg.process_value(
&mut state,
PeriodValue::new(date!(2023 - 02 - 01), Duration::days(1), 1.0),
);
assert!(agg_value.is_some());

let agg_value = agg.process_value(
&mut state,
PeriodValue::new(date!(2023 - 02 - 02), Duration::days(1), 1.0),
);
assert!(agg_value.is_none());
}
}
Loading

0 comments on commit 7fcc506

Please sign in to comment.