From 9fa405c503a11922fe524a900b4b193613eb5997 Mon Sep 17 00:00:00 2001 From: GabrielKS <23368820+GabrielKS@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:31:14 -0600 Subject: [PATCH 1/2] Fix bug 407, test that --- src/utils/utils.jl | 36 ++++++++++-------------------------- test/test_system_data.jl | 3 +++ 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index a8160bb49..a47397b69 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -126,6 +126,8 @@ _fetch_match_fn(::Nothing) = isequivalent # Whether to stop recursing and apply the match_fn _is_compare_directly(::DataType, ::DataType) = true _is_compare_directly(::T, ::U) where {T, U} = true +# As of 1.11, Arrays have fields we don't want to touch +_is_compare_directly(::T, ::T) where {T <: AbstractArray} = true _is_compare_directly(::T, ::T) where {T} = isempty(fieldnames(T)) """ @@ -161,33 +163,15 @@ function compare_values(match_fn::Union{Function, Nothing}, x::T, y::U; field_name in exclude && continue val1 = getproperty(x, field_name) val2 = getproperty(y, field_name) - if !isempty(fieldnames(typeof(val1))) - if !compare_values( - match_fn, - val1, - val2; - compare_uuids = compare_uuids, - exclude = exclude, - ) - @error "values do not match" T field_name val1 val2 - match = false - end - elseif val1 isa AbstractArray - if !compare_values( - match_fn, - val1, - val2; - compare_uuids = compare_uuids, - exclude = exclude, - ) - @error "values do not match" T field_name val1 val2 - match = false - end + sub_result = if _is_compare_directly(val1, val2) + _fetch_match_fn(match_fn)(val1, val2) else - if !_fetch_match_fn(match_fn)(val1, val2) - @error "values do not match" T field_name val1 val2 - match = false - end + compare_values(match_fn, val1, val2; + compare_uuids = compare_uuids, exclude = exclude) + end + if !sub_result + @error "values do not match" T field_name val1 val2 + match = false end end diff --git a/test/test_system_data.jl b/test/test_system_data.jl index b8c4254ba..3c485a784 100644 --- a/test/test_system_data.jl +++ b/test/test_system_data.jl @@ -201,6 +201,9 @@ end special1, )) @test IS.compare_values(==, special1, special1) + + # https://github.com/NREL-Sienna/InfrastructureSystems.jl/issues/407 + @test InfrastructureSystems.compare_values([0 0], [0 0]) end @testset "Test compression settings" begin From 83da575fe3800d3dfb5b096973575d9975cfe05e Mon Sep 17 00:00:00 2001 From: GabrielKS <23368820+GabrielKS@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:57:24 -0600 Subject: [PATCH 2/2] Fix bug 407 better --- src/utils/utils.jl | 27 +++++++++++---------------- test/test_system_data.jl | 8 ++++++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/utils/utils.jl b/src/utils/utils.jl index a47397b69..ebc424aba 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -126,8 +126,6 @@ _fetch_match_fn(::Nothing) = isequivalent # Whether to stop recursing and apply the match_fn _is_compare_directly(::DataType, ::DataType) = true _is_compare_directly(::T, ::U) where {T, U} = true -# As of 1.11, Arrays have fields we don't want to touch -_is_compare_directly(::T, ::T) where {T <: AbstractArray} = true _is_compare_directly(::T, ::T) where {T} = isempty(fieldnames(T)) """ @@ -163,12 +161,8 @@ function compare_values(match_fn::Union{Function, Nothing}, x::T, y::U; field_name in exclude && continue val1 = getproperty(x, field_name) val2 = getproperty(y, field_name) - sub_result = if _is_compare_directly(val1, val2) - _fetch_match_fn(match_fn)(val1, val2) - else - compare_values(match_fn, val1, val2; - compare_uuids = compare_uuids, exclude = exclude) - end + sub_result = compare_values(match_fn, val1, val2; + compare_uuids = compare_uuids, exclude = exclude) if !sub_result @error "values do not match" T field_name val1 val2 match = false @@ -178,20 +172,21 @@ function compare_values(match_fn::Union{Function, Nothing}, x::T, y::U; return match end +# compare_values of an AbstractArray: ignore the fields, iterate over all dimensions of the array function compare_values( match_fn::Union{Function, Nothing}, - x::Vector{T}, - y::Vector{T}; + x::AbstractArray, + y::AbstractArray; compare_uuids = false, exclude = Set{Symbol}(), -) where {T} - if length(x) != length(y) - @error "lengths do not match" T length(x) length(y) +) + if size(x) != size(y) + @error "sizes do not match" size(x) size(y) return false end match = true - for i in range(1; length = length(x)) + for i in keys(x) if !compare_values( match_fn, x[i], @@ -209,8 +204,8 @@ end function compare_values( match_fn::Union{Function, Nothing}, - x::Dict, - y::Dict; + x::AbstractDict, + y::AbstractDict; compare_uuids = false, exclude = Set{Symbol}(), ) diff --git a/test/test_system_data.jl b/test/test_system_data.jl index 3c485a784..8ec3acb55 100644 --- a/test/test_system_data.jl +++ b/test/test_system_data.jl @@ -204,6 +204,14 @@ end # https://github.com/NREL-Sienna/InfrastructureSystems.jl/issues/407 @test InfrastructureSystems.compare_values([0 0], [0 0]) + + # Test that for arrays and dicts we are actually comparing the values + my_match_fn_3(::Int64, ::Int64) = true + my_match_fn_3(::Any, ::Any) = false + @test IS.compare_values(my_match_fn_3, [0, 1], [0, 1]) + @test IS.compare_values(my_match_fn_3, [0 1], [0 1]) + @test IS.compare_values(my_match_fn_3, + Dict("a" => 0, "b" => 1), Dict("a" => 0, "b" => 1)) end @testset "Test compression settings" begin