Skip to content

Commit

Permalink
Merge pull request #673 from Blablablanca/hash-consing2
Browse files Browse the repository at this point in the history
Apply hash consing to all `BasicSymbolic` subtypes
  • Loading branch information
ChrisRackauckas authored Nov 30, 2024
2 parents 1af55d4 + d716d3b commit 59e3aa6
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 10 deletions.
64 changes: 54 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
# Call outer constructor because hash consing cannot be applied in inner constructor
@compactified obj::BasicSymbolic begin
Sym => Sym{T}(nt_new.name; nt_new...)
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new...)
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new...)
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new...)
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new...)
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new...)
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
end
end
Expand Down Expand Up @@ -298,6 +303,7 @@ Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymboli

## This is much faster than hash of an array of Any
hashvec(xs, z) = foldr(hash, xs, init=z)
hashvec2(xs, z) = foldr(hash2, xs, init=z)
const SYM_SALT = 0x4de7d7c66d41da43 % UInt
const ADD_SALT = 0xaddaddaddaddadda % UInt
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
Expand Down Expand Up @@ -344,10 +350,43 @@ objects. Unlike `Base.hash`, which only considers the expression structure, `has
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
consing, allowing for more effective deduplication of symbolically equivalent expressions
with different metadata or symtypes.
Equivalent numbers of different types, such as `0.5::Float64` and
`(1 // 2)::Rational{Int64}`, have the same default `Base.hash` value. The `hash2` function
distinguishes these by including their numeric types in the hash calculation to ensure that
symbolically equivalent expressions with different numeric types are treated as distinct
objects.
"""
hash2(s, salt::UInt) = hash(s, salt)
function hash2(n::T, salt::UInt) where {T <: Number}
hash(T, hash(n, salt))
end
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash(metadata(s), hash(T, hash(s, salt)))
E = exprtype(s)
h::UInt = 0
if E === SYM
h = hash(nameof(s), salt SYM_SALT)
elseif E === ADD || E === MUL
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
hv = Base.hasha_seed
for (k, v) in s.dict
hv ⊻= hash2(k, hash(v))
end
h = hash(hv, salt)
h = hash(hashoffset, hash2(s.coeff, h))
elseif E === DIV
h = hash2(s.num, hash2(s.den, salt DIV_SALT))
elseif E === POW
h = hash2(s.exp, hash2(s.base, salt POW_SALT))
elseif E === TERM
op = operation(s)
oph = op isa Function ? nameof(op) : op
h = hashvec2(arguments(s), hash(oph, salt))
else
error_on_type()
end
hash(metadata(s), hash(T, h))
end

###
Expand Down Expand Up @@ -395,7 +434,8 @@ function Term{T}(f, args; kw...) where T
args = convert(Vector{Any}, args)
end

Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
BasicSymbolic(s)
end

function Term(f, args; metadata=NO_METADATA)
Expand All @@ -415,7 +455,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
end
end

Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
end

function Mul(T, a, b; metadata=NO_METADATA, kw...)
Expand All @@ -430,7 +471,8 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
else
coeff = a
dict = b
Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
end
end

Expand Down Expand Up @@ -461,7 +503,7 @@ function maybe_intcoeff(x)
end
end

function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
if T<:Number && !(T<:SafeReal)
n, d = quick_cancel(n, d)
end
Expand Down Expand Up @@ -495,7 +537,8 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
end
end

Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
BasicSymbolic(s)
end

function Div(n,d, simplified=false; kw...)
Expand All @@ -509,14 +552,15 @@ end

@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]

function Pow{T}(a, b; metadata=NO_METADATA) where {T}
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
_iszero(b) && return 1
_isone(b) && return a
Pow{T}(; base=a, exp=b, arguments=[], metadata)
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
BasicSymbolic(s)
end

function Pow(a, b; metadata=NO_METADATA)
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata)
function Pow(a, b; metadata = NO_METADATA, kwargs...)
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; metadata, kwargs...)
end

function toterm(t::BasicSymbolic{T}) where T
Expand Down
84 changes: 84 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using SymbolicUtils, Test
using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2

struct Ctx1 end
struct Ctx2 end
Expand All @@ -24,3 +25,86 @@ struct Ctx2 end
xm3 = setmetadata(x1, Ctx2, "meta_2")
@test xm1 !== xm3
end

@syms a b c

@testset "Term" begin
t1 = sin(a)
t2 = sin(a)
@test t1 === t2
t3 = Term(identity,[a])
t4 = Term(identity,[a])
@test t3 === t4
t5 = Term{Int}(identity,[a])
@test t3 !== t5
tm1 = setmetadata(t1, Ctx1, "meta_1")
@test t1 !== tm1
end

@testset "Add" begin
d1 = a + b
d2 = b + a
@test d1 === d2
d3 = b - 2 + a
d4 = a + b - 2
@test d3 === d4
d5 = Add(Int, 0, Dict(a => 1, b => 1))
@test d5 !== d1

dm1 = setmetadata(d1,Ctx1,"meta_1")
@test d1 !== dm1
end

@testset "Mul" begin
m1 = a*b
m2 = b*a
@test m1 === m2
m3 = 6*a*b
m4 = 3*a*2*b
@test m3 === m4
m5 = Mul(Int, 1, Dict(a => 1, b => 1))
@test m5 !== m1

mm1 = setmetadata(m1, Ctx1, "meta_1")
@test m1 !== mm1
end

@testset "Div" begin
v1 = a/b
v2 = a/b
@test v1 === v2
v3 = -1/a
v4 = -1/a
@test v3 === v4
v5 = 3a/6
v6 = 2a/4
@test v5 === v6
v7 = Div{Float64}(-1,a)
@test v7 !== v3

vm1 = setmetadata(v1,Ctx1, "meta_1")
@test vm1 !== v1
end

@testset "Pow" begin
p1 = a^b
p2 = a^b
@test p1 === p2
p3 = a^(2^-b)
p4 = a^(2^-b)
@test p3 === p4
p5 = Pow{Float64}(a,b)
@test p1 !== p5

pm1 = setmetadata(p1,Ctx1, "meta_1")
@test pm1 !== p1
end

@testset "Equivalent numbers" begin
f = 0.5
r = 1 // 2
@test hash(f) == hash(r)
u0 = zero(UInt)
@test hash2(f, u0) != hash2(r, u0)
@test f + a !== r + a
end

0 comments on commit 59e3aa6

Please sign in to comment.