Skip to content

Commit

Permalink
feat: Add functionality to export a JSON schema of PywrModel.
Browse files Browse the repository at this point in the history
Use the schemars crate to export a JSON schema via the CLI. This
needs implementing for the remaining model types.
  • Loading branch information
jetuk committed Apr 26, 2024
1 parent 03fb32d commit 90e42b3
Show file tree
Hide file tree
Showing 45 changed files with 177 additions and 116 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ csv = "1.1"
hdf5 = { git = "https://github.com/aldanor/hdf5-rust.git", package = "hdf5", features = ["static", "zlib"] }
pywr-v1-schema = { git = "https://github.com/pywr/pywr-schema/", tag = "v0.12.0", package = "pywr-schema" }
chrono = { version = "0.4.34" }
schemars = { version = "0.8.16", features = ["chrono"] }
5 changes: 3 additions & 2 deletions pywr-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ categories = ["science", "simulation"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
clap = { version="4.0", features=["derive"] }
clap = { version = "4.0", features = ["derive"] }
anyhow = "1.0.69"
tracing = { workspace = true }
tracing-subscriber = { version ="0.3.17", features=["env-filter"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
rand = "0.8.5"
rand_chacha = "0.3.1"
serde = { workspace = true }
serde_json = { workspace = true }
pywr-v1-schema = { workspace = true }
schemars = { workspace = true }

pywr-core = { path = "../pywr-core" }
pywr-schema = { path = "../pywr-schema" }
19 changes: 18 additions & 1 deletion pywr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod tracing;

use crate::tracing::setup_tracing;
use ::tracing::info;
use anyhow::Result;
use anyhow::{Context, Result};
use clap::{Parser, Subcommand, ValueEnum};
#[cfg(feature = "ipm-ocl")]
use pywr_core::solvers::{ClIpmF32Solver, ClIpmF64Solver, ClIpmSolverSettings};
Expand All @@ -15,6 +15,7 @@ use pywr_core::test_utils::make_random_model;
use pywr_schema::model::{PywrModel, PywrMultiNetworkModel};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use schemars::schema_for;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};

Expand Down Expand Up @@ -109,6 +110,10 @@ enum Commands {
#[arg(short, long, default_value_t=Solver::Clp)]
solver: Solver,
},
ExportSchema {
/// Path to save the JSON schema.
out: PathBuf,
},
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -140,6 +145,7 @@ fn main() -> Result<()> {
num_scenarios,
solver,
} => run_random(*num_systems, *density, *num_scenarios, solver),
Commands::ExportSchema { out } => export_schema(out)?,
},
None => {}
}
Expand Down Expand Up @@ -254,3 +260,14 @@ fn run_random(num_systems: usize, density: usize, num_scenarios: usize, solver:
}
.unwrap();
}

fn export_schema(out_path: &Path) -> Result<()> {
let schema = schema_for!(PywrModel);
std::fs::write(
out_path,
serde_json::to_string_pretty(&schema).with_context(|| "Failed serialise Pywr schema".to_string())?,
)
.with_context(|| format!("Failed to write file: {:?}", out_path))?;

Ok(())
}
2 changes: 1 addition & 1 deletion pywr-schema/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyo3 = { workspace = true, optional = true }
pyo3-polars = { workspace = true, optional = true }
strum = "0.26"
strum_macros = "0.26"

schemars = { workspace = true }
hdf5 = { workspace = true, optional = true }
csv = { workspace = true, optional = true }
tracing = { workspace = true, optional = true }
Expand Down
11 changes: 6 additions & 5 deletions pywr-schema/src/data_tables/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use pywr_v1_schema::parameters::TableDataRef as TableDataRefV1;
use scalar::{
load_csv_row2_scalar_table_one, load_csv_row_col_scalar_table_one, load_csv_row_scalar_table_one, LoadedScalarTable,
};
use schemars::JsonSchema;
#[cfg(feature = "core")]
use std::collections::HashMap;
use std::path::{Path, PathBuf};
Expand All @@ -19,7 +20,7 @@ use tracing::{debug, info};
#[cfg(feature = "core")]
use vec::{load_csv_row2_vec_table_one, load_csv_row_vec_table_one, LoadedVecTable};

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum DataTableType {
Scalar,
Expand All @@ -31,7 +32,7 @@ pub enum DataTableFormat {
CSV,
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(tag = "format", rename_all = "lowercase")]
pub enum DataTable {
CSV(CsvDataTable),
Expand All @@ -52,7 +53,7 @@ impl DataTable {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum CsvDataTableLookup {
Row(usize),
Expand All @@ -61,7 +62,7 @@ pub enum CsvDataTableLookup {
}

/// An external table of data that can be referenced
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct CsvDataTable {
pub name: String,
#[serde(rename = "type")]
Expand Down Expand Up @@ -234,7 +235,7 @@ impl LoadedTableCollection {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct TableDataRef {
pub table: String,
pub column: Option<TableIndex>,
Expand Down
4 changes: 3 additions & 1 deletion pywr-schema/src/edge.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#[derive(serde::Deserialize, serde::Serialize, Clone)]
use schemars::JsonSchema;

#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema)]
pub struct Edge {
pub from_node: String,
pub to_node: String,
Expand Down
2 changes: 2 additions & 0 deletions pywr-schema/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub enum SchemaError {
Timeseries(#[from] TimeseriesError),
#[error("The output of literal constant values is not supported. This is because they do not have a unique identifier such as a name. If you would like to output a constant value please use a `Constant` parameter.")]
LiteralConstantOutputNotSupported,
#[error("Chrono out of range error: {0}")]
OutOfRange(#[from] chrono::OutOfRange),
}

#[cfg(feature = "core")]
Expand Down
11 changes: 6 additions & 5 deletions pywr-schema/src/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ use crate::ConversionError;
#[cfg(feature = "core")]
use pywr_core::{metric::MetricF64, models::MultiNetworkTransferIndex, recorders::OutputMetric};
use pywr_v1_schema::parameters::ParameterValue as ParameterValueV1;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum_macros::Display;

/// Output metrics that can be recorded from a model run.
#[derive(Deserialize, Serialize, Clone, Debug, Display)]
#[derive(Deserialize, Serialize, Clone, Debug, Display, JsonSchema)]
#[serde(tag = "type")]
pub enum Metric {
Constant {
Expand Down Expand Up @@ -226,14 +227,14 @@ impl TryFromV1Parameter<ParameterValueV1> for Metric {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
#[serde(tag = "type", content = "name")]
pub enum TimeseriesColumns {
Scenario(String),
Column(String),
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct TimeseriesReference {
name: String,
columns: TimeseriesColumns,
Expand All @@ -249,7 +250,7 @@ impl TimeseriesReference {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct NodeReference {
/// The name of the node
pub name: String,
Expand Down Expand Up @@ -295,7 +296,7 @@ impl NodeReference {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, JsonSchema)]
pub struct ParameterReference {
/// The name of the parameter
pub name: String,
Expand Down
9 changes: 5 additions & 4 deletions pywr-schema/src/metric_sets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use crate::error::SchemaError;
use crate::metric::Metric;
#[cfg(feature = "core")]
use crate::model::LoadArgs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::num::NonZeroUsize;

/// Aggregation function to apply over metric values.
#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone, JsonSchema)]
#[serde(tag = "type")]
pub enum MetricAggFunc {
Sum,
Expand All @@ -30,7 +31,7 @@ impl From<MetricAggFunc> for pywr_core::recorders::AggregationFunction {
}
}

#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Debug, Copy, Clone, JsonSchema)]
#[serde(tag = "type")]
pub enum MetricAggFrequency {
Monthly,
Expand Down Expand Up @@ -58,7 +59,7 @@ impl From<MetricAggFrequency> for pywr_core::recorders::AggregationFrequency {
///
/// If the metric set has a child aggregator then the aggregation will be performed over the
/// aggregated values of the child aggregator.
#[derive(Deserialize, Serialize, Clone)]
#[derive(Deserialize, Serialize, Clone, JsonSchema)]
pub struct MetricAggregator {
/// Optional aggregation frequency.
pub freq: Option<MetricAggFrequency>,
Expand All @@ -84,7 +85,7 @@ impl From<MetricAggregator> for pywr_core::recorders::Aggregator {
/// A metric set can optionally have an aggregator, which will apply an aggregation function
/// over metrics set. If the aggregator has a defined frequency then the aggregation will result
/// in multiple values (i.e. per each period implied by the frequency).
#[derive(Deserialize, Serialize, Clone)]
#[derive(Deserialize, Serialize, Clone, JsonSchema)]
pub struct MetricSet {
pub name: String,
pub metrics: Vec<Metric>,
Expand Down
15 changes: 8 additions & 7 deletions pywr-schema/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ use chrono::NaiveTime;
use chrono::{NaiveDate, NaiveDateTime};
#[cfg(feature = "core")]
use pywr_core::{models::ModelDomain, timestep::TimestepDuration, PywrError};
use schemars::JsonSchema;
use std::path::{Path, PathBuf};
use std::str::FromStr;

#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema)]
pub struct Metadata {
pub title: String,
pub description: Option<String>,
Expand Down Expand Up @@ -51,7 +52,7 @@ impl TryFrom<pywr_v1_schema::model::Metadata> for Metadata {
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, JsonSchema)]
#[serde(untagged)]
pub enum Timestep {
Days(i64),
Expand All @@ -67,7 +68,7 @@ impl From<pywr_v1_schema::model::Timestep> for Timestep {
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, Debug)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Copy, Debug, JsonSchema)]
#[serde(untagged)]
pub enum DateType {
Date(NaiveDate),
Expand All @@ -83,7 +84,7 @@ impl From<pywr_v1_schema::model::DateType> for DateType {
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, JsonSchema)]
pub struct Timestepper {
pub start: DateType,
pub end: DateType,
Expand Down Expand Up @@ -132,7 +133,7 @@ impl From<Timestepper> for pywr_core::timestep::Timestepper {
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema)]
pub struct Scenario {
pub name: String,
pub size: usize,
Expand All @@ -150,7 +151,7 @@ pub struct LoadArgs<'a> {
pub inter_network_transfers: &'a [PywrMultiNetworkTransfer],
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Default, JsonSchema)]
pub struct PywrNetwork {
pub nodes: Vec<Node>,
pub edges: Vec<Edge>,
Expand Down Expand Up @@ -370,7 +371,7 @@ pub enum PywrNetworkRef {
///
///
///
#[derive(serde::Deserialize, serde::Serialize, Clone)]
#[derive(serde::Deserialize, serde::Serialize, Clone, JsonSchema)]
pub struct PywrModel {
pub metadata: Metadata,
pub timestepper: Timestepper,
Expand Down
16 changes: 8 additions & 8 deletions pywr-schema/src/nodes/annual_virtual_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@ use pywr_core::{
};
use pywr_schema_macros::PywrNode;
use pywr_v1_schema::nodes::AnnualVirtualStorageNode as AnnualVirtualStorageNodeV1;
use schemars::JsonSchema;
use std::collections::HashMap;

#[derive(serde::Deserialize, serde::Serialize, Clone, Debug)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Debug, JsonSchema)]
pub struct AnnualReset {
pub day: u8,
pub month: chrono::Month,
pub month: u8,
pub use_initial_volume: bool,
}

impl Default for AnnualReset {
fn default() -> Self {
Self {
day: 1,
month: chrono::Month::January,
month: 1,
use_initial_volume: false,
}
}
}

#[derive(serde::Deserialize, serde::Serialize, Clone, Default, Debug, PywrNode)]
#[derive(serde::Deserialize, serde::Serialize, Clone, Default, Debug, PywrNode, JsonSchema)]
pub struct AnnualVirtualStorageNode {
#[serde(flatten)]
pub meta: NodeMeta,
Expand Down Expand Up @@ -85,9 +86,10 @@ impl AnnualVirtualStorageNode {
.map(|name| network.get_node_index_by_name(name.as_str(), None))
.collect::<Result<Vec<_>, _>>()?;

let reset_month = self.reset.month.try_into()?;
let reset = VirtualStorageReset::DayOfYear {
day: self.reset.day as u32,
month: self.reset.month,
month: reset_month,
};

network.add_virtual_storage_node(
Expand Down Expand Up @@ -168,8 +170,6 @@ impl TryFrom<AnnualVirtualStorageNodeV1> for AnnualVirtualStorageNode {
});
};

let month = chrono::Month::try_from(v1.reset_month as u8)?;

let n = Self {
meta,
nodes: v1.nodes,
Expand All @@ -180,7 +180,7 @@ impl TryFrom<AnnualVirtualStorageNodeV1> for AnnualVirtualStorageNode {
initial_volume,
reset: AnnualReset {
day: v1.reset_day as u8,
month,
month: v1.reset_month as u8,
use_initial_volume: v1.reset_to_initial_volume,
},
};
Expand Down
Loading

0 comments on commit 90e42b3

Please sign in to comment.