Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class checking performance fixes #240

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import System.FilePath.Lens (extension)
data CompileOpts = CompileOpts
{ _input :: FilePath
, _output :: FilePath
, _debug :: Bool
}
deriving stock (Eq, Show)

Expand All @@ -33,7 +34,7 @@ compile :: CompileOpts -> IO ()
compile opts = do
logInfo "" $ "Reading Compiler Input from " <> (opts ^. input)
compInp <- readCompilerInput (opts ^. input)
let compOut = runCompiler compInp
let compOut = runCompiler (opts ^. debug) compInp
case compOut ^. maybe'error of
Nothing -> do
logInfo (opts ^. input) "Compilation succeeded"
Expand Down
18 changes: 17 additions & 1 deletion lambda-buffers-compiler/app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ import Options.Applicative (
info,
long,
metavar,
option,
prefs,
progDesc,
short,
showDefault,
showHelpOnEmpty,
showHelpOnError,
strOption,
subparser,
value,
)
import Options.Applicative.Builder (auto)

newtype Command = Compile CompileOpts

Expand All @@ -43,8 +47,20 @@ outputPathP =
<> help "File to write the output to (lambdabuffers.compiler.CompilerOutput in .textproto format)"
)

debugP :: Parser Bool
debugP =
option
auto
( long "debug"
<> short 'd'
<> metavar "DEBUG"
<> help "Run everything in debug mode"
<> value False
<> showDefault
)

compileOptsP :: Parser CompileOpts
compileOptsP = CompileOpts <$> inputPathP <*> outputPathP
compileOptsP = CompileOpts <$> inputPathP <*> outputPathP <*> debugP

optionsP :: Parser Command
optionsP =
Expand Down
2 changes: 1 addition & 1 deletion lambda-buffers-compiler/lambda-buffers-compiler.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ library
, prettyprinter
, proto-lens
, text
, unification-fd
, unification-fd >=0.11

exposed-modules:
LambdaBuffers.Compiler
Expand Down
6 changes: 3 additions & 3 deletions lambda-buffers-compiler/src/LambdaBuffers/Compiler.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import LambdaBuffers.ProtoCompat (
import Proto.Compiler qualified as P
import Proto.Compiler_Fields qualified as P

runCompiler :: P.Input -> P.Output
runCompiler compInp = do
runCompiler :: Bool -> P.Input -> P.Output
runCompiler debug compInp = do
case compilerInputFromProto compInp of
Left err -> defMessage & P.error .~ err
Right compInp' -> case KindCheck.runCheck compInp' of
Left err -> defMessage & P.error .~ toProto err
Right _ -> case TyClassCheck.runCheck compInp' of
Right _ -> case TyClassCheck.runCheck debug compInp' of
Left err -> defMessage & P.error .~ err
Right _ -> defMessage & P.result .~ defMessage
2 changes: 2 additions & 0 deletions lambda-buffers-compiler/src/LambdaBuffers/Compiler/MiniLog.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE StrictData #-}

{- | MiniLog is a simple first order syntax encoding of a Prolog-like logic language without non-determinism and backtracking abilities.
It is used to represent LambdaBuffers Type Class rules (`InstanceClause` and `Derive`) and to check for their logical consistency.
-}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE StrictData #-}

-- | unification-fd based solver.
module LambdaBuffers.Compiler.MiniLog.UniFdSolver (solve) where

import Control.Monad (filterM, foldM)
import Control.Monad (filterM, foldM, when, (>=>))
import Control.Monad.Error.Class (MonadError (catchError, throwError))
import Control.Monad.Except (ExceptT, runExceptT)
import Control.Monad.Reader (MonadReader (local), ReaderT (runReaderT), asks)
Expand All @@ -11,6 +13,7 @@ import Control.Unification (Fallible, Unifiable (zipMatch))
import Control.Unification qualified as U
import Control.Unification qualified as Unif
import Control.Unification.IntVar (IntBindingT, IntVar, runIntBindingT)
import Control.Unification.IntVar qualified as U
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (isJust)
Expand Down Expand Up @@ -49,15 +52,16 @@ type UTerm fun atom = U.UTerm (Term' fun atom) IntVar
type Scope fun atom = Map ML.VarName IntVar

-- | A clause context consists of available clauses (knowledge base) and a trace of all the called goals that led up to it.
data ClauseContext fun atom = MkClauseCtx
{ cctxTrace :: [UTerm fun atom]
, cctxClauses :: [ML.Clause fun atom]
data UContext fun atom = UContext
{ uCtx'trace :: [UTerm fun atom]
, uCtx'clauses :: [ML.Clause fun atom]
, uCtx'doTracing :: Bool
}
deriving stock (Show)

type UniM fun atom a =
ReaderT
(ClauseContext fun atom)
(UContext fun atom)
( ExceptT
(UError fun atom)
( IntBindingT
Expand All @@ -68,16 +72,17 @@ type UniM fun atom a =
a

runUniM ::
Bool ->
[ML.Clause fun atom] ->
UniM fun atom a ->
(Either (UError fun atom) a, [ML.MiniLogTrace fun atom])
runUniM clauses p =
let (errOrRes, logs) = runWriter . runIntBindingT . runExceptT . (`runReaderT` MkClauseCtx mempty clauses) $ p
runUniM doTracing clauses p =
let (errOrRes, logs) = runWriter . runIntBindingT . runExceptT . (`runReaderT` UContext mempty clauses doTracing) $ p
in (fst errOrRes, logs)

-- | Implements `ML.MiniLogSolver`.
solve :: (Show fun, Show atom) => ML.MiniLogSolver fun atom
solve clauses goals = case runUniM clauses (top goals) of
solve :: (Show fun, Show atom) => Bool -> ML.MiniLogSolver fun atom
solve doTracing clauses goals = case runUniM doTracing clauses (top goals) of
(Left err, logs) -> case err of
MLError mlErr -> (Left mlErr, logs)
other -> (Left $ ML.InternalError . Text.pack . show $ other, logs)
Expand All @@ -89,37 +94,64 @@ solve clauses goals = case runUniM clauses (top goals) of
top :: (Eq fun, Eq atom, Show fun, Show atom) => [ML.Term fun atom] -> UniM fun atom (Map ML.VarName (ML.Term fun atom))
top goals = do
(goals', scope) <- interpretTerms mempty goals
-- I think these goals don't need to be forced as they are already ground
_ <- solveGoal `traverse` goals'
(fromUTerm . U.UVar) `traverse` scope
-- Resolve scope (applyBindings to get variables resolved to ground terms)
traverse (fromUTerm . U.UVar) scope

-- | Solving a goal means looking up a matching clause and `callClause` on it with the given goal as the argument.
{- | Solving a goal means looking up a matching clause and `callClause` on it with the given goal as the argument.
WARN(bladyjoker): This expects a forced goal!!!
-}
solveGoal :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (UTerm fun atom)
solveGoal goal' = do
-- TODO(bladyjoker): Reach out to the author for this issue.
-- WARN(bladyjoker): Needed to resolve from UVar otherwise we try to do a lookup with a variable that everything unifies with.
goal <- force goal'
mlGoal <- fromUTerm goal
trace $ ML.SolveGoal mlGoal
clause <- lookupClause goal
mayAncestor <- checkCycle goal
solveGoal goal'forced = do
traceSolveGoal goal'forced
clause <- lookupClause goal'forced
mayAncestor <- checkCycle goal'forced
retGoal <- case mayAncestor of
Nothing -> local (\r -> r {cctxTrace = goal : cctxTrace r}) (callClause clause goal)
Just ancestorGoal -> ancestorGoal `unify` goal
trace $ ML.DoneGoal mlGoal
Nothing -> do
local (\r -> r {uCtx'trace = goal'forced : uCtx'trace r}) (callClause clause goal'forced)
Just ancestorGoal -> return ancestorGoal
traceDoneGoal goal'forced
return retGoal

{- | Given a unifiable term and the knowledge base (clauses) find a next `MiniLog.Clause` to `callClause` on.
This is a delicate operation, the search simply tries to unify the heads of `MiniLog.Clause`s with the given goal.
However, before unifying with the goal, `duplicateTerm` is used to make sure the original goal variables are not affected (unified on) by the search.
WARN(bladyjoker): This expects a forced goal!!!
-}
lookupClause :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (ML.Clause fun atom)
lookupClause goal'forced = do
traceLookupClause goal'forced
clauses <- asks uCtx'clauses
matched <-
filterM
( \cl -> do
clauseHead' <- toUTerm $ ML.clauseHead cl
goal' <- duplicateTerm goal'forced
catchError
(goal' `unify` clauseHead' >> return True)
( \case
MismatchFailure _ _ -> return False
err -> throwError err
)
)
clauses
case matched of
[] -> fromUTerm goal'forced >>= throwError . MLError . ML.MissingClauseError
[clause] -> traceFoundClause goal'forced clause >> return clause
overlaps -> fromUTerm goal'forced >>= throwError . MLError . ML.OverlappingClausesError overlaps

{- | In functional speak, this is like a function call (application), where clause is a function and a goal is the argument.
We simply `interpretClause` and unify the given argument with the head of the clause.
After that we proceed to call all sub-goals in the body of the clause.
-}
callClause :: (Eq fun, Eq atom, Show fun, Show atom) => ML.Clause fun atom -> UTerm fun atom -> UniM fun atom (UTerm fun atom)
callClause clause arg = do
mlArg <- fromUTerm arg
trace $ ML.CallClause clause mlArg
traceCallClause clause arg
(clauseHead', clauseBody') <- interpretClause clause
retGoal <- clauseHead' `unify` arg
_ <- solveGoal `traverse` clauseBody'
trace $ ML.DoneClause clause mlArg
_ <- (force >=> solveGoal) `traverse` clauseBody'
traceDoneClause clause arg
return retGoal

{- | Checks if the supplied goal was already visited.
Expand All @@ -128,17 +160,20 @@ callClause clause arg = do
- https://www.swi-prolog.org/pldoc/doc/_SWI_/library/coinduction.pl
- https://personal.utdallas.edu/~gupta/courses/acl/2021/other-papers/colp.pdf
- https://arxiv.org/pdf/1511.09394.pdf

WARN(bladyjoker): This expects a forced goal!!!
-}
checkCycle :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (Maybe (UTerm fun atom))
checkCycle goal = do
visitedGoals <- asks cctxTrace
checkCycle goal'forced = do
visitedGoals <- asks uCtx'trace -- THESE ARE ALL FORCED
-- WARN(bladyjoker): Because of variable sharing, this has to be forced otherwise you'd yield a variable here that's not real bound by any `unify` that were applied to it beforehand
foldM
( \mayCycle visited -> do
if isJust mayCycle
then return mayCycle
else do
visited' <- duplicateTerm visited
goal' <- duplicateTerm goal
goal' <- duplicateTerm goal'forced
catchError
( do
_ <- goal' `unify` visited'
Expand All @@ -152,52 +187,26 @@ checkCycle goal = do
Nothing
visitedGoals

{- | Given a unifiable term and the knowledge base (clauses) find a next `MiniLog.Clause` to `callClause` on.
This is a delicate operation, the search simply tries to unify the heads of `MiniLog.Clause`s with the given goal.
However, before unifying with the goal, `duplicateTerm` is used to make sure the original goal variables are not affected (unified on) by the search.
-}
lookupClause :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (ML.Clause fun atom)
lookupClause goal = do
-- WARN(bladyjoker): Goal has to be `force`d.
mlGoal <- fromUTerm goal
trace $ ML.LookupClause mlGoal
clauses <- asks cctxClauses
matched <-
filterM
( \cl -> do
clauseHead' <- toUTerm $ ML.clauseHead cl
goal' <- duplicateTerm goal
catchError
(goal' `unify` clauseHead' >> return True)
( \case
MismatchFailure _ _ -> return False
err -> throwError err
)
)
clauses
case matched of
[] -> throwError . MLError . ML.MissingClauseError $ mlGoal
[clause] -> trace (ML.FoundClause mlGoal clause) >> return clause
overlaps -> throwError . MLError . ML.OverlappingClausesError overlaps $ mlGoal

{- | Duplicate a unifiable term (basically copies the structure and instantiates new variables).
See https://www.swi-prolog.org/pldoc/doc_for?object=duplicate_term/2.
-}
duplicateTerm :: (Eq fun, Eq atom) => UTerm fun atom -> UniM fun atom (UTerm fun atom)
duplicateTerm (U.UVar _) = freeVar
duplicateTerm (U.UVar _) = freeVar -- WARN(bladyjoker): The UTerm must be `forced` before calling this, otherwise you'd just copy a variable that's not bound to the original (and all its unifications)
duplicateTerm at@(U.UTerm (Atom' _)) = return at
duplicateTerm (U.UTerm (Struct' f args)) = U.UTerm . Struct' f <$> (duplicateTerm `traverse` args)

{- | Turn a unifiable term back into it's original MiniLog.Term.
The term needs to be `force`d otherwise you'd get back a variable.
For showing/debugging/testing purposes.
-}
fromUTerm' :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (ML.Term fun atom)
fromUTerm' (U.UVar v) = return $ ML.Var $ Text.pack $ show (U.getVarID v)
fromUTerm' (U.UTerm (Atom' at)) = return $ ML.Atom at
fromUTerm' (U.UTerm (Struct' f args)) = ML.Struct f <$> (fromUTerm' `traverse` args)

fromUTerm :: (Eq fun, Eq atom, Show fun, Show atom) => UTerm fun atom -> UniM fun atom (ML.Term fun atom)
fromUTerm t = force t >>= fromUTerm'
fromUTerm uv@(U.UVar _) = do
-- WARN(bladyjoker): Because of variable sharing, this has to be forced otherwise you'd yield a variable here that's not real bound by any `unify` that were applied to it beforehand
uv'forced <- force uv
case uv'forced of
U.UVar v -> return $ ML.Var $ Text.pack $ show (U.getVarID v)
other -> fromUTerm other
fromUTerm (U.UTerm (Atom' at)) = return $ ML.Atom at
fromUTerm (U.UTerm (Struct' f args)) = ML.Struct f <$> (fromUTerm `traverse` args)

-- | Turn a `MiniLog.Term` into a unifiable term.
toUTerm :: (Eq fun, Eq atom) => ML.Term fun atom -> UniM fun atom (UTerm fun atom)
Expand Down Expand Up @@ -247,10 +256,54 @@ interpretTerm scope (ML.Atom at) = return (U.UTerm $ Atom' at, scope)
debug :: Show a => a -> UniM fun atom ()
debug = trace . ML.InternalTrace . show

traceSolveGoal :: (Eq atom, Eq fun, Show fun, Show atom) => UTerm fun atom -> UniM fun atom ()
traceSolveGoal goal = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm goal
trace $ ML.SolveGoal mlGoal

traceDoneGoal :: (Eq atom, Eq fun, Show fun, Show atom) => UTerm fun atom -> UniM fun atom ()
traceDoneGoal goal = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm goal
trace $ ML.DoneGoal mlGoal

traceLookupClause :: (Eq atom, Eq fun, Show fun, Show atom) => UTerm fun atom -> UniM fun atom ()
traceLookupClause goal = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm goal
trace $ ML.LookupClause mlGoal

traceFoundClause :: (Eq atom, Eq fun, Show fun, Show atom) => UTerm fun atom -> ML.Clause fun atom -> UniM fun atom ()
traceFoundClause goal clause = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm goal
trace (ML.FoundClause mlGoal clause)

traceCallClause :: (Eq atom, Eq fun, Show fun, Show atom) => ML.Clause fun atom -> UTerm fun atom -> UniM fun atom ()
traceCallClause clause arg = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm arg
trace (ML.CallClause clause mlGoal)

traceDoneClause :: (Eq atom, Eq fun, Show fun, Show atom) => ML.Clause fun atom -> UTerm fun atom -> UniM fun atom ()
traceDoneClause clause arg = do
doTracing <- asks uCtx'doTracing
when doTracing $ do
mlGoal <- fromUTerm arg
trace (ML.DoneClause clause mlGoal)

trace :: ML.MiniLogTrace fun atom -> UniM fun atom ()
trace x = lift . lift . lift $ tell [x]
trace x = do
doTracing <- asks uCtx'doTracing
when doTracing $ lift . lift . lift $ tell [x]

force :: (Eq fun, Eq atom) => UTerm fun atom -> UniM fun atom (UTerm fun atom)
force :: (Eq atom, Eq fun) => UTerm fun atom -> UniM fun atom (UTerm fun atom)
force = lift . U.applyBindings

unify :: (Eq fun, Eq atom, Show atom, Show fun) => UTerm fun atom -> UTerm fun atom -> UniM fun atom (UTerm fun atom)
Expand All @@ -263,6 +316,6 @@ freeVar = U.UVar <$> freeVar'

freeVar' :: (Eq fun, Eq atom) => UniM fun atom IntVar
freeVar' = do
v <- lift . lift $ Unif.freeVar
debug ("new var" :: String, v)
v@(U.IntVar i) <- lift . lift $ Unif.freeVar
debug ("new var" :: String, i)
return v
Loading