Skip to content

Commit

Permalink
WIP moshi
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Nov 20, 2024
1 parent 4d86901 commit ef49132
Show file tree
Hide file tree
Showing 14 changed files with 387 additions and 153 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
Expand All @@ -48,6 +48,7 @@ DocStringExtensions = "0.8, 0.9"
DynamicPolynomials = "0.5, 0.6"
IfElse = "0.1"
LabelledArrays = "1.5"
Moshi = "0.3.5"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1"
ReverseDiff = "1"
Expand All @@ -57,7 +58,6 @@ StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.3"

Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using DocStringExtensions

export @syms, term, showraw, hasmetadata, getmetadata, setmetadata

using Unityper
using Moshi.Data: @data, variant_type, variant_name
using Moshi.Match: @match
using TermInterface
using DataStructures
using Setfield
Expand Down
4 changes: 3 additions & 1 deletion src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
symtype, sorted_arguments, metadata, isterm, term, maketerm
isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm
import SymbolicIndexingInterface: symbolic_type, NotSymbolic

##== state management ==##
Expand Down Expand Up @@ -182,6 +182,8 @@ function toexpr(O, st)
if issym(O)
O = substitute_name(O, st)
return issym(O) ? nameof(O) : toexpr(O, st)
elseif isconst(O)
return toexpr(O.val, st)
end
O = substitute_name(O, st)

Expand Down
18 changes: 16 additions & 2 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#
function matcher(val::Any)
iscall(val) && return term_matcher(val)
if isconst(val)
slot = val.val
return matcher(slot)
elseif iscall(val)
return term_matcher(val)
end
function literal_matcher(next, data, bindings)
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
if islist(data)
cd = car(data)
if isconst(cd)
cd = cd.val
end
if isequal(cd, val)
return next(bindings, 1)
end
end
nothing
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ end
for f in [!, ~]
@eval begin
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
(::$(typeof(f)))(s::Symbolic{Bool}) = isconst(s) ? !s.val : Term{Bool}(!, [s])
end
end

Expand Down
7 changes: 5 additions & 2 deletions src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function get_degrees(expr)
elseif iscall(expr)
op = operation(expr)
args = sorted_arguments(expr)
if op == (^) && args[2] isa Number
if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].val isa Number))
return map(get_degrees(args[1])) do (base, pow)
(base => pow * args[2])
end
Expand Down Expand Up @@ -79,12 +79,15 @@ function <ₑ(a::Tuple, b::Tuple)
end

function <(a::BasicSymbolic, b::BasicSymbolic)
isconst(a) && isconst(b) && return a.val <ₑ b.val
isconst(a) && return a.val <ₑ b
isconst(b) && return a <ₑ b.val
da, db = get_degrees(a), get_degrees(b)
fw = monomial_lt(da, db)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...)
else
return _arglen(a) < _arglen(b)
end
Expand Down
1 change: 1 addition & 0 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ end
_isone(p::PolyForm) = isone(p.p)

function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
x = isconst(x) ? x.val : x
if x isa Number
return x
elseif iscall(x)
Expand Down
2 changes: 2 additions & 0 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ function substitute(expr, dict; fold=true)
canfold = !(op isa Symbolic)
args = map(arguments(expr)) do x
x′ = substitute(x, dict; fold=fold)
x′ = isconst(x) ? x′.val : x′
canfold = canfold && !(x′ isa Symbolic)
x′
end
Expand Down Expand Up @@ -54,6 +55,7 @@ function _occursin(needle, haystack)
if iscall(haystack)
args = arguments(haystack)
for arg in args
arg = isconst(arg) ? arg.val : arg
if needle isa Integer || needle isa AbstractFloat
isequal(needle, arg) && return true
else
Expand Down
Loading

0 comments on commit ef49132

Please sign in to comment.