-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from dvdblk/feature/mg-haskell
Feature/mg haskell
- Loading branch information
Showing
13 changed files
with
460 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
*.code-workspace | ||
|
||
# Created by https://www.gitignore.io/api/visualstudiocode | ||
# Edit at https://www.gitignore.io/?templates=visualstudiocode | ||
|
||
### VisualStudioCode ### | ||
.vscode/* # Maybe .vscode/**/* instead - see comments | ||
!.vscode/settings.json | ||
!.vscode/tasks.json | ||
!.vscode/launch.json | ||
!.vscode/extensions.json | ||
|
||
### VisualStudioCode Patch ### | ||
# Ignore all local history of files | ||
**/.history | ||
|
||
# End of https://www.gitignore.io/api/visualstudiocode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"haskell.formattingProvider": "ormolu", | ||
"haskell.manageHLS": "GHCup", | ||
"haskell.serverExecutablePath": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,13 @@ | ||
# micrograd-haskell | ||
|
||
A port of [micrograd](https://github.com/karpathy/micrograd) written in Haskell. | ||
|
||
## Building | ||
``` | ||
$ stack build | ||
``` | ||
|
||
## Testing | ||
``` | ||
$ stack test | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,4 @@ | ||
module Main (main) where | ||
|
||
import Lib | ||
|
||
main :: IO () | ||
main = someFunc | ||
main = putStrLn "temp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | ||
|
||
{-# HLINT ignore "Use -" #-} | ||
module EngineReverse ( | ||
Value (..), | ||
defaultValue, | ||
valueInit, | ||
changeValueOperation, | ||
relu, | ||
EngineReverse.tanh, | ||
incrementGrad, | ||
_backward, | ||
backward | ||
) | ||
where | ||
|
||
import Data.Graph qualified as G | ||
|
||
data Value a = Value | ||
{ -- | Numerical data for this `Value`. | ||
-- Need to use `_data` because 'data' is a reserved keyword | ||
_data :: a | ||
-- | Gradient of the `Value` | ||
, grad :: a | ||
-- | Operation that created this `Value` | ||
, _op :: String | ||
-- | List of previous values that created this `Value` | ||
, _prev :: [Value a] | ||
} | ||
deriving (Show, Eq, Ord) | ||
|
||
-- | defaultValue creates a new `Value` with the given data and no gradient | ||
defaultValue :: Num a => a -> Value a | ||
defaultValue v = Value v 0 "" [] | ||
|
||
-- | valueInit creates a new `Value` with the given data, operation, and ensures unique previous values. | ||
valueInit :: (Eq a, Num a) => a -> String -> [Value a] -> Value a | ||
valueInit v op [x, y] | x == y && op /= "**" = Value v 0 op [x] | ||
| otherwise = Value v 0 op [x, y] | ||
valueInit v op prev = Value v 0 op prev | ||
|
||
-- | changeValueOperation changes the operation of a `Value` to the given operation. | ||
-- It takes one argument of type 'String' and one argument of type `Value`. | ||
changeValueOperation :: String -> Value a -> Value a | ||
changeValueOperation op v = v{_op = op} | ||
|
||
instance (Num a, Ord a) => Num (Value a) where | ||
(+) :: Value a -> Value a -> Value a | ||
v1@(Value x _ _ _) + v2@(Value y _ _ _) = valueInit (x + y) "+" [v1, v2] | ||
(*) :: Value a -> Value a -> Value a | ||
v1@(Value x _ _ _) * v2@(Value y _ _ _) = valueInit (x * y) "*" [v1, v2] | ||
(-) :: Value a -> Value a -> Value a | ||
v1 - v2 = (changeValueOperation "-" . (v1 +) . negate) v2 | ||
negate :: Value a -> Value a | ||
negate v = changeValueOperation "neg" $ defaultValue (-1) * v | ||
|
||
-- | abs is not supported by OG micrograd | ||
abs :: Value a -> Value a | ||
abs v@(Value x _ _ _) = changeValueOperation "abs" $ if x >= 0 then v else negate v | ||
|
||
-- | signum is not supported by OG micrograd | ||
signum :: Value a -> Value a | ||
signum v@(Value x _ _ _) = valueInit (signum x) "signum" [v] | ||
fromInteger :: Integer -> Value a | ||
fromInteger = defaultValue . fromInteger | ||
|
||
instance (Floating a, Ord a) => Fractional (Value a) where | ||
(/) :: Value a -> Value a -> Value a | ||
v1 / v2 = changeValueOperation "/" $ v1 * (v2 ** (-1)) | ||
|
||
-- | recip is not supported by OG micrograd | ||
recip :: Value a -> Value a | ||
recip v = changeValueOperation "recip" $ defaultValue 1 / v | ||
fromRational :: Rational -> Value a | ||
fromRational = defaultValue . fromRational | ||
|
||
instance (Floating a, Ord a) => Floating (Value a) where | ||
(**) :: Value a -> Value a -> Value a | ||
v1@(Value x _ _ _) ** v2@(Value y _ _ _) = valueInit (x ** y) "**" [v1, v2] | ||
|
||
-- | Applies the exponential function to the given `Value`. | ||
exp :: Value a -> Value a | ||
exp v@(Value x _ _ _) = valueInit (Prelude.exp x) "exp" [v] | ||
|
||
-- | functions which are not supported by OG micrograd | ||
pi = undefined | ||
log = undefined | ||
sin = undefined | ||
cos = undefined | ||
asin = undefined | ||
acos = undefined | ||
atan = undefined | ||
sinh = undefined | ||
cosh = undefined | ||
asinh = undefined | ||
acosh = undefined | ||
atanh = undefined | ||
|
||
-- | Applies the rectified linear unit function to the given `Value`. | ||
relu :: (Num a, Ord a) => Value a -> Value a | ||
relu v@(Value x _ _ _) = valueInit (max 0 x) "relu" [v] | ||
|
||
-- | Applies the hyperbolic tangent function to the given `Value`. | ||
tanh :: (Floating a, Ord a) => Value a -> Value a | ||
tanh v@(Value x _ _ _) = valueInit (Prelude.tanh x) "tanh" [v] | ||
|
||
-- | incrementGrad increments the gradient of a `Value` by the given amount. | ||
incrementGrad :: Num a => a -> Value a -> Value a | ||
incrementGrad x v@(Value _ g _ _) = v { grad = g + x } | ||
|
||
-- | _backward updates gradients of the children Values based on the operation | ||
_backward :: (Floating a, Ord a, Show a) => Value a -> Value a | ||
_backward v@(Value _ g "+" [v1, v2]) = v { _prev = [incrementGrad g v1, incrementGrad g v2] } | ||
_backward v@(Value _ g "*" [v1, v2]) = v { _prev = [incrementGrad (g * _data v2) v1, incrementGrad (g * _data v1) v2] } | ||
_backward v@(Value _ g "**" [v1, v2]) = v { _prev = [incrementGrad (_data v2 * _data v1 ** (_data v2 - 1) * g) v1, v2] } | ||
_backward v@(Value _ g "relu" [v1]) = v { _prev = [incrementGrad (g * if _data v1 > 0 then 1 else 0) v1] } | ||
_backward v@(Value _ g "exp" [v1]) = v { _prev = [incrementGrad (_data v1 * g) v1] } | ||
_backward v@(Value _ g "tanh" [v1]) = v { _prev = [incrementGrad ((1 - _data v1 ** 2) * g) v1] } | ||
-- workaround for duplicit values in required binary operations (+), (*) | ||
_backward v@(Value _ g "*" [v1]) = v { _prev = [incrementGrad (2 * g * _data v1) v1] } | ||
_backward v@(Value _ g "+" [v1]) = v { _prev = [incrementGrad (2 * g) v1] } | ||
_backward v@(Value _ _ op prev) = case (op, prev) of | ||
-- redirect negation to multiplication by -1 | ||
("neg", _) -> redirect "neg" "*" v | ||
-- redirect subtraction to addition of negation | ||
("-", _) -> redirect "-" "+" v | ||
-- redirect division to multiplication by reciprocal | ||
("/", _) -> redirect "/" "*" v | ||
-- default case | ||
("", _) -> v | ||
-- catch-all for unsupported operations | ||
(_, _) -> error $ "Invalid operation (" ++ op ++ ") in _backward for prev: " ++ show prev | ||
where | ||
-- | redirect redirects a given operation to another operation and backpropagates the gradient. | ||
redirect from to = changeValueOperation from . _backward . changeValueOperation to | ||
|
||
-- | prevToEdges converts a computation graph from a given `Value` to a list of edges. | ||
prevToEdges :: Value a -> [(Value a, Int, [Int])] | ||
prevToEdges = traverseComputationGraph 0 | ||
where traverseComputationGraph :: Int -> Value a -> [(Value a, Int, [Int])] | ||
-- | If the Value has no previous Values, it is a leaf node and has no edges. | ||
traverseComputationGraph i v@(Value _ _ _ []) = [(v, i, [])] | ||
-- | If the Value has one previous Value, it has one edge. | ||
traverseComputationGraph i v@(Value _ _ _ [v1]) = (v, i, [i + 1]) : traverseComputationGraph (i + 1) v1 | ||
-- | If the Value has two previous Values, it has two edges. | ||
traverseComputationGraph i v@(Value _ _ _ [v1, v2]) = (v, i, [i + 1, i + 2]) : (traverseComputationGraph (i + 1) v1) ++ (traverseComputationGraph (i + 2) v2) | ||
traverseComputationGraph _ _ = error "Invalid computation graph" | ||
|
||
reverseList :: [a] -> [a] | ||
reverseList = foldl (flip (:)) [] | ||
|
||
-- | backward computes the gradient of every `Value` in the computation graph in a topological order. | ||
-- It also sets the gradient of the topmost node to 1. | ||
backward :: (Show a, Floating a, Ord a) => Value a -> Value a | ||
backward v = head . map ((\(val, _, _) -> _backward val) . vertexToNode) . reverseList . G.topSort $ graph | ||
where topNode = v { grad = 1 } | ||
(graph, vertexToNode, _) = G.graphFromEdges . prevToEdges $ topNode | ||
|
||
-- for debugging | ||
-- a = (defaultValue (3 :: Double) * defaultValue 4) * (defaultValue (-5) + defaultValue 16) | ||
-- bw v = map vertexToNode . reverseList . G.topSort $ graph | ||
-- where topNode = v { grad = 1 } | ||
-- (graph, vertexToNode, keyToVertex) = G.graphFromEdges . prevToEdges $ topNode |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
module NN ( | ||
) | ||
where |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
main :: IO () | ||
main = putStrLn "Test suite not yet implemented" | ||
import TestEngineReverse (testsEngineReverse) | ||
|
||
import Test.HUnit | ||
|
||
main :: IO Counts | ||
main = do runTestTT testsEngineReverse |
Oops, something went wrong.