Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply hash consing to all BasicSymbolic subtypes #673

Merged
merged 16 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
# Call outer constructor because hash consing cannot be applied in inner constructor
@compactified obj::BasicSymbolic begin
Sym => Sym{T}(nt_new.name; nt_new...)
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new...)
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new...)
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new...)
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new...)
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new...)
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
end
end
Expand Down Expand Up @@ -298,6 +303,7 @@ Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymboli

## This is much faster than hash of an array of Any
hashvec(xs, z) = foldr(hash, xs, init=z)
hashvec2(xs, z) = foldr(hash2, xs, init=z)
const SYM_SALT = 0x4de7d7c66d41da43 % UInt
const ADD_SALT = 0xaddaddaddaddadda % UInt
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
Expand Down Expand Up @@ -344,10 +350,43 @@ objects. Unlike `Base.hash`, which only considers the expression structure, `has
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
consing, allowing for more effective deduplication of symbolically equivalent expressions
with different metadata or symtypes.

Equivalent numbers of different types, such as `0.5::Float64` and
`(1 // 2)::Rational{Int64}`, have the same default `Base.hash` value. The `hash2` function
distinguishes these by including their numeric types in the hash calculation to ensure that
symbolically equivalent expressions with different numeric types are treated as distinct
objects.
"""
hash2(s, salt::UInt) = hash(s, salt)
function hash2(n::T, salt::UInt) where {T <: Number}
hash(T, hash(n, salt))
end
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash(metadata(s), hash(T, hash(s, salt)))
E = exprtype(s)
h::UInt = 0
if E === SYM
h = hash(nameof(s), salt ⊻ SYM_SALT)
elseif E === ADD || E === MUL
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
hv = Base.hasha_seed
for (k, v) in s.dict
hv ⊻= hash2(k, hash(v))
end
h = hash(hv, salt)
h = hash(hashoffset, hash2(s.coeff, h))
elseif E === DIV
h = hash2(s.num, hash2(s.den, salt ⊻ DIV_SALT))
elseif E === POW
h = hash2(s.exp, hash2(s.base, salt ⊻ POW_SALT))
elseif E === TERM
op = operation(s)
oph = op isa Function ? nameof(op) : op
h = hashvec2(arguments(s), hash(oph, salt))
else
error_on_type()
end
hash(metadata(s), hash(T, h))
end

###
Expand Down Expand Up @@ -395,7 +434,8 @@ function Term{T}(f, args; kw...) where T
args = convert(Vector{Any}, args)
end

Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
BasicSymbolic(s)
end

function Term(f, args; metadata=NO_METADATA)
Expand All @@ -415,7 +455,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
end
end

Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
end

function Mul(T, a, b; metadata=NO_METADATA, kw...)
Expand All @@ -430,7 +471,8 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
else
coeff = a
dict = b
Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
end
end

Expand Down Expand Up @@ -461,7 +503,7 @@ function maybe_intcoeff(x)
end
end

function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
if T<:Number && !(T<:SafeReal)
n, d = quick_cancel(n, d)
end
Expand Down Expand Up @@ -495,7 +537,8 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
end
end

Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
BasicSymbolic(s)
end

function Div(n,d, simplified=false; kw...)
Expand All @@ -509,14 +552,15 @@ end

@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]

function Pow{T}(a, b; metadata=NO_METADATA) where {T}
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
_iszero(b) && return 1
_isone(b) && return a
Pow{T}(; base=a, exp=b, arguments=[], metadata)
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
BasicSymbolic(s)
end

function Pow(a, b; metadata=NO_METADATA)
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata)
function Pow(a, b; metadata = NO_METADATA, kwargs...)
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; metadata, kwargs...)
end

function toterm(t::BasicSymbolic{T}) where T
Expand Down
84 changes: 84 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SymbolicUtils, Test
using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2

struct Ctx1 end
struct Ctx2 end
Expand All @@ -24,3 +25,86 @@ struct Ctx2 end
xm3 = setmetadata(x1, Ctx2, "meta_2")
@test xm1 !== xm3
end

@syms a b c

@testset "Term" begin
t1 = sin(a)
t2 = sin(a)
@test t1 === t2
t3 = Term(identity,[a])
t4 = Term(identity,[a])
@test t3 === t4
t5 = Term{Int}(identity,[a])
@test t3 !== t5
tm1 = setmetadata(t1, Ctx1, "meta_1")
@test t1 !== tm1
end

@testset "Add" begin
d1 = a + b
d2 = b + a
@test d1 === d2
d3 = b - 2 + a
d4 = a + b - 2
@test d3 === d4
d5 = Add(Int, 0, Dict(a => 1, b => 1))
@test d5 !== d1

dm1 = setmetadata(d1,Ctx1,"meta_1")
@test d1 !== dm1
end

@testset "Mul" begin
m1 = a*b
m2 = b*a
@test m1 === m2
m3 = 6*a*b
m4 = 3*a*2*b
@test m3 === m4
m5 = Mul(Int, 1, Dict(a => 1, b => 1))
@test m5 !== m1

mm1 = setmetadata(m1, Ctx1, "meta_1")
@test m1 !== mm1
end

@testset "Div" begin
v1 = a/b
v2 = a/b
@test v1 === v2
v3 = -1/a
v4 = -1/a
@test v3 === v4
v5 = 3a/6
v6 = 2a/4
@test v5 === v6
v7 = Div{Float64}(-1,a)
@test v7 !== v3

vm1 = setmetadata(v1,Ctx1, "meta_1")
@test vm1 !== v1
end

@testset "Pow" begin
p1 = a^b
p2 = a^b
@test p1 === p2
p3 = a^(2^-b)
p4 = a^(2^-b)
@test p3 === p4
p5 = Pow{Float64}(a,b)
@test p1 !== p5

pm1 = setmetadata(p1,Ctx1, "meta_1")
@test pm1 !== p1
end

@testset "Equivalent numbers" begin
f = 0.5
r = 1 // 2
@test hash(f) == hash(r)
u0 = zero(UInt)
@test hash2(f, u0) != hash2(r, u0)
@test f + a !== r + a
end
Loading