From a4b33d40c817496ce27e3c3f99f7a7e858ac3fe9 Mon Sep 17 00:00:00 2001 From: James Tomlinson Date: Fri, 6 Dec 2024 12:27:29 +0000 Subject: [PATCH 1/2] WIP IndexMetric --- pywr-core/src/metric.rs | 24 ++ pywr-core/src/network.rs | 8 + pywr-core/src/parameters/array.rs | 126 +++++++--- .../parameters/control_curves/piecewise.rs | 2 +- pywr-core/src/parameters/delay.rs | 2 +- pywr-core/src/parameters/discount_factor.rs | 2 +- pywr-core/src/parameters/mod.rs | 17 +- pywr-core/src/state.rs | 32 ++- pywr-core/src/test_utils.rs | 4 +- pywr-core/src/timestep.rs | 2 +- pywr-schema/src/error.rs | 6 +- pywr-schema/src/metric.rs | 221 ++++++++++++++++-- pywr-schema/src/parameters/aggregated.rs | 8 +- .../src/parameters/asymmetric_switch.rs | 11 +- pywr-schema/src/parameters/control_curves.rs | 8 +- pywr-schema/src/parameters/indexed_array.rs | 8 +- pywr-schema/src/parameters/mod.rs | 91 +------- pywr-schema/src/parameters/python.rs | 8 +- pywr-schema/src/parameters/tables.rs | 4 +- pywr-schema/src/timeseries/mod.rs | 117 +++++++++- pywr-schema/src/visit.rs | 9 +- 21 files changed, 525 insertions(+), 185 deletions(-) diff --git a/pywr-core/src/metric.rs b/pywr-core/src/metric.rs index 02396577..6116b154 100644 --- a/pywr-core/src/metric.rs +++ b/pywr-core/src/metric.rs @@ -265,6 +265,18 @@ impl From<(ParameterIndex, String)> for MetricF64 { } } +impl From<(ParameterIndex, String)> for MetricUsize { + fn from((idx, key): (ParameterIndex, String)) -> Self { + match idx { + ParameterIndex::General(idx) => Self::MultiParameterValue((idx, key)), + ParameterIndex::Simple(idx) => Self::Simple(SimpleMetricUsize::MultiParameterValue((idx, key))), + ParameterIndex::Const(idx) => Self::Simple(SimpleMetricUsize::Constant( + ConstantMetricUsize::MultiParameterValue((idx, key)), + )), + } + } +} + impl TryFrom> for SimpleMetricF64 { type Error = PywrError; fn try_from(idx: ParameterIndex) -> Result { @@ -288,6 +300,7 @@ impl TryFrom> for SimpleMetricUsize { #[derive(Clone, Debug, PartialEq)] pub enum ConstantMetricUsize { IndexParameterValue(ConstParameterIndex), + MultiParameterValue((ConstParameterIndex, String)), Constant(usize), } @@ -295,6 +308,9 @@ impl ConstantMetricUsize { pub fn get_value(&self, values: &ConstParameterValues) -> Result { match self { ConstantMetricUsize::IndexParameterValue(idx) => values.get_const_parameter_usize(*idx), + ConstantMetricUsize::MultiParameterValue((idx, key)) => { + Ok(values.get_const_multi_parameter_usize(*idx, key)?) + } ConstantMetricUsize::Constant(v) => Ok(*v), } } @@ -303,6 +319,7 @@ impl ConstantMetricUsize { #[derive(Clone, Debug, PartialEq)] pub enum SimpleMetricUsize { IndexParameterValue(SimpleParameterIndex), + MultiParameterValue((SimpleParameterIndex, String)), Constant(ConstantMetricUsize), } @@ -310,6 +327,9 @@ impl SimpleMetricUsize { pub fn get_value(&self, values: &SimpleParameterValues) -> Result { match self { SimpleMetricUsize::IndexParameterValue(idx) => values.get_simple_parameter_usize(*idx), + SimpleMetricUsize::MultiParameterValue((idx, key)) => { + Ok(values.get_simple_multi_parameter_usize(*idx, key)?) + } SimpleMetricUsize::Constant(m) => m.get_value(values.get_constant_values()), } } @@ -319,13 +339,17 @@ impl SimpleMetricUsize { pub enum MetricUsize { IndexParameterValue(GeneralParameterIndex), Simple(SimpleMetricUsize), + MultiParameterValue((GeneralParameterIndex, String)), + InterNetworkTransfer(MultiNetworkTransferIndex), } impl MetricUsize { pub fn get_value(&self, _network: &Network, state: &State) -> Result { match self { Self::IndexParameterValue(idx) => state.get_parameter_index(*idx), + Self::MultiParameterValue((idx, key)) => Ok(state.get_multi_parameter_index(*idx, key)?), Self::Simple(s) => s.get_value(&state.get_simple_parameter_values()), + Self::InterNetworkTransfer(_idx) => todo!("Support usize for inter-network transfers"), } } } diff --git a/pywr-core/src/network.rs b/pywr-core/src/network.rs index 98a1d27b..076c1341 100644 --- a/pywr-core/src/network.rs +++ b/pywr-core/src/network.rs @@ -1363,6 +1363,14 @@ impl Network { self.parameters.add_simple_f64(parameter) } + /// Add a [`parameters::SimpleParameter`] to the network + pub fn add_simple_index_parameter( + &mut self, + parameter: Box>, + ) -> Result, PywrError> { + self.parameters.add_simple_usize(parameter) + } + /// Add a [`parameters::ConstParameter`] to the network pub fn add_const_parameter( &mut self, diff --git a/pywr-core/src/parameters/array.rs b/pywr-core/src/parameters/array.rs index 2163f109..b69a675d 100644 --- a/pywr-core/src/parameters/array.rs +++ b/pywr-core/src/parameters/array.rs @@ -1,44 +1,55 @@ -use crate::network::Network; -use crate::parameters::{GeneralParameter, Parameter, ParameterMeta, ParameterName, ParameterState}; +use crate::parameters::{Parameter, ParameterMeta, ParameterName, ParameterState, SimpleParameter}; use crate::scenario::ScenarioIndex; -use crate::state::State; -use crate::timestep::Timestep; +use crate::state::SimpleParameterValues; +use crate::timestep::{Timestep, TimestepIndex}; use crate::PywrError; use ndarray::{Array1, Array2, Axis}; -pub struct Array1Parameter { +pub struct Array1Parameter { meta: ParameterMeta, - array: Array1, + array: Array1, timestep_offset: Option, } -impl Array1Parameter { - pub fn new(name: ParameterName, array: Array1, timestep_offset: Option) -> Self { +impl Array1Parameter { + pub fn new(name: ParameterName, array: Array1, timestep_offset: Option) -> Self { Self { meta: ParameterMeta::new(name), array, timestep_offset, } } + + /// Compute the time-step index to use accounting for any defined offset. + /// + /// The offset is applied to the time-step index and then clamped to the bounds of the array. + /// This ensures that the time-step index is always within the bounds of the array. + fn timestep_index(&self, timestep: &Timestep) -> TimestepIndex { + match self.timestep_offset { + None => timestep.index, + Some(offset) => (timestep.index as i32 + offset) + .max(0) + .min(self.array.len_of(Axis(0)) as i32 - 1) as usize, + } + } } -impl Parameter for Array1Parameter { +impl Parameter for Array1Parameter +where + T: Send + Sync + Clone, +{ fn meta(&self) -> &ParameterMeta { &self.meta } } -impl GeneralParameter for Array1Parameter { +impl SimpleParameter for Array1Parameter { fn compute( &self, timestep: &Timestep, _scenario_index: &ScenarioIndex, - _model: &Network, - _state: &State, + _values: &SimpleParameterValues, _internal_state: &mut Option>, ) -> Result { - let idx = match self.timestep_offset { - None => timestep.index, - Some(offset) => (timestep.index as i32 + offset).max(0).min(self.array.len() as i32 - 1) as usize, - }; + let idx = self.timestep_index(timestep); // This panics if out-of-bounds let value = self.array[[idx]]; Ok(value) @@ -52,17 +63,39 @@ impl GeneralParameter for Array1Parameter { } } -pub struct Array2Parameter { +impl SimpleParameter for Array1Parameter { + fn compute( + &self, + timestep: &Timestep, + _scenario_index: &ScenarioIndex, + _values: &SimpleParameterValues, + _internal_state: &mut Option>, + ) -> Result { + let idx = self.timestep_index(timestep); + // This panics if out-of-bounds + let value = self.array[[idx]]; + Ok(value as usize) + } + + fn as_parameter(&self) -> &dyn Parameter + where + Self: Sized, + { + self + } +} + +pub struct Array2Parameter { meta: ParameterMeta, - array: Array2, + array: Array2, scenario_group_index: usize, timestep_offset: Option, } -impl Array2Parameter { +impl Array2Parameter { pub fn new( name: ParameterName, - array: Array2, + array: Array2, scenario_group_index: usize, timestep_offset: Option, ) -> Self { @@ -73,30 +106,40 @@ impl Array2Parameter { timestep_offset, } } + + /// Compute the time-step index to use accounting for any defined offset. + /// + /// The offset is applied to the time-step index and then clamped to the bounds of the array. + /// This ensures that the time-step index is always within the bounds of the array. + fn timestep_index(&self, timestep: &Timestep) -> TimestepIndex { + match self.timestep_offset { + None => timestep.index, + Some(offset) => (timestep.index as i32 + offset) + .max(0) + .min(self.array.len_of(Axis(0)) as i32 - 1) as usize, + } + } } -impl Parameter for Array2Parameter { +impl Parameter for Array2Parameter +where + T: Send + Sync + Clone, +{ fn meta(&self) -> &ParameterMeta { &self.meta } } -impl GeneralParameter for Array2Parameter { +impl SimpleParameter for Array2Parameter { fn compute( &self, timestep: &Timestep, scenario_index: &ScenarioIndex, - _model: &Network, - _state: &State, + _values: &SimpleParameterValues, _internal_state: &mut Option>, ) -> Result { // This panics if out-of-bounds - let t_idx = match self.timestep_offset { - None => timestep.index, - Some(offset) => (timestep.index as i32 + offset) - .max(0) - .min(self.array.len_of(Axis(0)) as i32 - 1) as usize, - }; + let t_idx = self.timestep_index(timestep); let s_idx = scenario_index.indices[self.scenario_group_index]; Ok(self.array[[t_idx, s_idx]]) @@ -109,3 +152,26 @@ impl GeneralParameter for Array2Parameter { self } } + +impl SimpleParameter for Array2Parameter { + fn compute( + &self, + timestep: &Timestep, + scenario_index: &ScenarioIndex, + _values: &SimpleParameterValues, + _internal_state: &mut Option>, + ) -> Result { + // This panics if out-of-bounds + let t_idx = self.timestep_index(timestep); + let s_idx = scenario_index.indices[self.scenario_group_index]; + + Ok(self.array[[t_idx, s_idx]] as usize) + } + + fn as_parameter(&self) -> &dyn Parameter + where + Self: Sized, + { + self + } +} diff --git a/pywr-core/src/parameters/control_curves/piecewise.rs b/pywr-core/src/parameters/control_curves/piecewise.rs index 5a548c28..70e02a86 100644 --- a/pywr-core/src/parameters/control_curves/piecewise.rs +++ b/pywr-core/src/parameters/control_curves/piecewise.rs @@ -87,7 +87,7 @@ mod test { // Create an artificial volume series to use for the interpolation test let volume = Array1Parameter::new("test-x".into(), Array1::linspace(1.0, 0.0, 21), None); - let volume_idx = model.network_mut().add_parameter(Box::new(volume)).unwrap(); + let volume_idx = model.network_mut().add_simple_parameter(Box::new(volume)).unwrap(); let parameter = PiecewiseInterpolatedParameter::new( "test-parameter".into(), diff --git a/pywr-core/src/parameters/delay.rs b/pywr-core/src/parameters/delay.rs index bc828fa4..9ea63385 100644 --- a/pywr-core/src/parameters/delay.rs +++ b/pywr-core/src/parameters/delay.rs @@ -176,7 +176,7 @@ mod test { let volumes = Array1::linspace(1.0, 0.0, 21); let volume = Array1Parameter::new("test-x".into(), volumes.clone(), None); - let volume_idx = model.network_mut().add_parameter(Box::new(volume)).unwrap(); + let volume_idx = model.network_mut().add_simple_parameter(Box::new(volume)).unwrap(); const DELAY: usize = 3; // 3 time-step delay let parameter = DelayParameter::new( diff --git a/pywr-core/src/parameters/discount_factor.rs b/pywr-core/src/parameters/discount_factor.rs index ce8de25c..b9ab6e55 100644 --- a/pywr-core/src/parameters/discount_factor.rs +++ b/pywr-core/src/parameters/discount_factor.rs @@ -68,7 +68,7 @@ mod test { let volumes = Array1::linspace(1.0, 0.0, 21); let volume = Array1Parameter::new("test-x".into(), volumes.clone(), None); - let _volume_idx = network.add_parameter(Box::new(volume)).unwrap(); + let _volume_idx = network.add_simple_parameter(Box::new(volume)).unwrap(); let parameter = DiscountFactorParameter::new( "test-parameter".into(), diff --git a/pywr-core/src/parameters/mod.rs b/pywr-core/src/parameters/mod.rs index 427db3e6..70eeb947 100644 --- a/pywr-core/src/parameters/mod.rs +++ b/pywr-core/src/parameters/mod.rs @@ -1008,7 +1008,7 @@ impl ParameterCollection { } match parameter.try_into_simple() { - Some(simple) => self.add_simple_usize(simple).map(|idx| idx.into()), + Some(simple) => self.add_simple_usize(simple), None => { let index = GeneralParameterIndex::new(self.general_usize.len()); self.general_usize.push(parameter); @@ -1020,17 +1020,22 @@ impl ParameterCollection { pub fn add_simple_usize( &mut self, parameter: Box>, - ) -> Result, PywrError> { + ) -> Result, PywrError> { if self.has_name(parameter.name()) { return Err(PywrError::ParameterNameAlreadyExists(parameter.meta().name.to_string())); } - let index = SimpleParameterIndex::new(self.simple_usize.len()); + match parameter.try_into_const() { + Some(constant) => self.add_const_usize(constant), + None => { + let index = SimpleParameterIndex::new(self.simple_f64.len()); - self.simple_usize.push(parameter); - self.simple_resolve_order.push(SimpleParameterType::Index(index)); + self.simple_usize.push(parameter); + self.simple_resolve_order.push(SimpleParameterType::Index(index)); - Ok(index) + Ok(index.into()) + } + } } pub fn add_const_usize( diff --git a/pywr-core/src/state.rs b/pywr-core/src/state.rs index 88a076c7..5cb66c4c 100644 --- a/pywr-core/src/state.rs +++ b/pywr-core/src/state.rs @@ -416,7 +416,7 @@ pub struct ParameterValuesRef<'a> { multi_values: &'a [MultiValue], } -impl<'a> ParameterValuesRef<'a> { +impl ParameterValuesRef<'_> { fn get_value(&self, idx: usize) -> Option<&f64> { self.values.get(idx) } @@ -428,6 +428,10 @@ impl<'a> ParameterValuesRef<'a> { fn get_multi_value(&self, idx: usize, key: &str) -> Option<&f64> { self.multi_values.get(idx).and_then(|s| s.get_value(key)) } + + fn get_multi_index(&self, idx: usize, key: &str) -> Option<&usize> { + self.multi_values.get(idx).and_then(|s| s.get_index(key)) + } } pub struct SimpleParameterValues<'a> { @@ -435,7 +439,7 @@ pub struct SimpleParameterValues<'a> { simple: ParameterValuesRef<'a>, } -impl<'a> SimpleParameterValues<'a> { +impl SimpleParameterValues<'_> { pub fn get_simple_parameter_f64(&self, idx: SimpleParameterIndex) -> Result { self.simple .get_value(*idx.deref()) @@ -461,6 +465,17 @@ impl<'a> SimpleParameterValues<'a> { .copied() } + pub fn get_simple_multi_parameter_usize( + &self, + idx: SimpleParameterIndex, + key: &str, + ) -> Result { + self.simple + .get_multi_index(*idx.deref(), key) + .ok_or(PywrError::SimpleMultiValueParameterIndexNotFound(idx)) + .copied() + } + pub fn get_constant_values(&self) -> &ConstParameterValues { &self.constant } @@ -470,7 +485,7 @@ pub struct ConstParameterValues<'a> { constant: ParameterValuesRef<'a>, } -impl<'a> ConstParameterValues<'a> { +impl ConstParameterValues<'_> { pub fn get_const_parameter_f64(&self, idx: ConstParameterIndex) -> Result { self.constant .get_value(*idx.deref()) @@ -495,6 +510,17 @@ impl<'a> ConstParameterValues<'a> { .ok_or(PywrError::ConstMultiValueParameterIndexNotFound(idx)) .copied() } + + pub fn get_const_multi_parameter_usize( + &self, + idx: ConstParameterIndex, + key: &str, + ) -> Result { + self.constant + .get_multi_index(*idx.deref(), key) + .ok_or(PywrError::ConstMultiValueParameterIndexNotFound(idx)) + .copied() + } } // State of the nodes and edges diff --git a/pywr-core/src/test_utils.rs b/pywr-core/src/test_utils.rs index e0c89930..6f571371 100644 --- a/pywr-core/src/test_utils.rs +++ b/pywr-core/src/test_utils.rs @@ -64,7 +64,7 @@ pub fn simple_network(network: &mut Network, inflow_scenario_index: usize, num_i let inflow = Array::from_shape_fn((366, num_inflow_scenarios), |(i, j)| 1.0 + i as f64 + j as f64); let inflow = Array2Parameter::new("inflow".into(), inflow, inflow_scenario_index, None); - let inflow = network.add_parameter(Box::new(inflow)).unwrap(); + let inflow = network.add_simple_parameter(Box::new(inflow)).unwrap(); let input_node = network.get_mut_node_by_name("input", None).unwrap(); input_node.set_max_flow_constraint(Some(inflow.into())).unwrap(); @@ -287,7 +287,7 @@ fn make_simple_system( inflow_scenario_group_index, None, ); - let idx = network.add_parameter(Box::new(inflow))?; + let idx = network.add_simple_parameter(Box::new(inflow))?; network.set_node_max_flow("input", Some(suffix), Some(idx.into()))?; diff --git a/pywr-core/src/timestep.rs b/pywr-core/src/timestep.rs index 028e9c20..c7391198 100644 --- a/pywr-core/src/timestep.rs +++ b/pywr-core/src/timestep.rs @@ -96,7 +96,7 @@ impl PywrDuration { } } -type TimestepIndex = usize; +pub type TimestepIndex = usize; #[pyclass] #[derive(Debug, Copy, Clone)] diff --git a/pywr-schema/src/error.rs b/pywr-schema/src/error.rs index a52a3130..5bb9d5f7 100644 --- a/pywr-schema/src/error.rs +++ b/pywr-schema/src/error.rs @@ -17,8 +17,10 @@ pub enum SchemaError { name: String, attr: NodeAttribute, }, - #[error("parameter {0} not found")] + #[error("Parameter `{0}` not found")] ParameterNotFound(String), + #[error("Expected an index parameter, but found a regular parameter: {0}")] + IndexParameterExpected(String), #[error("Loading a local parameter reference (name: {0}) requires a parent name space.")] LocalParameterReferenceRequiresParent(String), #[error("network {0} not found")] @@ -44,8 +46,6 @@ pub enum SchemaError { HDF5Error(String), #[error("Missing metric set: {0}")] MissingMetricSet(String), - #[error("unexpected parameter type: {0}")] - UnexpectedParameterType(String), #[error("mismatch in the length of data provided. expected: {expected}, found: {found}")] DataLengthMismatch { expected: usize, found: usize }, #[error("Failed to estimate epsilon for use in the radial basis function.")] diff --git a/pywr-schema/src/metric.rs b/pywr-schema/src/metric.rs index 167279f1..2d1bbc2b 100644 --- a/pywr-schema/src/metric.rs +++ b/pywr-schema/src/metric.rs @@ -17,7 +17,10 @@ use crate::v1::{ConversionData, TryFromV1, TryIntoV2}; use crate::ConversionError; #[cfg(feature = "core")] use pywr_core::{ - metric::MetricF64, models::MultiNetworkTransferIndex, parameters::ParameterName, recorders::OutputMetric, + metric::{MetricF64, MetricUsize}, + models::MultiNetworkTransferIndex, + parameters::ParameterName, + recorders::OutputMetric, }; use pywr_schema_macros::PywrVisitAll; use pywr_v1_schema::parameters::ParameterValue as ParameterValueV1; @@ -25,23 +28,33 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use strum_macros::Display; -/// Output metrics that can be recorded from a model run. +/// A floating point value representing different model metrics. +/// +/// Metrics can be used in various places in a model to create dynamic behaviour. For example, +/// parameter can use an arbitrary [`Metric`] for its calculation giving the user the ability +/// to configure the source of that value. Therefore, metrics are the primary way in which +/// dynamic behaviour is created. +/// +/// See also [`IndexMetric`] for integer values. #[derive(Deserialize, Serialize, Clone, Debug, Display, JsonSchema)] #[serde(tag = "type")] pub enum Metric { - Constant { - value: f64, - }, + /// A constant floating point value. + Constant { value: f64 }, + /// A reference to a constant value in a table. Table(TableDataRef), /// An attribute of a node. Node(NodeReference), + /// An attribute of an edge. Edge(EdgeReference), + /// A reference to a value from a timeseries. Timeseries(TimeseriesReference), + /// A reference to a global parameter. Parameter(ParameterReference), + /// A reference to a local parameter. LocalParameter(ParameterReference), - InterNetworkTransfer { - name: String, - }, + /// A reference to an inter-network transfer by name. + InterNetworkTransfer { name: String }, } impl Default for Metric { @@ -65,9 +78,9 @@ impl Metric { parent: Option<&str>, ) -> Result { match self { - Self::Node(node_ref) => node_ref.load(network, args), + Self::Node(node_ref) => node_ref.load_f64(network, args), // Global parameter with no parent - Self::Parameter(parameter_ref) => parameter_ref.load(network, None), + Self::Parameter(parameter_ref) => parameter_ref.load_f64(network, None), // Local parameter loaded from parent's namespace Self::LocalParameter(parameter_ref) => { if parent.is_none() { @@ -76,7 +89,7 @@ impl Metric { )); } - parameter_ref.load(network, parent) + parameter_ref.load_f64(network, parent) } Self::Constant { value } => Ok((*value).into()), Self::Table(table_ref) => { @@ -93,13 +106,13 @@ impl Metric { let param_idx = match &ts_ref.columns { Some(TimeseriesColumns::Scenario(scenario)) => { args.timeseries - .load_df(network, ts_ref.name.as_ref(), args.domain, scenario.as_str())? + .load_df_f64(network, ts_ref.name.as_ref(), args.domain, scenario.as_str())? } Some(TimeseriesColumns::Column(col)) => { args.timeseries - .load_column(network, ts_ref.name.as_ref(), col.as_str())? + .load_column_f64(network, ts_ref.name.as_ref(), col.as_str())? } - None => args.timeseries.load_single_column(network, ts_ref.name.as_ref())?, + None => args.timeseries.load_single_column_f64(network, ts_ref.name.as_ref())?, }; Ok(param_idx.into()) } @@ -234,8 +247,13 @@ impl NodeReference { Self { name, attribute } } + /// Load a node reference into a [`MetricF64`]. #[cfg(feature = "core")] - pub fn load(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result { + pub fn load_f64( + &self, + network: &mut pywr_core::network::Network, + args: &LoadArgs, + ) -> Result { // This is the associated node in the schema let node = args .schema @@ -245,6 +263,22 @@ impl NodeReference { node.create_metric(network, self.attribute, args) } + /// Load a node reference into a [`MetricUsize`]. + #[cfg(feature = "core")] + pub fn load_usize( + &self, + _network: &mut pywr_core::network::Network, + args: &LoadArgs, + ) -> Result { + // This is the associated node in the schema + let _node = args + .schema + .get_node_by_name(&self.name) + .ok_or_else(|| SchemaError::NodeNotFound(self.name.clone()))?; + + todo!("Support usize attributes on nodes.") + } + /// Return the attribute of the node. If the attribute is not specified then the default /// attribute of the node is returned. Note that this does not check if the attribute is /// valid for the node. @@ -345,7 +379,7 @@ impl ParameterReference { /// from the `network`. If `parent` is the optional parameter name space from which to load /// the parameter. #[cfg(feature = "core")] - pub fn load( + pub fn load_f64( &self, network: &mut pywr_core::network::Network, parent: Option<&str>, @@ -369,6 +403,34 @@ impl ParameterReference { } } + /// Load a parameter reference into a [`MetricUsize`] by attempting to retrieve the parameter + /// from the `network`. If `parent` is the optional parameter name space from which to load + /// the parameter. + #[cfg(feature = "core")] + pub fn load_usize( + &self, + network: &mut pywr_core::network::Network, + parent: Option<&str>, + ) -> Result { + let name = ParameterName::new(&self.name, parent); + + match &self.key { + Some(key) => { + // Key given; this should be a multi-valued parameter + Ok((network.get_multi_valued_parameter_index_by_name(&name)?, key.clone()).into()) + } + None => { + if let Ok(idx) = network.get_index_parameter_index_by_name(&name) { + Ok(idx.into()) + } else if network.get_parameter_index_by_name(&name).is_ok() { + // Inform the user we found the parameter, but it was the wrong type + Err(SchemaError::IndexParameterExpected(self.name.to_string())) + } else { + Err(SchemaError::ParameterNotFound(self.name.to_string())) + } + } + } + } #[cfg(feature = "core")] pub fn parameter_type(&self, args: &LoadArgs) -> Result { let parameter = args @@ -394,3 +456,130 @@ impl EdgeReference { self.edge.create_metric(network, args) } } + +/// An unsigned integer value representing different model metrics. +/// +/// This struct is the integer equivalent of [`Metric`] and is used in places where an integer +/// value is required. See [`Metric`] for more information. +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, Display)] +#[serde(untagged)] +pub enum IndexMetric { + Constant { + value: usize, + }, + Table(TableDataRef), + /// An attribute of a node. + Node(NodeReference), + Timeseries(TimeseriesReference), + Parameter(ParameterReference), + LocalParameter(ParameterReference), + InterNetworkTransfer { + name: String, + }, +} + +impl IndexMetric { + pub fn from_usize(v: usize) -> Self { + Self::Constant { value: v } + } +} + +#[cfg(feature = "core")] +impl IndexMetric { + pub fn load( + &self, + network: &mut pywr_core::network::Network, + args: &LoadArgs, + parent: Option<&str>, + ) -> Result { + match self { + Self::Node(node_ref) => node_ref.load_usize(network, args), + // Global parameter with no parent + Self::Parameter(parameter_ref) => parameter_ref.load_usize(network, None), + // Local parameter loaded from parent's namespace + Self::LocalParameter(parameter_ref) => { + if parent.is_none() { + return Err(SchemaError::LocalParameterReferenceRequiresParent( + parameter_ref.name.clone(), + )); + } + + parameter_ref.load_usize(network, parent) + } + Self::Constant { value } => Ok((*value).into()), + Self::Table(table_ref) => { + let value = args + .tables + .get_scalar_usize(table_ref) + .map_err(|error| SchemaError::TableRefLoad { + table_ref: table_ref.clone(), + error, + })?; + Ok(value.into()) + } + Self::Timeseries(ts_ref) => { + let param_idx = match &ts_ref.columns { + Some(TimeseriesColumns::Scenario(scenario)) => { + args.timeseries + .load_df_usize(network, ts_ref.name.as_ref(), args.domain, scenario.as_str())? + } + Some(TimeseriesColumns::Column(col)) => { + args.timeseries + .load_column_usize(network, ts_ref.name.as_ref(), col.as_str())? + } + None => args + .timeseries + .load_single_column_usize(network, ts_ref.name.as_ref())?, + }; + Ok(param_idx.into()) + } + Self::InterNetworkTransfer { name } => { + // Find the matching inter model transfer + match args.inter_network_transfers.iter().position(|t| &t.name == name) { + Some(idx) => Ok(MetricUsize::InterNetworkTransfer(MultiNetworkTransferIndex(idx))), + None => Err(SchemaError::InterNetworkTransferNotFound(name.to_string())), + } + } + } + } +} + +impl TryFromV1 for IndexMetric { + type Error = ConversionError; + + fn try_from_v1( + v1: ParameterValueV1, + parent_node: Option<&str>, + conversion_data: &mut ConversionData, + ) -> Result { + let p = match v1 { + // There was no such thing as s constant index in Pywr v1 + // TODO this could print a warning and do a cast to usize instead. + ParameterValueV1::Constant(_) => return Err(ConversionError::FloatToIndex), + ParameterValueV1::Reference(p_name) => Self::Parameter(ParameterReference { + name: p_name, + key: None, + }), + ParameterValueV1::Table(tbl) => Self::Table(tbl.try_into()?), + ParameterValueV1::Inline(param) => { + // Inline parameters are converted to either a parameter or a timeseries + // The actual component is extracted into the conversion data leaving a reference + // to the component in the metric. + let definition: ParameterOrTimeseriesRef = (*param).try_into_v2(parent_node, conversion_data)?; + match definition { + ParameterOrTimeseriesRef::Parameter(p) => { + let reference = ParameterReference { + name: p.name().to_string(), + key: None, + }; + conversion_data.parameters.push(*p); + + Self::Parameter(reference) + } + ParameterOrTimeseriesRef::Timeseries(t) => Self::Timeseries(t), + } + } + }; + Ok(p) + } +} diff --git a/pywr-schema/src/parameters/aggregated.rs b/pywr-schema/src/parameters/aggregated.rs index f2ea048d..aa91f12e 100644 --- a/pywr-schema/src/parameters/aggregated.rs +++ b/pywr-schema/src/parameters/aggregated.rs @@ -1,10 +1,10 @@ use crate::error::ConversionError; #[cfg(feature = "core")] use crate::error::SchemaError; -use crate::metric::Metric; +use crate::metric::{IndexMetric, Metric}; #[cfg(feature = "core")] use crate::model::LoadArgs; -use crate::parameters::{ConversionData, DynamicIndexValue, ParameterMeta}; +use crate::parameters::{ConversionData, ParameterMeta}; use crate::v1::{IntoV2, TryFromV1, TryIntoV2}; #[cfg(feature = "core")] use pywr_core::parameters::ParameterIndex; @@ -173,7 +173,7 @@ pub struct AggregatedIndexParameter { pub meta: ParameterMeta, pub agg_func: IndexAggFunc, // TODO this should be `DynamicIntValues` - pub parameters: Vec, + pub parameters: Vec, } impl AggregatedIndexParameter { @@ -201,7 +201,7 @@ impl AggregatedIndexParameter { let parameters = self .parameters .iter() - .map(|v| v.load(network, args)) + .map(|v| v.load(network, args, None)) .collect::, _>>()?; let p = pywr_core::parameters::AggregatedIndexParameter::new( diff --git a/pywr-schema/src/parameters/asymmetric_switch.rs b/pywr-schema/src/parameters/asymmetric_switch.rs index e128549c..474dfeb7 100644 --- a/pywr-schema/src/parameters/asymmetric_switch.rs +++ b/pywr-schema/src/parameters/asymmetric_switch.rs @@ -1,9 +1,10 @@ use crate::error::ConversionError; #[cfg(feature = "core")] use crate::error::SchemaError; +use crate::metric::IndexMetric; #[cfg(feature = "core")] use crate::model::LoadArgs; -use crate::parameters::{ConversionData, DynamicIndexValue, ParameterMeta}; +use crate::parameters::{ConversionData, ParameterMeta}; use crate::v1::{IntoV2, TryFromV1, TryIntoV2}; #[cfg(feature = "core")] use pywr_core::parameters::ParameterIndex; @@ -15,8 +16,8 @@ use schemars::JsonSchema; #[serde(deny_unknown_fields)] pub struct AsymmetricSwitchIndexParameter { pub meta: ParameterMeta, - pub on_index_parameter: DynamicIndexValue, - pub off_index_parameter: DynamicIndexValue, + pub on_index_parameter: IndexMetric, + pub off_index_parameter: IndexMetric, } #[cfg(feature = "core")] @@ -26,8 +27,8 @@ impl AsymmetricSwitchIndexParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let on_index_parameter = self.on_index_parameter.load(network, args)?; - let off_index_parameter = self.off_index_parameter.load(network, args)?; + let on_index_parameter = self.on_index_parameter.load(network, args, None)?; + let off_index_parameter = self.off_index_parameter.load(network, args, None)?; let p = pywr_core::parameters::AsymmetricSwitchIndexParameter::new( self.meta.name.as_str().into(), diff --git a/pywr-schema/src/parameters/control_curves.rs b/pywr-schema/src/parameters/control_curves.rs index 42e0836d..7748aacf 100644 --- a/pywr-schema/src/parameters/control_curves.rs +++ b/pywr-schema/src/parameters/control_curves.rs @@ -34,7 +34,7 @@ impl ControlCurveInterpolatedParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let metric = self.storage_node.load(network, args)?; + let metric = self.storage_node.load_f64(network, args)?; let control_curves = self .control_curves @@ -134,7 +134,7 @@ impl ControlCurveIndexParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let metric = self.storage_node.load(network, args)?; + let metric = self.storage_node.load_f64(network, args)?; let control_curves = self .control_curves @@ -246,7 +246,7 @@ impl ControlCurveParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let metric = self.storage_node.load(network, args)?; + let metric = self.storage_node.load_f64(network, args)?; let control_curves = self .control_curves @@ -342,7 +342,7 @@ impl ControlCurvePiecewiseInterpolatedParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let metric = self.storage_node.load(network, args)?; + let metric = self.storage_node.load_f64(network, args)?; let control_curves = self .control_curves diff --git a/pywr-schema/src/parameters/indexed_array.rs b/pywr-schema/src/parameters/indexed_array.rs index 153bc8e0..3a192a98 100644 --- a/pywr-schema/src/parameters/indexed_array.rs +++ b/pywr-schema/src/parameters/indexed_array.rs @@ -1,10 +1,10 @@ use crate::error::ConversionError; #[cfg(feature = "core")] use crate::error::SchemaError; -use crate::metric::Metric; +use crate::metric::{IndexMetric, Metric}; #[cfg(feature = "core")] use crate::model::LoadArgs; -use crate::parameters::{ConversionData, DynamicIndexValue, ParameterMeta}; +use crate::parameters::{ConversionData, ParameterMeta}; use crate::v1::{IntoV2, TryFromV1, TryIntoV2}; #[cfg(feature = "core")] use pywr_core::parameters::ParameterIndex; @@ -18,7 +18,7 @@ pub struct IndexedArrayParameter { pub meta: ParameterMeta, #[serde(alias = "params")] pub metrics: Vec, - pub index_parameter: DynamicIndexValue, + pub index_parameter: IndexMetric, } #[cfg(feature = "core")] @@ -28,7 +28,7 @@ impl IndexedArrayParameter { network: &mut pywr_core::network::Network, args: &LoadArgs, ) -> Result, SchemaError> { - let index_parameter = self.index_parameter.load(network, args)?; + let index_parameter = self.index_parameter.load(network, args, None)?; let metrics = self .metrics diff --git a/pywr-schema/src/parameters/mod.rs b/pywr-schema/src/parameters/mod.rs index 5f6b4226..72a5644b 100644 --- a/pywr-schema/src/parameters/mod.rs +++ b/pywr-schema/src/parameters/mod.rs @@ -29,7 +29,7 @@ pub use super::data_tables::TableDataRef; use crate::error::ConversionError; #[cfg(feature = "core")] use crate::error::SchemaError; -use crate::metric::{Metric, ParameterReference}; +use crate::metric::Metric; #[cfg(feature = "core")] use crate::model::LoadArgs; use crate::timeseries::TimeseriesReference; @@ -59,8 +59,6 @@ pub use profiles::{ #[cfg(feature = "core")] pub use python::try_json_value_into_py; pub use python::{PythonParameter, PythonReturnType, PythonSource}; -#[cfg(feature = "core")] -use pywr_core::{metric::MetricUsize, parameters::ParameterIndex}; use pywr_schema_macros::PywrVisitAll; use pywr_v1_schema::parameters::{ CoreParameter, DataFrameParameter as DataFrameParameterV1, Parameter as ParameterV1, @@ -617,93 +615,6 @@ impl TryFrom for ConstantValue { } } -/// An integer (i64) value from another parameter -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, Display)] -#[serde(untagged)] -pub enum ParameterIndexValue { - Reference(String), -} - -#[cfg(feature = "core")] -impl ParameterIndexValue { - pub fn load(&self, network: &mut pywr_core::network::Network) -> Result, SchemaError> { - match self { - Self::Reference(name) => { - // This should be an existing parameter - Ok(network.get_index_parameter_index_by_name(&name.as_str().into())?) - } - } - } -} - -/// A potentially dynamic integer (usize) value -/// -/// This value can be a constant (literal or otherwise) or a dynamic value provided -/// by another parameter. -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, Display)] -#[serde(untagged)] -pub enum DynamicIndexValue { - Constant(ConstantValue), - Dynamic(ParameterIndexValue), -} - -impl DynamicIndexValue { - pub fn from_usize(v: usize) -> Self { - Self::Constant(ConstantValue::Literal(v)) - } -} - -#[cfg(feature = "core")] -impl DynamicIndexValue { - pub fn load(&self, network: &mut pywr_core::network::Network, args: &LoadArgs) -> Result { - let parameter_ref = match self { - DynamicIndexValue::Constant(v) => v.load(args.tables)?.into(), - DynamicIndexValue::Dynamic(v) => v.load(network)?.into(), - }; - Ok(parameter_ref) - } -} - -impl TryFromV1 for DynamicIndexValue { - type Error = ConversionError; - - fn try_from_v1( - v1: ParameterValueV1, - parent_node: Option<&str>, - conversion_data: &mut ConversionData, - ) -> Result { - let p = match v1 { - // There was no such thing as s constant index in Pywr v1 - // TODO this could print a warning and do a cast to usize instead. - ParameterValueV1::Constant(_) => return Err(ConversionError::FloatToIndex), - ParameterValueV1::Reference(p_name) => Self::Dynamic(ParameterIndexValue::Reference(p_name)), - ParameterValueV1::Table(tbl) => Self::Constant(ConstantValue::Table(tbl.try_into()?)), - ParameterValueV1::Inline(param) => { - // Inline parameters are converted to either a parameter or a timeseries - // The actual component is extracted into the conversion data leaving a reference - // to the component in the metric. - let definition: ParameterOrTimeseriesRef = (*param).try_into_v2(parent_node, conversion_data)?; - match definition { - ParameterOrTimeseriesRef::Parameter(p) => { - let reference = ParameterReference { - name: p.name().to_string(), - key: None, - }; - conversion_data.parameters.push(*p); - - Self::Dynamic(ParameterIndexValue::Reference(reference.name)) - } - ParameterOrTimeseriesRef::Timeseries(_) => { - // TODO create an error for this - panic!("Timeseries do not support indexes yet") - } - } - } - }; - Ok(p) - } -} - /// An non-variable vector of constant floating-point (f64) values /// /// This value can be a literal vector of floats or an external reference to an input table. diff --git a/pywr-schema/src/parameters/python.rs b/pywr-schema/src/parameters/python.rs index 9983c4c8..36be9942 100644 --- a/pywr-schema/src/parameters/python.rs +++ b/pywr-schema/src/parameters/python.rs @@ -2,10 +2,10 @@ use crate::data_tables::make_path; #[cfg(feature = "core")] use crate::error::SchemaError; -use crate::metric::Metric; +use crate::metric::{IndexMetric, Metric}; #[cfg(feature = "core")] use crate::model::LoadArgs; -use crate::parameters::{DynamicFloatValueType, DynamicIndexValue, ParameterMeta}; +use crate::parameters::{DynamicFloatValueType, ParameterMeta}; use crate::visit::{VisitMetrics, VisitPaths}; #[cfg(feature = "core")] use pyo3::prelude::{PyAnyMethods, PyModule}; @@ -103,7 +103,7 @@ pub struct PythonParameter { pub metrics: Option>, /// Index values to pass to the calculation method of the initialised object (i.e. /// indices that the Python calculation is dependent on). - pub indices: Option>, + pub indices: Option>, } #[cfg(feature = "core")] @@ -242,7 +242,7 @@ impl PythonParameter { let indices = match &self.indices { Some(indices) => indices .iter() - .map(|(k, v)| Ok((k.to_string(), v.load(network, args)?))) + .map(|(k, v)| Ok((k.to_string(), v.load(network, args, None)?))) .collect::, SchemaError>>()?, None => HashMap::new(), }; diff --git a/pywr-schema/src/parameters/tables.rs b/pywr-schema/src/parameters/tables.rs index 6919b446..9d31aa79 100644 --- a/pywr-schema/src/parameters/tables.rs +++ b/pywr-schema/src/parameters/tables.rs @@ -75,7 +75,7 @@ impl TablesArrayParameter { scenario_group_index, self.timestep_offset, ); - Ok(network.add_parameter(Box::new(p))?) + Ok(network.add_simple_parameter(Box::new(p))?) } else { let array = array.slice_move(s![.., 0]); let p = pywr_core::parameters::Array1Parameter::new( @@ -83,7 +83,7 @@ impl TablesArrayParameter { array, self.timestep_offset, ); - Ok(network.add_parameter(Box::new(p))?) + Ok(network.add_simple_parameter(Box::new(p))?) } } } diff --git a/pywr-schema/src/timeseries/mod.rs b/pywr-schema/src/timeseries/mod.rs index 244cac13..fb115cc9 100644 --- a/pywr-schema/src/timeseries/mod.rs +++ b/pywr-schema/src/timeseries/mod.rs @@ -13,7 +13,11 @@ pub use pandas::PandasDataset; #[cfg(feature = "core")] use polars::error::PolarsError; #[cfg(feature = "core")] -use polars::prelude::{DataFrame, DataType::Float64, Float64Type, IndexOrder}; +use polars::prelude::{ + DataFrame, + DataType::{Float64, UInt64}, + Float64Type, IndexOrder, UInt64Type, +}; pub use polars_dataset::PolarsDataset; #[cfg(feature = "core")] use pywr_core::{ @@ -129,7 +133,7 @@ impl LoadedTimeseriesCollection { Ok(Self { timeseries }) } - pub fn load_column( + pub fn load_column_f64( &self, network: &mut pywr_core::network::Network, name: &str, @@ -149,14 +153,41 @@ impl LoadedTimeseriesCollection { Err(e) => match e { PywrError::ParameterNotFound(_) => { let p = Array1Parameter::new(name, array, None); - Ok(network.add_parameter(Box::new(p))?) + Ok(network.add_simple_parameter(Box::new(p))?) + } + _ => Err(TimeseriesError::PywrCore(e)), + }, + } + } + + pub fn load_column_usize( + &self, + network: &mut pywr_core::network::Network, + name: &str, + col: &str, + ) -> Result, TimeseriesError> { + let df = self + .timeseries + .get(name) + .ok_or(TimeseriesError::TimeseriesNotFound(name.to_string()))?; + let series = df.column(col)?; + + let array = series.cast(&UInt64)?.u64()?.to_ndarray()?.to_owned(); + let name = ParameterName::new(col, Some(name)); + + match network.get_index_parameter_index_by_name(&name) { + Ok(idx) => Ok(idx), + Err(e) => match e { + PywrError::ParameterNotFound(_) => { + let p = Array1Parameter::new(name, array, None); + Ok(network.add_simple_index_parameter(Box::new(p))?) } _ => Err(TimeseriesError::PywrCore(e)), }, } } - pub fn load_single_column( + pub fn load_single_column_f64( &self, network: &mut pywr_core::network::Network, name: &str, @@ -187,14 +218,53 @@ impl LoadedTimeseriesCollection { Err(e) => match e { PywrError::ParameterNotFound(_) => { let p = Array1Parameter::new(name, array, None); - Ok(network.add_parameter(Box::new(p))?) + Ok(network.add_simple_parameter(Box::new(p))?) } _ => Err(TimeseriesError::PywrCore(e)), }, } } - pub fn load_df( + pub fn load_single_column_usize( + &self, + network: &mut pywr_core::network::Network, + name: &str, + ) -> Result, TimeseriesError> { + let df = self + .timeseries + .get(name) + .ok_or(TimeseriesError::TimeseriesNotFound(name.to_string()))?; + + let cols = df.get_column_names(); + + if cols.len() > 1 { + return Err(TimeseriesError::TimeseriesColumnOrScenarioRequired(name.to_string())); + }; + + let col = cols.first().ok_or(TimeseriesError::ColumnNotFound { + col: "".to_string(), + name: name.to_string(), + })?; + + let series = df.column(col)?; + + let array = series.cast(&UInt64)?.u64()?.to_ndarray()?.to_owned(); + let name = ParameterName::new(col, Some(name)); + + match network.get_index_parameter_index_by_name(&name) { + Ok(idx) => Ok(idx), + Err(e) => match e { + PywrError::ParameterNotFound(_) => { + let p = Array1Parameter::new(name, array, None); + Ok(network.add_simple_index_parameter(Box::new(p))?) + } + _ => Err(TimeseriesError::PywrCore(e)), + }, + } + } + + /// Load a timeseries dataframe as a 2D array F64 parameter. + pub fn load_df_f64( &self, network: &mut pywr_core::network::Network, name: &str, @@ -219,7 +289,40 @@ impl LoadedTimeseriesCollection { Err(e) => match e { PywrError::ParameterNotFound(_) => { let p = Array2Parameter::new(name, array, scenario_group_index, None); - Ok(network.add_parameter(Box::new(p))?) + Ok(network.add_simple_parameter(Box::new(p))?) + } + _ => Err(TimeseriesError::PywrCore(e)), + }, + } + } + + /// Load a timeseries dataframe as a 2D array Usize parameter. + pub fn load_df_usize( + &self, + network: &mut pywr_core::network::Network, + name: &str, + domain: &ModelDomain, + scenario: &str, + ) -> Result, TimeseriesError> { + let scenario_group_index = domain + .scenarios() + .group_index(scenario) + .ok_or(TimeseriesError::ScenarioGroupNotFound(scenario.to_string()))?; + + let df = self + .timeseries + .get(name) + .ok_or(TimeseriesError::TimeseriesNotFound(name.to_string()))?; + + let array: Array2 = df.to_ndarray::(IndexOrder::default()).unwrap(); + let name = ParameterName::new(scenario, Some(name)); + + match network.get_index_parameter_index_by_name(&name) { + Ok(idx) => Ok(idx), + Err(e) => match e { + PywrError::ParameterNotFound(_) => { + let p = Array2Parameter::new(name, array, scenario_group_index, None); + Ok(network.add_simple_index_parameter(Box::new(p))?) } _ => Err(TimeseriesError::PywrCore(e)), }, diff --git a/pywr-schema/src/visit.rs b/pywr-schema/src/visit.rs index 8c97ff04..c34616b8 100644 --- a/pywr-schema/src/visit.rs +++ b/pywr-schema/src/visit.rs @@ -1,4 +1,4 @@ -use crate::metric::Metric; +use crate::metric::{IndexMetric, Metric}; use std::collections::HashMap; use std::num::NonZeroUsize; use std::path::{Path, PathBuf}; @@ -26,6 +26,12 @@ impl VisitMetrics for Metric { } } +impl VisitMetrics for IndexMetric { + fn visit_metrics(&self, _visitor: &mut F) {} + + fn visit_metrics_mut(&mut self, _visitor: &mut F) {} +} + impl VisitMetrics for Option where T: VisitMetrics, @@ -125,6 +131,7 @@ pub trait VisitPaths { } impl VisitPaths for Metric {} +impl VisitPaths for IndexMetric {} impl VisitPaths for Option where From e0b7840af3fa56c4b83181f34be6c993190844b8 Mon Sep 17 00:00:00 2001 From: James Tomlinson Date: Mon, 6 Jan 2025 12:14:16 +0000 Subject: [PATCH 2/2] fix: Correct conversion of constant indices. --- pywr-schema/src/data_tables/mod.rs | 2 +- pywr-schema/src/edge.rs | 2 +- pywr-schema/src/metric.rs | 37 +++++++++++++++++++++++++----- pywr-schema/src/nodes/mod.rs | 2 +- pywr-schema/src/parameters/mod.rs | 2 +- pywr-schema/src/timeseries/mod.rs | 4 ++-- 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pywr-schema/src/data_tables/mod.rs b/pywr-schema/src/data_tables/mod.rs index 24c8f069..243765c8 100644 --- a/pywr-schema/src/data_tables/mod.rs +++ b/pywr-schema/src/data_tables/mod.rs @@ -255,7 +255,7 @@ impl LoadedTableCollection { } } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, PartialEq)] #[serde(deny_unknown_fields)] pub struct TableDataRef { pub table: String, diff --git a/pywr-schema/src/edge.rs b/pywr-schema/src/edge.rs index 1b6e1b10..c163129b 100644 --- a/pywr-schema/src/edge.rs +++ b/pywr-schema/src/edge.rs @@ -7,7 +7,7 @@ use pywr_core::{edge::EdgeIndex, metric::MetricF64, node::NodeIndex}; use schemars::JsonSchema; use std::fmt::{Display, Formatter}; -#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema, Debug)] +#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema, Debug, PartialEq)] pub struct Edge { pub from_node: String, pub to_node: String, diff --git a/pywr-schema/src/metric.rs b/pywr-schema/src/metric.rs index dd40fed6..535425fc 100644 --- a/pywr-schema/src/metric.rs +++ b/pywr-schema/src/metric.rs @@ -37,7 +37,7 @@ use strum_macros::Display; /// dynamic behaviour is created. /// /// See also [`IndexMetric`] for integer values. -#[derive(Deserialize, Serialize, Clone, Debug, Display, JsonSchema)] +#[derive(Deserialize, Serialize, Clone, Debug, Display, JsonSchema, PartialEq)] #[serde(tag = "type")] pub enum Metric { /// A constant floating point value. @@ -240,7 +240,7 @@ impl TryFromV1 for Metric { } /// A reference to a node with an optional attribute. -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, PartialEq)] #[serde(deny_unknown_fields)] pub struct NodeReference { /// The name of the node @@ -365,7 +365,7 @@ impl From for SimpleNodeReference { } } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PartialEq)] #[serde(deny_unknown_fields)] pub struct ParameterReference { /// The name of the parameter @@ -449,7 +449,7 @@ impl ParameterReference { } } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PartialEq)] #[serde(deny_unknown_fields)] pub struct EdgeReference { /// The edge referred to by this reference. @@ -468,7 +468,7 @@ impl EdgeReference { /// /// This struct is the integer equivalent of [`Metric`] and is used in places where an integer /// value is required. See [`Metric`] for more information. -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, Display)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, Display, PartialEq)] #[serde(untagged)] pub enum IndexMetric { Constant { @@ -564,7 +564,7 @@ impl TryFromV1 for IndexMetric { // TODO this could print a warning and do a cast to usize instead. ParameterValueV1::Constant(value) => { // Check if the value is not a whole non-negative number - if value.fract() != 0.0 && value >= 0.0 { + if value.fract() != 0.0 || value < 0.0 { return Err(ConversionError::FloatToIndex {}); } @@ -603,3 +603,28 @@ impl TryFromV1 for IndexMetric { Ok(p) } } + +#[cfg(test)] +mod test { + use super::{ConversionError, IndexMetric, ParameterValueV1, TryFromV1}; + + /// Test conversion of `ParameterValueV1::Constant` to `IndexMetric`. + #[test] + fn test_index_metric_try_from_v1_constant() { + let v1 = ParameterValueV1::Constant(0.0); + let result = IndexMetric::try_from_v1(v1, None, &mut Default::default()); + assert_eq!(result, Ok(IndexMetric::Constant { value: 0 })); + + let v1 = ParameterValueV1::Constant(1.0); + let result = IndexMetric::try_from_v1(v1, None, &mut Default::default()); + assert_eq!(result, Ok(IndexMetric::Constant { value: 1 })); + + let v1 = ParameterValueV1::Constant(1.5); + let result = IndexMetric::try_from_v1(v1, None, &mut Default::default()); + assert_eq!(result, Err(ConversionError::FloatToIndex {})); + + let v1 = ParameterValueV1::Constant(-1.0); + let result = IndexMetric::try_from_v1(v1, None, &mut Default::default()); + assert_eq!(result, Err(ConversionError::FloatToIndex {})); + } +} diff --git a/pywr-schema/src/nodes/mod.rs b/pywr-schema/src/nodes/mod.rs index b5018c44..3262d23d 100644 --- a/pywr-schema/src/nodes/mod.rs +++ b/pywr-schema/src/nodes/mod.rs @@ -89,7 +89,7 @@ impl From for NodeMeta { /// All possible attributes that could be produced by a node. /// /// -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Copy, Display, JsonSchema, PywrVisitAll)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Copy, Display, JsonSchema, PywrVisitAll, PartialEq)] pub enum NodeAttribute { Inflow, Outflow, diff --git a/pywr-schema/src/parameters/mod.rs b/pywr-schema/src/parameters/mod.rs index d0ee737b..7d8eca9b 100644 --- a/pywr-schema/src/parameters/mod.rs +++ b/pywr-schema/src/parameters/mod.rs @@ -638,7 +638,7 @@ impl ConstantFloatVec { } } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, Display)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PywrVisitAll, Display, PartialEq)] #[serde(untagged)] pub enum TableIndex { Single(String), diff --git a/pywr-schema/src/timeseries/mod.rs b/pywr-schema/src/timeseries/mod.rs index 10f363fd..b69f91e4 100644 --- a/pywr-schema/src/timeseries/mod.rs +++ b/pywr-schema/src/timeseries/mod.rs @@ -399,14 +399,14 @@ impl LoadedTimeseriesCollection { // ts.into_values().collect::>() // } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, strum_macros::Display)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, strum_macros::Display, PartialEq)] #[serde(tag = "type", content = "name")] pub enum TimeseriesColumns { Scenario(String), Column(String), } -#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)] +#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema, PartialEq)] #[serde(deny_unknown_fields)] pub struct TimeseriesReference { pub name: String,