Skip to content

Commit

Permalink
feat: Initial commit of 'outputs' and 'metric_sets' sections of schem…
Browse files Browse the repository at this point in the history
…a. (#23)
  • Loading branch information
jetuk authored Sep 21, 2023
1 parent 9f4f89e commit 9d31a3a
Show file tree
Hide file tree
Showing 35 changed files with 1,314 additions and 129 deletions.
16 changes: 12 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ extern crate core;

use crate::node::NodeIndex;
use crate::parameters::{IndexParameterIndex, MultiValueParameterIndex, ParameterIndex};
use crate::recorders::RecorderIndex;
use crate::recorders::{MetricSetIndex, RecorderIndex};
use crate::schema::ConversionError;
use thiserror::Error;

Expand Down Expand Up @@ -51,17 +51,23 @@ pub enum PywrError {
MultiValueParameterKeyNotFound(String),
#[error("parameter {0} not found")]
ParameterNotFound(String),
#[error("metric set index not found")]
MetricSetIndexNotFound,
#[error("metric set with name {0} not found")]
MetricSetNotFound(String),
#[error("recorder index not found")]
RecorderIndexNotFound,
#[error("recorder not found")]
RecorderNotFound,
#[error("node name `{0}` already exists")]
NodeNameAlreadyExists(String),
#[error("parameter name `{0}` already exists on parameter {1}")]
#[error("parameter name `{0}` already exists at index {1}")]
ParameterNameAlreadyExists(String, ParameterIndex),
#[error("index parameter name `{0}` already exists on index parameter {1}")]
#[error("index parameter name `{0}` already exists at index {1}")]
IndexParameterNameAlreadyExists(String, IndexParameterIndex),
#[error("recorder name `{0}` already exists on parameter {1}")]
#[error("metric set name `{0}` already exists")]
MetricSetNameAlreadyExists(String),
#[error("recorder name `{0}` already exists at index {1}")]
RecorderNameAlreadyExists(String, RecorderIndex),
#[error("connections from output nodes are invalid. node: {0}")]
InvalidNodeConnectionFromOutput(String),
Expand Down Expand Up @@ -111,6 +117,8 @@ pub enum PywrError {
RecorderNotInitialised,
#[error("hdf5 error: {0}")]
HDF5Error(String),
#[error("csv error: {0}")]
CSVError(String),
#[error("not implemented by recorder")]
NotSupportedByRecorder,
#[error("invalid constraint value: {0}")]
Expand Down
9 changes: 6 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ enum Commands {
solver: Solver,
#[arg(short, long)]
data_path: Option<PathBuf>,
#[arg(short, long)]
output_path: Option<PathBuf>,
/// Use multiple threads for simulation.
#[arg(short, long, default_value_t = false)]
parallel: bool,
Expand All @@ -89,6 +91,7 @@ fn main() -> Result<()> {
model,
solver,
data_path,
output_path,
parallel,
threads,
} => {
Expand All @@ -98,7 +101,7 @@ fn main() -> Result<()> {
RunOptions::default()
};

run(model, solver, data_path.as_deref(), &options)
run(model, solver, data_path.as_deref(), output_path.as_deref(), &options)
}
},
None => {}
Expand Down Expand Up @@ -148,11 +151,11 @@ fn v1_to_v2(path: &Path) -> std::result::Result<(), ConversionError> {
Ok(())
}

fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, options: &RunOptions) {
fn run(path: &Path, solver: &Solver, data_path: Option<&Path>, output_path: Option<&Path>, options: &RunOptions) {
let data = std::fs::read_to_string(path).unwrap();
let schema_v2: PywrModel = serde_json::from_str(data.as_str()).unwrap();

let (model, timestepper): (Model, Timestepper) = schema_v2.try_into_model(data_path).unwrap();
let (model, timestepper): (Model, Timestepper) = schema_v2.build_model(data_path, output_path).unwrap();

match *solver {
Solver::Clp => model.run::<ClpSolver>(&timestepper, options),
Expand Down
13 changes: 13 additions & 0 deletions src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ pub enum Metric {
VirtualStorageVolume(VirtualStorageIndex),
VirtualStorageProportionalVolume(VirtualStorageIndex),
VolumeBetweenControlCurves(VolumeBetweenControlCurves),
MultiNodeInFlow {
indices: Vec<NodeIndex>,
name: String,
sub_name: Option<String>,
},
// TODO implement other MultiNodeXXX variants
Constant(f64),
}

Expand Down Expand Up @@ -107,6 +113,13 @@ impl Metric {
// TODO handle divide by zero
Ok(volume / max_volume)
}
Metric::MultiNodeInFlow { indices, .. } => {
let flow = indices
.iter()
.map(|idx| state.get_network_state().get_node_in_flow(idx))
.sum::<Result<_, _>>()?;
Ok(flow)
}
Metric::NodeInFlowDeficit(idx) => {
let node = model.get_node(idx)?;
let flow = state.get_network_state().get_node_in_flow(idx)?;
Expand Down
90 changes: 61 additions & 29 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ use crate::edge::{EdgeIndex, EdgeVec};
use crate::metric::Metric;
use crate::node::{ConstraintValue, Node, NodeVec, StorageInitialVolume};
use crate::parameters::{MultiValueParameterIndex, ParameterType};
use crate::recorders::{MetricSet, MetricSetIndex};
use crate::scenario::{ScenarioGroupCollection, ScenarioIndex};
use crate::solvers::{MultiStateSolver, Solver, SolverTimings};
use crate::state::{ParameterStates, State};
use crate::timestep::{Timestep, Timestepper};
use crate::virtual_storage::{VirtualStorage, VirtualStorageIndex, VirtualStorageReset, VirtualStorageVec};
use crate::{parameters, recorders, IndexParameterIndex, NodeIndex, ParameterIndex, PywrError, RecorderIndex};
use indicatif::ProgressIterator;
use tracing::{debug, info};
use rayon::prelude::*;
use std::any::Any;
use std::ops::Deref;
use std::time::Duration;
use std::time::Instant;
use tracing::{debug, info};

enum RunDuration {
Running(Instant),
Expand Down Expand Up @@ -168,6 +169,7 @@ pub struct Model {
parameters: Vec<Box<dyn parameters::Parameter>>,
index_parameters: Vec<Box<dyn parameters::IndexParameter>>,
multi_parameters: Vec<Box<dyn parameters::MultiValueParameter>>,
metric_sets: Vec<MetricSet>,
resolve_order: Vec<ComponentType>,
recorders: Vec<Box<dyn recorders::Recorder>>,
}
Expand Down Expand Up @@ -840,6 +842,22 @@ impl Model {
}
}

/// Get a `AggregatedNodeIndex` from a node's name
pub fn get_aggregated_node_index_by_name(
&self,
name: &str,
sub_name: Option<&str>,
) -> Result<AggregatedNodeIndex, PywrError> {
match self
.aggregated_nodes
.iter()
.find(|&n| n.full_name() == (name, sub_name))
{
Some(node) => Ok(node.index()),
None => Err(PywrError::NodeNotFound(name.to_string())),
}
}

pub fn set_aggregated_node_max_flow(
&mut self,
name: &str,
Expand Down Expand Up @@ -949,6 +967,16 @@ impl Model {
}
}

/// Get a `VirtualStorageNode` from a node's name
pub fn get_virtual_storage_node_index_by_name(
&self,
name: &str,
sub_name: Option<&str>,
) -> Result<VirtualStorageIndex, PywrError> {
let node = self.get_virtual_storage_node_by_name(name, sub_name)?;
Ok(node.index())
}

pub fn get_storage_node_metric(
&self,
name: &str,
Expand Down Expand Up @@ -979,34 +1007,6 @@ impl Model {
}
}

pub fn get_node_default_metrics(&self) -> Vec<(Metric, (String, Option<String>))> {
self.nodes
.iter()
.map(|n| {
let metric = n.default_metric();
let (name, sub_name) = n.full_name();
(metric, (name.to_string(), sub_name.map(|s| s.to_string())))
})
.chain(self.aggregated_nodes.iter().map(|n| {
let metric = n.default_metric();
let (name, sub_name) = n.full_name();
(metric, (name.to_string(), sub_name.map(|s| s.to_string())))
}))
.collect()
}

pub fn get_parameter_metrics(&self) -> Vec<(Metric, (String, Option<String>))> {
self.parameters
.iter()
.enumerate()
.map(|(idx, p)| {
let metric = Metric::ParameterValue(ParameterIndex::new(idx));

(metric, (format!("param-{}", p.name()), None))
})
.collect()
}

/// Get a `Parameter` from a parameter's name
pub fn get_parameter(&self, index: &ParameterIndex) -> Result<&dyn parameters::Parameter, PywrError> {
match self.parameters.get(*index.deref()) {
Expand Down Expand Up @@ -1266,6 +1266,38 @@ impl Model {
Ok(parameter_index)
}

/// Add a [`MetricSet`] to the model.
pub fn add_metric_set(&mut self, metric_set: MetricSet) -> Result<MetricSetIndex, PywrError> {
if let Ok(_) = self.get_metric_set_by_name(&metric_set.name()) {
return Err(PywrError::MetricSetNameAlreadyExists(metric_set.name().to_string()));
}

let metric_set_idx = MetricSetIndex::new(self.metric_sets.len());
self.metric_sets.push(metric_set);
Ok(metric_set_idx)
}

/// Get a [`MetricSet'] from its index.
pub fn get_metric_set(&self, index: MetricSetIndex) -> Result<&MetricSet, PywrError> {
self.metric_sets.get(*index).ok_or(PywrError::MetricSetIndexNotFound)
}

/// Get a ['MetricSet'] by its name.
pub fn get_metric_set_by_name(&self, name: &str) -> Result<&MetricSet, PywrError> {
self.metric_sets
.iter()
.find(|&m| m.name() == name)
.ok_or(PywrError::MetricSetNotFound(name.to_string()))
}

/// Get a ['MetricSetIndex'] by its name.
pub fn get_metric_set_index_by_name(&self, name: &str) -> Result<MetricSetIndex, PywrError> {
match self.metric_sets.iter().position(|m| m.name() == name) {
Some(idx) => Ok(MetricSetIndex::new(idx)),
None => Err(PywrError::MetricSetNotFound(name.to_string())),
}
}

/// Add a `recorders::Recorder` to the model
pub fn add_recorder(&mut self, recorder: Box<dyn recorders::Recorder>) -> Result<RecorderIndex, PywrError> {
// TODO reinstate this check
Expand Down
15 changes: 3 additions & 12 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::aggregated_node::AggregatedNodeIndex;
use crate::model::{Model, RunOptions};
use crate::recorders::hdf::HDF5Recorder;
use crate::recorders::HDF5Recorder;
use crate::schema::model::PywrModel;
use crate::solvers::ClpSolver;
#[cfg(feature = "highs")]
Expand Down Expand Up @@ -666,7 +666,6 @@ fn run_model_from_path(
path: PathBuf,
solver_name: String,
data_path: Option<PathBuf>,
output_h5: Option<PathBuf>,
num_threads: Option<usize>,
) -> PyResult<()> {
let data = std::fs::read_to_string(path.clone()).unwrap();
Expand All @@ -676,7 +675,7 @@ fn run_model_from_path(
Some(dp) => Some(dp),
};

run_model_from_string(py, data, solver_name, data_path, output_h5, num_threads)
run_model_from_string(py, data, solver_name, data_path, num_threads)
}

#[pyfunction]
Expand All @@ -685,20 +684,12 @@ fn run_model_from_string(
data: String,
solver_name: String,
path: Option<PathBuf>,
output_h5: Option<PathBuf>,
num_threads: Option<usize>,
) -> PyResult<()> {
// TODO handle the serde error properly
let schema_v2: PywrModel = serde_json::from_str(data.as_str()).unwrap();

let (mut model, timestepper): (Model, Timestepper) = schema_v2.try_into_model(path.as_deref())?;

if let Some(pth) = output_h5 {
let metrics = model.get_node_default_metrics();
// metrics.extend(model.get_parameter_metrics());
let tables_rec = HDF5Recorder::new("tables", pth, metrics);
model.add_recorder(Box::new(tables_rec)).unwrap();
}
let (mut model, timestepper): (Model, Timestepper) = schema_v2.build_model(path.as_deref(), None)?;

let nt = num_threads.unwrap_or(1);
let options = if nt > 1 {
Expand Down
Loading

0 comments on commit 9d31a3a

Please sign in to comment.