This library builds a general neural network solver on top of the MXNet raw c-apis and operators.
The Symbol
API of MXNet synthesizes a symbolic graph of the neural network. To solve such a graph, it is necessary to back every Symbol
with two NDArray
, one for forwarding propagation and one for backward propagation. By calling the API mxExecutorBind
, the symbolic graph and backing NDArrays are bound together, producing an Executor
. And with this executor, mxExecutorForward
and mxExecutorBackward
can run. By optimization, the backing NDArrays of the neural network is updated in each iteration.
MXNet provides data iterators. And it can be wrapped in a Stream or Conduit. The fei-dataiter provides an implementation.
Module
is a tagged StateT
monad, where the internal state TaggedModuleState
has a hashmap of NDArray
s to back the symbolic graph, together with a few other pieces of information.
initialize :: forall tag dty. (HasCallStack, FloatDType dty)
=> SymbolHandle
-> Config dty
-> IO (TaggedModuleState dty tag)`
It takes the symbolic graph, and a configuration of
- shapes of the placeholder in the training phase.
- how to initialize the NDArrays
It returns an initial state for Module
.
fit :: (FloatDType dty, MonadIO m)
=> M.HashMap Text (NDArray dty)
-> Module tag dty m ()
Given bindings of variables, fit
carries out a complete forward/backward.
forwardOnly :: (FloatDType dty, MonadIO m)
=> M.HashMap Text (NDArray dty)
-> Module tag dty m [NDArray dty]
Given bindings of variables, forwardOnly
carries out a forward phase only, returning the output of the neural network.
fitAndEval :: (FloatDType dty, Optimizer opt, EvalMetricMethod mtr, MonadIO m)
=> opt dty
-> M.HashMap Text (NDArray dty)
-> MetricData mtr dty
-> Module tag dty m ()
Fit the neural network, update gradients, and then evaluate and record metrics.
It is possible to write a training loop with the above Module
. But it is still a bit difficult to connect with other code pieces such as data loading, logging/debugging. The major obstacle is that Module
has a StateT
monad under the hood, which rules out the chance to work in places requiring a MonadUnliftIO
.
Therefore, we embed the Module
's state in a top-level enviornment FeiApp
. An appplication will be written in a ReaderT
monad, and call askSession
when needed to "enter" the Module
monad.
data FeiApp t n x = FeiApp
{ _fa_log_func :: !LogFunc
, _fa_process_context :: !ProcessContext
, _fa_session :: MVar (TaggedModuleState t n)
, _fa_extra :: x
}
initSession :: forall n t m x. (FloatDType t, Feiable m, MonadIO m, MonadReader (FeiApp t n x) m)
=> SymbolHandle -> Config t -> m ()
askSession :: (MonadIO m, MonadReader e m, HasSessionRef e s, Session sess s)
=> sess m r -> m r
There are two pre-made top-level ReaderT
monads. The SimpleFeiM
uses FeiApp
without extra infomation, while NeptFeiM
holds an extra NeptExtra
data structure. As name suggests, NeptFeiM
is augmented with the capability to record logs to netpune.
newtype SimpleFeiM t n a = SimpleFeiM (ReaderT (FeiApp t n ()) (ResourceT IO) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadFail)
newtype NeptFeiM t n a = NeptFeiM (ReaderT (FeiApp t n NeptExtra) (ResourceT IO) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadFail)
The type class Feiable
is then invented to unify the common interface of SimpleFeiM
and NeptFeiM
, and possibility support future extension as well. runFeiM
is supposed to properly initialize mxnet before running action and cleanup before termination.
class Feiable (m :: * -> *) where
data FeiMType m a :: *
runFeiM :: FeiMType m a -> IO a
See the examples of fei-examples repository.