Skip to content

Commit

Permalink
perf: Remove unneeded allocations in NodeCost::get_cost
Browse files Browse the repository at this point in the history
Refactor to remove collecting into a Vec, and more explicit
handling of the different cases.
  • Loading branch information
jetuk committed Oct 25, 2024
1 parent 929b77d commit 1e2ff69
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions pywr-core/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,34 +694,33 @@ impl Default for NodeCost {

impl NodeCost {
fn get_cost(&self, network: &Network, state: &State) -> Result<f64, PywrError> {
let local_cost = match &self.local {
// Initial local cost that has any virtual storage cost applied
let mut cost = match &self.local {
None => Ok(0.0),
Some(m) => m.get_value(network, state),
}?;

let vs_costs: Vec<f64> = self
.virtual_storage_nodes
.iter()
.map(|idx| {
let vs = network.get_virtual_storage_node(idx)?;
vs.get_cost(network, state)
})
.collect::<Result<_, _>>()?;

let cost = match self.agg_func {
CostAggFunc::Sum => local_cost + vs_costs.iter().sum::<f64>(),
CostAggFunc::Max => local_cost.max(
vs_costs
.into_iter()
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(f64::NEG_INFINITY),
),
CostAggFunc::Min => local_cost.min(
vs_costs
.into_iter()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(f64::INFINITY),
),
let vs_costs = self.virtual_storage_nodes.iter().map(|idx| {
let vs = network.get_virtual_storage_node(idx)?;
vs.get_cost(network, state)
});

match self.agg_func {
CostAggFunc::Sum => {
for vs_cost in vs_costs {
cost += vs_cost?;
}
}
CostAggFunc::Max => {
for vs_cost in vs_costs {
cost = cost.max(vs_cost?);
}
}
CostAggFunc::Min => {
for vs_cost in vs_costs {
cost = cost.min(vs_cost?);
}
}
};

Ok(cost)
Expand Down

0 comments on commit 1e2ff69

Please sign in to comment.