Skip to content

Commit

Permalink
Merge pull request #658 from JuliaSymbolics/hash-consing
Browse files Browse the repository at this point in the history
Implement hash consing for `Sym`
  • Loading branch information
ChrisRackauckas authored Nov 7, 2024
2 parents f9b0ade + 13b642b commit a587847
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 11 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand Down Expand Up @@ -57,6 +58,7 @@ SymbolicIndexingInterface = "0.3"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.3"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments
# For ReverseDiffExt
import ArrayInterface
using WeakValueDicts: WeakValueDict

Base.@deprecate istree iscall
export istree, operation, arguments, sorted_arguments, iscall
Expand Down
95 changes: 85 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
const EMPTY_DICT_T = typeof(EMPTY_DICT)

@compactify show_methods=false begin
@abstract struct BasicSymbolic{T} <: Symbolic{T}
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
metadata::Metadata = NO_METADATA
end
struct Sym{T} <: BasicSymbolic{T}
mutable struct Sym{T} <: BasicSymbolic{T}
name::Symbol = :OOF
end
struct Term{T} <: BasicSymbolic{T}
mutable struct Term{T} <: BasicSymbolic{T}
f::Any = identity # base/num if Pow; issorted if Add/Dict
arguments::Vector{Any} = EMPTY_ARGS
hash::RefValue{UInt} = EMPTY_HASH
end
struct Mul{T} <: BasicSymbolic{T}
mutable struct Mul{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Add{T} <: BasicSymbolic{T}
mutable struct Add{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Div{T} <: BasicSymbolic{T}
mutable struct Div{T} <: BasicSymbolic{T}
num::Any = 1
den::Any = 1
simplified::Bool = false
arguments::Vector{Any} = EMPTY_ARGS
end
struct Pow{T} <: BasicSymbolic{T}
mutable struct Pow{T} <: BasicSymbolic{T}
base::Any = 1
exp::Any = 1
arguments::Vector{Any} = EMPTY_ARGS
Expand All @@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
end
end

const wvd = WeakValueDict{UInt, BasicSymbolic}()

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
Expand All @@ -92,7 +94,11 @@ const SIMPLIFIED = 0x01 << 0
function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T
nt = getproperties(obj)
nt_new = merge(nt, patch)
Unityper.rt_constructor(obj){T}(;nt_new...)
# Call outer constructor because hash consing cannot be applied in inner constructor
@compactified obj::BasicSymbolic begin
Sym => Sym{T}(nt_new.name; nt_new...)
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
end
end

###
Expand Down Expand Up @@ -265,6 +271,26 @@ function _isequal(a, b, E)
end
end

"""
$(TYPEDSIGNATURES)
Checks for equality between two `BasicSymbolic` objects, considering both their
values and metadata.
The default `Base.isequal` function for `BasicSymbolic` only compares their expressions
and ignores metadata. This does not help deal with hash collisions when metadata is
relevant for distinguishing expressions, particularly in hashing contexts. This function
provides a stricter equality check that includes metadata comparison, preventing
such collisions.
Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
function.
"""
function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool
isequal(a, b) && isequal(metadata(a), metadata(b))
end

Base.one( s::Symbolic) = one( symtype(s))
Base.zero(s::Symbolic) = zero(symtype(s))

Expand Down Expand Up @@ -307,12 +333,61 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
end
end

"""
$(TYPEDSIGNATURES)
Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and
symtype.
This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic`
objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
consing, allowing for more effective deduplication of symbolically equivalent expressions
with different metadata or symtypes.
"""
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash(metadata(s), hash(T, hash(s, salt)))
end

###
### Constructors
###

function Sym{T}(name::Symbol; kw...) where T
Sym{T}(; name=name, kw...)
"""
$(TYPEDSIGNATURES)
Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
which includes metadata comparison, is used to confirm the equivalence of objects with
matching hashes. If an equivalent object is found, the existing object is returned;
otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
for runtime code generation, and supports built-in common subexpression elimination,
particularly when working with symbolic objects with metadata.
Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
stored, allowing objects that are no longer strongly referenced to be garbage collected.
Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
original behavior of those functions.
"""
function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
h = hash2(s)
t = get!(wvd, h, s)
if t === s || isequal_with_metadata(t, s)
return t
else
return s
end
end

function Sym{T}(name::Symbol; kw...) where {T}
s = Sym{T}(; name, kw...)
BasicSymbolic(s)
end

function Term{T}(f, args; kw...) where T
Expand Down
9 changes: 8 additions & 1 deletion test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, isequal_with_metadata
using SymbolicUtils
using IfElse: ifelse
using Setfield
Expand Down Expand Up @@ -336,6 +336,13 @@ end

@test !isequal(a, missing)
@test !isequal(missing, b)

a1 = setmetadata(a, Ctx1, "meta_1")
a2 = setmetadata(a, Ctx1, "meta_1")
a3 = setmetadata(a, Ctx2, "meta_2")
@test !isequal_with_metadata(a, a1)
@test isequal_with_metadata(a1, a2)
@test !isequal_with_metadata(a1, a3)
end

@testset "subtyping" begin
Expand Down
26 changes: 26 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using SymbolicUtils, Test

struct Ctx1 end
struct Ctx2 end

@testset "Sym" begin
x1 = only(@syms x)
x2 = only(@syms x)
@test x1 === x2
x3 = only(@syms x::Float64)
@test x1 !== x3
x4 = only(@syms x::Float64)
@test x1 !== x4
@test x3 === x4
x5 = only(@syms x::Int)
x6 = only(@syms x::Int)
@test x1 !== x5
@test x3 !== x5
@test x5 === x6

xm1 = setmetadata(x1, Ctx1, "meta_1")
xm2 = setmetadata(x1, Ctx1, "meta_1")
@test xm1 === xm2
xm3 = setmetadata(x1, Ctx2, "meta_2")
@test xm1 !== xm3
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ using Pkg, Test, SafeTestsets
# Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed
@safetestset "Fuzz" begin include("fuzz.jl") end
@safetestset "Adjoints" begin include("adjoints.jl") end
@safetestset "Hash Consing" begin include("hash_consing.jl") end
end
end

0 comments on commit a587847

Please sign in to comment.