Skip to content

Commit

Permalink
Fixed bug in filtration functions; updated aggregation to save lost e…
Browse files Browse the repository at this point in the history
…dge data (#17)

* Updated aggregation and tests

* updated examples to be PlasmoDataPlots.jl

* added to tests; updated docs; fixed minor bug
  • Loading branch information
dlcole3 authored Mar 7, 2024
1 parent ca8eaf8 commit 0d3efae
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 40 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
| [![doc](https://img.shields.io/badge/docs-dev-blue.svg)](https://zavalab.github.io/PlasmoData.jl/dev) | [![build](https://github.com/zavalab/PlasmoData.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/zavalab/PlasmoData.jl/actions) | [![codecov](https://codecov.io/gh/zavalab/PlasmoData.jl/branch/main/graph/badge.svg?token=LZJ3T1XQZ0)](https://app.codecov.io/gh/zavalab/PlasmoData.jl) |


PlasmoData.jl is a package for [Julia](https://julialang.org/) designed for representing and modeling data as graphs and for building graph models that contain large amounts of data on the nodes or edges of the graph. This package also has an accompanying package [DataGraphPlots.jl](https://github.com/zavalab/DataGraphPlots.jl) which can be used for plotting the graphs.
PlasmoData.jl is a package for [Julia](https://julialang.org/) designed for representing and modeling data as graphs and for building graph models that contain large amounts of data on the nodes or edges of the graph. This package also has an accompanying package [PlasmoDataPlots.jl](https://github.com/zavalab/PlasmoDataPlots.jl) which can be used for plotting the graphs.

PlasmoData.jl is built on the abstraction called a `DataGraph`. The manuscript ["PlasmoData.jl -- A Julia Framework for Modeling and Analyzing Complex Data as Graphs"](https://arxiv.org/abs/2401.11404) details the abstraction and this package.

## Bug Reports and Support

Expand Down
4 changes: 2 additions & 2 deletions examples/basic_functions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Revise
using PlasmoData, Graphs
using DataGraphPlots
using PlasmoDataPlots

dg = DataGraph()

Expand Down Expand Up @@ -28,4 +28,4 @@ add_edge_data!(dg, "node4", 1, 1.0, "weight")
add_edge_data!(dg, :node5, 2, -.00001, "weight")
add_edge_data!(dg, 3, "node4", 1, "weight")

DataGraphPlots.plot_graph(dg; xdim = 400, ydim = 400)
PlasmoDataPlots.plot_graph(dg; xdim = 400, ydim = 400)
2 changes: 1 addition & 1 deletion examples/edge_weighted_EC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Revise
using PlasmoData, Graphs
using JLD, LinearAlgebra
using Plots, Statistics
using DataGraphPlots
using PlasmoDataPlots

# Data for this example comes from Alex Smith's paper on the Euler Characteristic:
# https://doi.org/10.1016/j.compchemeng.2021.107463
Expand Down
2 changes: 1 addition & 1 deletion examples/matrix_to_graph.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Revise
using PlasmoData, Graphs
using DataGraphPlots
using PlasmoDataPlots

mat = rand(10, 10)

Expand Down
2 changes: 1 addition & 1 deletion examples/tensor_to_graph.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Revise
using PlasmoData, Graphs
using Statistics, DelimitedFiles
using DataGraphPlots
using PlasmoDataPlots

abc = rand(10, 4, 5)

Expand Down
112 changes: 95 additions & 17 deletions src/datadigraphs/utils.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""
filter_nodes(datadigraph, filter_value, attribute = dg.node_data.attributes[1]; fn = isless)
Removes the nodes of the graph whose weight value of `attribute` is greater than the given
`filter_value`. If `attribute` is not specified, this defaults to the first attribute within
the DataGraph's `NodeData`.
Removes the nodes of the graph whose data on `attribute` is does not meet the criteria of `fn`
with respect to `filter_value`. If `attribute` is not specified, this defaults to the first
attribute within the DataDiGraph's `NodeData`.
`fn` is a function that takes an input of two scalar values and is broadcast to the data vector.
For example, isless, isgreater, isequal
`fn` is a function that takes an input of a node's data on attribute and the `filter_value`
and returns a true or false
"""
function filter_nodes(
dg::DataDiGraph,
filter_val::R,
attribute::String=dg.node_data.attributes[1];
fn::Function = isless
) where {R <: Real}
) where {R <: Any}

node_attributes = dg.node_data.attributes
edge_attributes = dg.edge_data.attributes
Expand All @@ -27,7 +27,7 @@ function filter_nodes(
edges = dg.edges

if length(node_attributes) == 0
error("No node weights are defined")
error("No node data are defined")
end

T = eltype(dg)
Expand Down Expand Up @@ -107,19 +107,19 @@ end
"""
filter_edges(datadigraph, filter_value, attribute = dg.edge-data.attributes[1]; fn = isless)
Removes the edges of the graph whose weight value of `attribute` is greater than the given
`filter_value`. If `attribute` is not specified, this defaults to the first attribute within
the DataGraph's `EdgeData`.
Removes the edges of the graph whose data on `attribute` is does not meet the criteria of `fn`
with respect to `filter_value`. If `attribute` is not specified, this defaults to the first
attribute within the DataDiGraph's `EdgeData`.
`fn` is a function that takes an input of two scalar values and is broadcast to the data vector.
For example, isless, isgreater, isequal
`fn` is a function that takes an input of a edge's data on attribute and the `filter_value`
and returns a true or false
"""
function filter_edges(
dg::DataDiGraph,
filter_val::R,
attribute::String = dg.edge_data.attributes[1];
fn::Function = isless
) where {R <: Real}
) where {R <: Any}

nodes = dg.nodes
edges = dg.edges
Expand All @@ -132,7 +132,7 @@ function filter_edges(
edge_attribute_map = dg.edge_data.attribute_map

if length(edge_attributes) == 0
error("No node weights are defined")
error("No edge data are defined")
end

T = eltype(dg)
Expand Down Expand Up @@ -353,22 +353,36 @@ function remove_edge!(
end

"""
aggregate(datadigraph, node_list, aggregated_node_name; node_fn = mean, edge_fn = mean)
aggregate(datadigraph, node_list, aggregated_node_name;
node_fn = mean, edge_fn = mean, save_agg_edge_data = false,
agg_edge_fn = mean, agg_edge_val = 0, node_attributes_to_add = String[]
)
Aggregates all the nodes in `node_list` into a single node which is called `aggregated_node_name`.
If nodes have any weight/attribute values defined, these values are combined via the `node_fn` function.
The default for `node_fn` is Statistics.mean which averages the data for the nodes in `node_list`.
Edge data are also are also combined via the `edge_fn` when two or more nodes in the `node_list` are
connected to the same node and these edges have data defined on them. The `edge_fn` also defaults
to `Statistics.mean`
If edges exist between nodes in `node_list`, the data on these edges can optionally be saved on
the `aggregated_node_name` node by setting `save_agg_edge_data = true`. If true, then the edge data
on these edges is aggregated using `agg_edge_fn`. If the user wants to define new attribute names for
this data, they can pass a vector to `node_attributes_to_add`; if no vector is defined, the data will
be aggregated under the names of the `edge_data` attributes. All other nodes except the aggregated
nodes will have these attributes initialized as `agg_edge_val`.
"""
function aggregate(
dg::DataDiGraph,
node_set::Vector,
new_name::N;
node_fn::Function = _default_mean,
edge_fn::Function = _default_mean
) where {N <: Any}
edge_fn::Function = _default_mean,
save_agg_edge_data::Bool = false,
agg_edge_fn::Function = _default_mean,
agg_edge_val::R = 0.,
node_attributes_to_add::Vector{String} = String[]
) where {N <: Any, R <: Any}

nodes = dg.nodes
node_map = dg.node_map
Expand All @@ -389,6 +403,26 @@ function aggregate(
error("New node name already exists in set of non-aggregated nodes")
end

if save_agg_edge_data
if length(dg.edge_data.attributes) > 0 && length(node_attributes_to_add) > 0
if length(dg.edge_data.attributes) != length(node_attributes_to_add)
error("Length of the node_attributes_to_add does not match the edge_data attributes")
end
for i in 1:length(node_attributes_to_add)
if node_attributes_to_add[i] in node_attributes
error("Attribute name $(node_attributes_to_add[i]) is already defined in node_attributes")
end
end
elseif length(dg.edge_data.attributes) > 0
attribute_names = dg.edge_data.attributes
for i in 1:length(attribute_names)
if attribute_names[i] in node_attributes
error("Edge data attribute names conflict with node data attributes; user must pass node_attributes_to_add")
end
end
end
end

T = eltype(dg)
T1 = eltype(get_node_data(dg))
M1 = typeof(get_node_data(dg))
Expand Down Expand Up @@ -445,6 +479,8 @@ function aggregate(
edge_bool_avg_index = Dict{Tuple{T, T}, Vector{T}}()
new_edge_data = fill(0, (0, length(edge_attributes)))

removed_edge_bool_vec = [false for i in 1:length(edges)]

for i in 1:length(nodes)
node_name_mapping[node_map[nodes[i]]] = nodes[i]
end
Expand Down Expand Up @@ -532,6 +568,10 @@ function aggregate(
end
end
end
else
if save_agg_edge_data
removed_edge_bool_vec[i] = true
end
end
end

Expand All @@ -548,6 +588,44 @@ function aggregate(
new_dg.edge_data.attributes = copy(edge_attributes)
new_dg.edge_data.attribute_map = copy(edge_attribute_map)
new_dg.edge_data.data = copy(new_edge_data)

if save_agg_edge_data
new_node_data = new_dg.node_data.data
new_node_attributes = new_dg.node_data.attributes
new_node_attribute_map = new_dg.node_data.attribute_map
if length(node_attributes) > 0
edge_data_to_avg = edge_data[removed_edge_bool_vec, :]
if length(node_attributes_to_add) > 0
attributes_to_add = node_attributes_to_add
else
attributes_to_add = edge_attributes
end
for j in 1:length(attributes_to_add)
push!(new_node_attributes, attributes_to_add[j])
new_node_attribute_map[attributes_to_add[j]] = length(new_node_attributes)
end
data_to_add = fill(agg_edge_val, (length(new_nodes), length(edge_attributes)))
data_to_add[length(new_nodes), :] .= agg_edge_fn(edge_data_to_avg)
old_data = new_node_data
new_dg.node_data.data = hcat(old_data, data_to_add)
else
edge_data_to_avg = edge_data[removed_edge_bool_vec, :]
if length(node_attributes_to_add) > 0
attributes_to_add = node_attributes_to_add
else
attributes_to_add = edge_attributes
end
new_dg.node_data.attributes = attributes_to_add
new_node_attribute_map = new_dg.node_data.attribute_map
for j in 1:length(attributes_to_add)
new_node_attribute_map[attributes_to_add[j]] = j
end
data_to_add = fill(agg_edge_val, (length(new_nodes), length(edge_attributes)))
data_to_add[length(new_nodes), :] .= agg_edge_fn(edge_data_to_avg)
old_data = zeros(T1, (length(new_nodes), 0))
new_dg.node_data.data = hcat(old_data, data_to_add)
end
end
end

simple_digraph = Graphs.SimpleDiGraph(T(length(new_edges)), fadjlist, badjlist)
Expand Down
Loading

0 comments on commit 0d3efae

Please sign in to comment.