diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index bf6048180..914a786fe 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -45,4 +45,5 @@ pages: - Symbolic API: api/symbolic-node.md - Neural Networks Factory: api/nn-factory.md - Executor: api/executor.md + - Modules: api/modules.md - Network Visualization: api/visualize.md diff --git a/docs/src/api/modules.md b/docs/src/api/modules.md new file mode 100644 index 000000000..6bb3a4b53 --- /dev/null +++ b/docs/src/api/modules.md @@ -0,0 +1,6 @@ +# Modules + +```@autodocs +Modules = [MXNet.mx] +Pages = ["module/Module.jl"] +``` diff --git a/src/MXNet.jl b/src/MXNet.jl index b9de52a58..192f6ed51 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -1,4 +1,4 @@ -__precompile__() +#__precompile__() module MXNet @@ -37,6 +37,8 @@ include("kvstore.jl") include("callback.jl") include("model.jl") +include("executor-group.jl") +include("module/Module.jl") include("visualize.jl") diff --git a/src/executor-group.jl b/src/executor-group.jl new file mode 100644 index 000000000..8d86f2084 --- /dev/null +++ b/src/executor-group.jl @@ -0,0 +1,426 @@ +""" + AbstractExecutorGroup +Executor group is a convenient tool for managing a group of executors. +""" +abstract AbstractExecutorGroup + +function forward(self::AbstractExecutorGroup, data_provider :: AbstractDataProvider, + data_batch :: AbstractDataBatch, is_train) + throw(MethodError(forward, (self, ))) +end + +""" + DataParallelExecutorGroup + +Supports: + - Fixed parameters (freezing) + - Shape inference + - Type inference +""" +type DataParallelExecutorGroup <: AbstractExecutorGroup + symbol :: SymbolicNode + context :: Vector{Context} + execs :: Vector{Executor} + + data_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + label_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + + for_training :: Bool + slices :: Vector{UnitRange{Int}} + batch_size :: Int + + shared_group :: Nullable{DataParallelExecutorGroup} + inputs_need_grad :: Bool + fixed_param_names :: Nullable{Vector{Symbol}} + grad_req :: Dict{Symbol, GRAD_REQ} + freeze_idx + + data_arrays :: Vector{Vector{SlicedNDArray}} + label_arrays :: Vector{Vector{SlicedNDArray}} + param_arrays :: Vector{Vector{NDArray}} + grad_arrays :: Vector{Vector{NDArray}} + aux_arrays :: Vector{Vector{NDArray}} + input_grad_arrays :: Vector{Vector{NDArray}} + + arg_params :: Dict{Symbol, NDArray} + aux_params :: Dict{Symbol, NDArray} + param_names :: Vector{Symbol} + aux_names :: Vector{Symbol} +end + +function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context}, + data_shapes, data_names, data_types, label_shapes, label_names, label_types, + for_training, inputs_need_grad, shared_group, fixed_param_names, grad_req) + + num_dev = length(context) + arg_names = list_arguments(symbol) + input_names = [data_names; label_names] + param_names = setdiff(arg_names, input_names) + aux_names = list_auxiliary_states(symbol) + + batch_size = data_shapes[1][end] + for shape in data_shapes + @assert batch_size == shape[end] + end + if !isempty(label_shapes) + for shape in label_shapes + @assert batch_size == shape[end] + end + end + + # TODO implement workload + slices = _split_inputs(batch_size, num_dev) + + execs = Vector{Executor}(num_dev) + + # Shape inference based on data_shapes and label_shapes + provided_shapes = merge( + Dict(name => shape for (name, shape) in zip(data_names, data_shapes)), + Dict(name => shape for (name, shape) in zip(label_names, label_shapes)) + ) + + # Run shape inference globally + arg_shapes, out_shapes, aux_shapes = infer_shape(symbol, provided_shapes) + @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") + + # Type inference based on data_types and lable_types + provided_types = merge( + Dict(name => T for (name, T) in zip(data_names, data_types)), + Dict(name => T for (name, T) in zip(label_names, label_types)) + ) + + arg_types, out_types, aux_types = infer_type(symbol, provided_types) + + # Check for what arg we needs gradients and which are frozen + grad_req, freeze_idx = get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) + + arg_params = Dict{Symbol, NDArray}() + aux_params = Dict{Symbol, NDArray}() + + for (name, shape, T) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes, arg_types)) + arg_params[name] = empty(T, shape) + end + + for (name, shape, T) in zip(aux_names, aux_shapes, aux_types) + aux_params[name] = empty(T, shape) + end + + dev_shapes(shapes, slice) = (tuple(shape[1:end-1]..., slice) for shape in shapes) + + for i = 1:num_dev + slice = length(slices[i]) + # Shape inference based on data_shapes and label_shapes per device + provided_shapes_dev = merge( + Dict(name => shape for (name, shape) in zip(data_names, dev_shapes(data_shapes, slice))), + Dict(name => shape for (name, shape) in zip(label_names, dev_shapes(label_shapes, slice))) + ) + + # Run shape inference locally (per-device) + arg_shapes_dev, out_shapes_dev, aux_shapes_dev = infer_shape(symbol, provided_shapes_dev) + @assert(!isa(arg_shapes_dev, Void), "Information not enough to perform complete shape inference") + + arg_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(arg_shapes_dev, arg_types)] + aux_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(aux_shapes_dev, aux_types)] + + # Process arguments to create gradient arrays + grad_arrays = Dict{Symbol,NDArray}() + arg_info = zip(arg_names, arg_shapes_dev, arg_types) + + # if not in provided data, should be parameters + if inputs_need_grad + provided_data_names = label_names + else + provided_data_names = [data_names; label_names] + end + arg_info = filter(x -> !in(x[1], provided_data_names), arg_info) + + # Remove all gradients for nop params + arg_info = filter(x -> grad_req[x[1]] != GRAD_NOP, arg_info) + + for (name, shape, T) in arg_info + grad_arrays[name] = zeros(T, shape, context[i]) + end + + execs[i] = bind(symbol, context[i], arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) + end + + # set up input data structures + data_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in data_names] + label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in label_names] + + param_idx = filter(i -> in(arg_names[i], param_names), 1:length(arg_names)) + name_idx = filter(i -> in(arg_names[i], data_names), 1:length(arg_names)) + + param_arrays = [NDArray[exec.arg_arrays[i] for exec in execs] for i in param_idx] + grad_arrays = [NDArray[exec.grad_arrays[i] for exec in execs] for i in param_idx] + aux_arrays = [NDArray[exec.aux_arrays[i] for exec in execs] for i = 1:length(aux_names)] + + if inputs_need_grad + input_grad_arrays = [NDArray[exec.grad_arrays[i] for exec in execs] for i in name_idx] + else + input_grad_arrays = [] + end + + data_shapes = Dict(name => shape for (name, shape) in zip(data_names, data_shapes)) + label_shapes = Dict(name => shape for (name, shape) in zip(label_names, label_shapes)) + + return DataParallelExecutorGroup( + symbol, context, execs, + data_shapes, label_shapes, for_training, slices, batch_size, + shared_group, inputs_need_grad, fixed_param_names, grad_req, freeze_idx, + data_arrays, label_arrays, param_arrays, grad_arrays, aux_arrays, + input_grad_arrays, arg_params, aux_params, param_names, aux_names) +end + +""" + forward(exec_group, data_batch, is_train) +Split `data_batch` according to workload and run forward on each devices. +# Arguments +* `data_batch` : AbstractDataBatch +* `is_train` : `Bool` + The hint for the backend, indicating whether we are during training phase. + Default is `nothing`, then the value `self.for_training` will be used. +""" +function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool = self.for_training) + + load_data!(data_provider, data_batch, self.data_arrays) + + if is_train && !isempty(get_label(data_provider, data_batch)) + load_label!(data_provider, data_batch, self.label_arrays) + end + + for exec in self.execs + forward(exec, is_train=is_train) + end + # TODO add callbacks here +end + +# TODO Add description +backward(self::DataParallelExecutorGroup, out_grads::Void) = backward(self, NDArray[]) +backward(self::DataParallelExecutorGroup, out_grads::NDArray) = backward(self, [out_grads]) +function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}) + @assert(self.for_training, "re-bind with for_training=true to run backward") + + for (i, exec) in enumerate(self.execs) + out_grad_slices = NDArray[] + for grad in out_grads + push!(out_grad_slices, copy(grad, self.context[i])) + end + backward(exec, out_grad_slices) + end +end + +""" + set_params!(self::DataParallelExecutorGroup, arg_params, aux_params; allow_extra_params) + +Assign, i.e. copy parameters to all the executors. +# Arguments +* `arg_params` : `Dict{Symbol, NDArray}` + A dictionary of name to `NDArray` parameter mapping. +* `aux_params` : `Dict{Symbol, NDArray}` + A dictionary of name to `NDArray` auxiliary variable mapping. +* `allow_extra_params`: `Bool`, default `false`, allow parameters in `arg_params` or `aux_params` that not exists in `self`. +""" +function set_params!(self::DataParallelExecutorGroup, + arg_params, aux_params; allow_extra_params::Bool = false) + for exec in self.execs + copy_params_from(exec, arg_params, aux_params, allow_extra_params=allow_extra_params) + end +end + +## +# Utility +## + +update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::Void = nothing) = update_params(self, updater, update_on_kvstore, Nullable{KVStore}()) +update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::KVStore) = update_params(self, updater, update_on_kvstore, Nullable(kvstore)) +function update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::Nullable{KVStore}) + num_dev = length(self.context) + for idx = 1:length(self.param_names) + #= if isa(self.grad_arrays[i][1], Void) =# + #= continue =# + #= end =# + if in(idx, self.freeze_idx) + continue # Skip parameter update entirely + end + if !isnull(kvstore) + kvstore = get(kvstore) + # push gradient, priority is negative index + push!(kvstore, idx, self.param_arrays[idx], priority=-idx) + if update_on_kvstore + # pull back the weights + pull!(kvstore, idx, self.param_arrays[idx], priority=-idx) + else + # pull back the sum-ed gradients, to the same locations + pull!(kvstore, idx, self.grad_arrays[idx], priority=-idx) + end + end + + if !update_on_kvstore + # manual updating + for i_dev = 1:num_dev + # create a fake index, so that the updater create states + # for different param AND different devices, TODO(mli) + # use a better solution later + fake_idx = idx * num_dev + i_dev + get(updater)(fake_idx, self.grad_arrays[idx][i_dev], self.param_arrays[idx][i_dev]) + end + end + end +end + +""" + get_params!(self, arg_params, aux_params) + +Copy data from each executor to `arg_params` and `aux_params`. +# Arguments +* `arg_params` : Dict{Symbol, Vector{NDArray}}. Target parameter arrays +* `aux_params` : Dict{Symbol, Vector{NDArray}}. Target aux arrays + +# Notes +This function will inplace update the NDArrays in arg_params and aux_params. +""" +function get_params!(self::DataParallelExecutorGroup, arg_params::Dict{Symbol, NDArray}, + aux_params::Dict{Symbol, NDArray}) + for (name, block) in zip(self.param_names, self.param_arrays) + w = empty(size(block[1])) + for i in 1:length(block) + @inplace w .+= copy(block[i], cpu()) + end + @inplace w ./= length(block) + copy!(arg_params[name], w) + end + for (name, block) in zip(self.aux_names, self.aux_arrays) + w = empty(size(block[1])) + for i in 1:length(block) + @inplace w .+= copy(block[i], cpu()) + end + @inplace w ./= length(block) + copy!(aux_params[name], w) + end +end + +""" + update_metric + +Accumulate the performance according to `eval_metric` on all devices. +# Parameters +* eval_metric : EvalMetric + The metric used for evaluation. +* labels : list of NDArray + Typically comes from `label` of a `DataBatch`. +""" +function update_metric(self::DataParallelExecutorGroup, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + + # XXX: there is a possibility, that label arrays lie in different + # context than cpu_output_arrays. It should be checked and labels + # should be copied to corresponding context + cpu_output_arrays = get_outputs(self) + update!(eval_metric, get_label(provider, batch), cpu_output_arrays) +end + +""" + get_outputs + +Get outputs of the previous forward computation. + +# Arguments +merge_multi_context : Bool +Default is `True`. In the case when data-parallelism is used, the outputs +will be collected from multiple devices. A `True` value indicate that we +should merge the collected results so that they look like from a single +executor. +# Returns +If `merge_multi_context` is `true`, it is like `[out1, out2]`. Otherwise, it +is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool=true) + outputs = [[exec.outputs[i] for exec in self.execs] for i in 1:length(self.execs[1].outputs)] + + if merge_multi_context + # TODO In original FeedForward model single predefined + # output was used. _merge_multi_context creates new array + # each time it is called. Need to benchmark, may be it's better + # to predefine cpu_output_arrays in self. + return [concatenate(tensors, always_copy=false) for tensors in outputs] + else + return outputs + end +end + +""" + get_input_grads(self, merge_multi_context) + +Get the gradients with respect to the inputs of the module. + +# Arguments +* `merge_multi_context` : `Bool` + Default is `true`. In the case when data-parallelism is used, the outputs + will be collected from multiple devices. A `true` value indicate that we + should merge the collected results so that they look like from a single + executor. + +# Returns +If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it +is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_input_grads(self::DataParallelExecutorGroup, merge_multi_context::Bool=true) + !self.inputs_need_grad && NDArray[] + + if merge_multi_context + return [concatenate(tensors, always_copy=false) for tensors in self.input_grad_arrays] + end + + return self.input_grad_arrays +end + +function output_shapes(self:: DataParallelExecutorGroup) + outputs = [size(out) for out in self.execs[1].outputs] + return Dict(key => shape for (key, shape) in zip(list_outputs(self.symbol), outputs)) +end + +## +# Internals +## + +function get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) + if isnull(fixed_param_names) + # get grad attribute to allow for freezing + fixed_param_names = Symbol[] + for (attr, value) in list_all_attr(symbol) + sattr = string(attr) + if endswith(sattr, "grad") && value == "freeze" + push!(fixed_param_names, Symbol(sattr[1:end-5])) + end + end + else + fixed_param_names = get(fixed_param_names) + end + + # Needs to correspond to the correct id in the update loop layer idx=1:length(param_names). + freeze_idx = filter(i -> in(param_names[i], fixed_param_names), 1:length(param_names)) + + # Setup grad_req as a dictionary + grad_req_dict = Dict{Symbol, GRAD_REQ}() + for param in arg_names + if param in param_names + if in(param, fixed_param_names) + grad_req_dict[param] = GRAD_NOP + else + grad_req_dict[param] = grad_req + end + elseif param in data_names + if inputs_need_grad + grad_req_dict[param] = grad_req + else + grad_req_dict[param] = GRAD_NOP + end + else + grad_req_dict[param] = GRAD_NOP + end + end + + return grad_req_dict, freeze_idx +end diff --git a/src/executor.jl b/src/executor.jl index 3ae5301a6..077491216 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -152,7 +152,7 @@ function simple_bind(self :: SymbolicNode, ctx :: Context; end end - aux_arrays = [zeros(shape, ctx) for shape in aux_shapes] + aux_arrays = NDArray[zeros(shape, ctx) for shape in aux_shapes] return bind(self, ctx, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) end @@ -179,8 +179,7 @@ function backward(self :: Executor, out_grads :: Vector{NDArray}) end -function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray}, - aux_params::Union{Void,Dict{Base.Symbol,NDArray}}=nothing; +function copy_params_from(self::Executor, arg_params, aux_params=nothing; allow_extra_params::Bool=false) for (name, array) in arg_params if haskey(self.arg_dict, name) @@ -201,7 +200,6 @@ function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray}, end end - """ Get a debug string about internal execution plan. diff --git a/src/io.jl b/src/io.jl index f65314e67..218827247 100644 --- a/src/io.jl +++ b/src/io.jl @@ -17,6 +17,9 @@ Normally this involves defining: """ abstract AbstractDataProvider +type StubProvider <: AbstractDataProvider +end + """ get_batch_size(provider) -> Int @@ -118,10 +121,37 @@ type DataBatch <: AbstractDataBatch label :: Vector{NDArray} count :: Int end +DataBatch(provider::AbstractDataProvider, batch::AbstractDataBatch) = + DataBatch(get_data(provider, batch), get_label(provider, batch), count_samples(provider, batch)) +DataBatch(batch::DataBatch) = DataBatch(batch.data, batch.label, batch.count) count_samples(batch :: DataBatch) = batch.count get_data{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.data get_label{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.label +type DataBatchProvider <: AbstractDataProvider + provider :: AbstractDataProvider +end + +eachdatabatch(provider :: AbstractDataProvider) = DataBatchProvider(provider) + +function Base.eltype(provider :: DataBatchProvider) + DataBatch +end +function Base.start(provider :: DataBatchProvider) + return Base.start(provider.provider) +end +function Base.next(provider :: DataBatchProvider, state :: AbstractDataProviderState) + (inner_batch, next_state) = Base.next(provider.provider, state) + batch = DataBatch(get_data(provider.provider, inner_batch), + get_label(provider.provider, inner_batch), + count_samples(provider.provider, inner_batch)) + + return (batch, next_state) +end +function Base.done(provider :: DataBatchProvider, state :: AbstractDataProviderState) + return Base.done(provider.provider, state) +end + """ SlicedNDArray diff --git a/src/module/Module.jl b/src/module/Module.jl new file mode 100644 index 000000000..5bacf04b2 --- /dev/null +++ b/src/module/Module.jl @@ -0,0 +1,545 @@ +module Module +import ....MXNet: mx +import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider, + SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, + GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, + KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, + AbstractOptimizer, get_updater, update_params, provide_data, + provide_label, AbstractEvalMetric, StubProvider, init, copy!, + concatenate, eachdatabatch, reset!, Accuracy, @defstruct, AbstractInitializer + +""" + AbstractModule + +A module represnets a computation component. The design purpose of a module is +that abstracts a computation unit, that one can run forward, backward, update parameters, etc. +We aim to make the APIs easy to use, especially in the case when we need to use +an imperative API to work with multiple modules (e.g. stochastic depth networks). + +A module has several states: + +* Initial state. Memory is not allocated yet, not ready for computation. +* Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated. +* Parameter initialized. For modules with parameters, doing computation before intitializing + the parameters might result in undefined outputs. +* Optimizer installed. An optimizer can be installed to a module. After this, the parameters. + of the module can be updated according to the optimizers after gradients are computed + (forward-backward). + +In order for a module to interact with others, a module should be able to report the following +information in its raw stage (before binded): + +* [`data_names`](@ref): Names of required data. +* [`output_names`](@ref): Names of the defined outputs. + +And also the following richer information after being binded: + +* State information: + * [`isbinded`](@ref): indicating whether the memory buffers needed for computation + have been allocated. + * [`allows_training`](@ref): whether the module is binded for training (if binded). + * [`isinitialized`](@ref): indicating whether the parameters of this module have + been initialized. + * [`hasoptimizer`](@ref): indicating wherger an optimizers is defined and intialized. +* Input/Output information: + * [`data_shapes`](@ref): + * [`label_shapes`](@ref): + * [`output_shapes`](@ref): +* Parameters (for modules with parameters) + * [`get_params`](@ref): + * [`set_params`](@ref): + * [`init_params`](@ref): +* Setup: + * [`bind`](@ref): + * [`init_optimizer`](@ref): +* Computation: + * [`forward`](@ref): + * [`backward`](@ref): + * [`update!`](@ref): + * [`get_outputs`](@ref): + * [`get_input_grads`](@ref): + * [`update_metric`](@ref): +* Optional: + * [`get_symbol`](@ref): Access the associated `SymbolicNode` if the module has one. + The returned value needs not to be constant (or defined) + +Based on the underlying API a high-level API is implemented: +* [`fit`](@ref): +* [`predict`](@ref): +* [`score`](@ref): +* [`forward_backward`](@ref): +""" +abstract AbstractModule + +## +# Names +## +""" + data_names(self::AbstractModule) -> Vector{Symbol} +""" +function data_names(self::AbstractModule) + throw(MethodError(data_names, (self,))) +end + +""" + output_names(self::AbstractModule) -> Vector{Symbol} +""" +function output_names(self::AbstractModule) + throw(MethodError(output_names, (self,))) +end + +## +# State information +## + +""" + isbinded(self::AbstractModule) -> Bool +""" +function isbinded(self::AbstractModule) + throw(MethodError(isbinded, (self,))) +end + +""" + allows_training(self::AbstractModule) -> Bool +""" +function allows_training(self::AbstractModule) + throw(MethodError(allows_training, (self,))) +end + +""" + isinitialized(self::AbstractModule) -> Bool +""" +function isinitialized(self::AbstractModule) + throw(MethodError(isinitialized, (self,))) +end + +""" + hasoptimizer(self::AbstractModule) -> Bool +""" +function hasoptimizer(self::AbstractModule) + throw(MethodError(hasoptimizer, (self,))) +end + +## +# Input/Output information +## + +""" + data_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the data inputs to this module. +""" +function data_shapes(self :: AbstractModule) + throw(MethodError(data_shapes, (self,))) +end + +""" + label_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the label inputs to this module. +If this module does not accept labels -- either it is a module without loss function, +or it is not binded for training, then this should return an empty Dict. +""" +function label_shapes(self :: AbstractModule) + throw(MethodError(label_shapes, (self,))) +end + +""" + output_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the outputs of this module. +""" +function output_shapes(self :: AbstractModule) + throw(MethodError(output_shapes, (self,))) +end + +## +# Parameters +## + +""" + get_params(self::AbstractModule) + +Get parameters, those are potentially copies of the the actual parameters used to do computation on the device. + +# Returns +`(arg_params, aux_params)`, a pair of dictionary of name to value mapping. +""" +function get_params(self :: AbstractModule) + throw(MethodError(get_params, (self,))) +end + +""" + set_params(self::AbstractModule; arg_params, aux_params, allow_missing, force_init) + +Assign parameter and aux state values. + +# Arguments +* `arg_params` : `Dict`. Dictionary of name to value (`NDArray`) mapping. +* `aux_params` : `Dict`. Dictionary of name to value (`NDArray`) mapping. +* `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. +* `force_init` : `Bool`. If true, will force re-initialize even if already initialized. +""" +function set_params(self::AbstractModule, + arg_params::Dict{Symbol, NDArray}, + aux_params::Dict{Symbol, NDArray}; + allow_missing=false, force_init=false) + init_params(self, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) +end + +@defstruct ModuleInitParamsOptions ( + initializer::AbstractInitializer=UniformInitializer(0.07), + arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + allow_missing::Bool=false, + force_init::Bool=false +) +""" + init_params!(self; kwargs...) + +Initialize the parameters and auxiliary states. + +# Arguments +* `self` : `AbstractModule` +* `initializer` : `AbstractInitializer`. Called to initialize parameters if needed. +* `arg_params` : `Dict{Symbol, NDArray}`. If not empty, should be a dictionary of existing `arg_params`. Initialization will be copied from that. +* `aux_params` : `Dict{Symbol, NDArray}`. If not empty, should be a dictionary of existing `aux_params`. Initialization will be copied from that. +* `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. +* `force_init` : `Bool`. If true, will force re-initialize even if already initialized. +""" +init_params(self :: AbstractModule; kwargs...) = init_params(self, ModuleInitParamsOptions(; kwargs...)) +function init_params(self :: AbstractModule, opts::ModuleInitParamsOptions) + throw(MethodError(init_params, (self, opts))) +end + +### +# Setup +### + +@defstruct ModuleBindOptions ( + for_training::Bool = true, + inputs_need_grad::Bool = true, + force_rebind::Bool = false, + grad_req::mx.GRAD_REQ = mx.GRAD_WRITE, + shared_module::Union{Void, AbstractModule} = nothing +) +""" +""" +bind(self::AbstractModule, data_provider::AbstractDataProvider; kwargs...) = + bind(self, + [x[2] for x in provide_data(data_provider)], + [x[2] for x in provide_label(data_provider)]; kwargs...) +bind(self::AbstractModule, data_shapes, label_shapes = Tuple{Int}[]; kwargs...) = bind(self, data_shapes, label_shapes, ModuleBindOptions(;kwargs...)) +function bind(self :: AbstractModule, data_shapes, label_shapes, opts::ModuleBindOptions) + throw(MethodError(bind, (self, data_shapes, label_shapes, opts))) +end + +""" +""" +function init_optimizer(self :: AbstractModule, ) +end + +### +# Computation +### +""" +""" +forward{T <: AbstractModule}(self :: T, data_batch :: DataBatch, is_train) = forward(self, StubProvider(), data_batch, is_train) +function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train) + throw(MethodError(forward, (self, ))) +end + +""" +""" +function backward(self :: AbstractModule, ) +end + +""" +""" +function update(self :: AbstractModule, ) +end + +""" +""" +function get_outputs(self :: AbstractModule, ) +end + +""" +""" +function get_input_grads(self :: AbstractModule, ) +end + +""" +""" + +update_metric(self :: AbstractModule, eval_metric::AbstractEvalMetric, batch::AbstractDataBatch) = update_metric(self, eval_metric, StubProvider(), batch) +function update_metric(self :: AbstractModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + throw(MethodError(update_metric, (self, ))) +end + +### +# Optional +## +""" + get_symbol(self::AbstractModule) -> Nullable{SymbolicNode} + +Returns the associated [`SymbolicNode`](@ref) of the module. It might not be defined or change over time. +""" +function get_symbol(self::AbstractModule) + return Nullable{SymbolicNode}() +end + +### +# High-level +### + +""" + fit(self::AbstractModule, train_data::AbstractDataProvider, num_epoch::Int; kwargs...) + +Train the module parameters. + +# Arguments +* `train_data` : AbstractDataProvider +* `eval_data` : AbstractDataProvider + If not `nothing`, will be used as validation set and evaluate the performance + after each epoch. +* `eval_metric` : str or EvalMetric + Default `'acc'`. The performance measure used to display during training. +* `epoch_end_callback` : function or list of function + Each callback will be called with the current `epoch`, `symbol`, `arg_params` + and `aux_params`. +* `batch_end_callback` : function or list of function + Each callback will be called with a `BatchEndParam`. +* `kvstore` : Symbol or KVStore + Default `:local`. +* `optimizer` : AbstractOptimizer + Default `ADAM` +* `eval_end_callback` : function or list of function + These will be called at the end of each full evaluation, with the metrics over + the entire evaluation set. +* `eval_batch_end_callback` : function or list of function + These will be called at the end of each minibatch during evaluation +* `initializer` : Initializer + Will be called to initialize the module parameters if not already initialized. +* `arg_params` : dict + Default `nothing`, if not `nothing`, should be existing parameters from a trained + model or loaded from a checkpoint (previously saved model). In this case, + the value here will be used to initialize the module parameters, unless they + are already initialized by the user via a call to `init_params` or `fit`. +`arg_params` has higher priority to `initializer`. +* `aux_params` : dict + Default `None`. Similar to `arg_params`, except for auxiliary states. +* `allow_missing` : bool + Default `False`. Indicate whether we allow missing parameters when `arg_params` + and `aux_params` are not `None`. If this is `True`, then the missing parameters + will be initialized via the `initializer`. +* `force_rebind` : bool + Default `False`. Whether to force rebinding the executors if already binded. +* `force_init` : bool + Default `False`. Indicate whether we should force initialization even if the + parameters are already initialized. +* `begin_epoch` : int + Default `1`. Indicate the starting epoch. Usually, if we are resuming from a + checkpoint saved at a previous training phase at epoch N, then we should specify + this value as N+1. +* `num_epoch` : int + Number of epochs to run training. + +# Examples +An example of using fit for training:: +```julia +# Assume training train_data and validation eval_data are ready +mx.Module.fit(mod, train_data, 10, eval_data=eval_data, + optimizer=mx.SGD(lr=0.01, momentum=0.9)) +``` +""" +function fit(self::AbstractModule, train_data, num_epoch; + initializer=UniformInitializer(0.07), + optimizer = ADAM(), + eval_data=nothing, + eval_metric::AbstractEvalMetric=Accuracy(), + validation_metric::AbstractEvalMetric = eval_metric, + epoch_end_callback = nothing, + batch_end_callback = nothing, + kvstore = :local, + eval_end_callback = nothing, + eval_batch_end_callback = nothing, + arg_params = Dict{Symbol, NDArray}(), + aux_params = Dict{Symbol, NDArray}(), + allow_missing = false, + force_rebind = false, + force_init = false, + begin_epoch = 1) + bind(self, train_data, for_training=true, force_rebind=force_rebind) + init_params(self, initializer=initializer, arg_params=arg_params, + aux_params=aux_params, allow_missing=allow_missing, + force_init=force_init) + init_optimizer(self, kvstore=kvstore, optimizer=optimizer, force_init=force_init) + + if validation_metric == nothing + validation_metric = eval_metric + end + + #################################################################### + # training loop + #################################################################### + for epoch in begin_epoch:num_epoch + time_start = time() + reset!(eval_metric) + for (nbatch, batch) in enumerate(eachdatabatch(train_data)) + forward_backward(self, batch) + update(self) + update_metric(self, eval_metric, batch) + + if batch_end_callback !== nothing + error("Not implemented yet") + end + end + + # one epoch of training is finished + for (name, val) in get(eval_metric) + info("Epoch[$epoch] Train-$name=$val") + end + time_stop = time() + info("Epoch[$epoch] Time cost=$(time_stop - time_start)") + + # sync aux params across devices + arg_params, aux_params = get_params(self) + set_params(self, arg_params, aux_params) + + if epoch_end_callback !== nothing + error("Not implemented yet") + end + + ################################################################## + # evaluation on validation set + ################################################################## + if eval_data !== nothing + res = score(self, eval_data, validation_metric, + score_end_callback=eval_end_callback, + batch_end_callback=eval_batch_end_callback, + epoch=epoch) + #TODO: pull this into default + for (name, val) in res + info("Epoch[$epoch] Validation-$name=$val") + end + end + end +end + +# XXX: warning, this function is not type stable. +""" + predict + +Run prediction and collect the outputs. + +# Arguments + +* `eval_data` : `AbstractDataProvider` +* `num_batch` : Int + Default is `None`, indicating running all the batches in the data iterator. +* `merge_batches` : `Bool` + Default is `true`, see the doc for return values. +* `always_output_list` : `Bool` + Default is `false`, see the doc for return values. + +# Returns +When `merge_batches` is `true` (by default), the return value will be a vector +`[out1, out2, out3]`. Where each element is concatenation of the outputs for +all the mini-batches. If further that `always_output_list` is `false` (by default), +then in the case of a single output, `out1` is returned instead of `[out1]`. +When `merge_batches` is `false`, the return value will be a nested list like +`[[out1_batch1, out2_batch1], [out1_batch2], ...]`. This mode is useful because +in some cases (e.g. bucketing), the module does not necessarily produce the same +number of outputs. +The objects in the results are `NDArray`s. If you need to work with julia array, +just call `Array{Float32}` on each of the `NDArray`. + +# Examples +An example of using predict for prediction:: +```julia +# Predict on the first 10 batches of `data` DataProvider +predict(m1, data, num_batch=10) +``` +""" +function predict(self::AbstractModule, eval_data::AbstractDataProvider; + num_batch=nothing, merge_batches=true, always_output_list::Bool=false) + @assert isbinded(self) && isinitialized(self) + + output_list = [] + for (nbatch, eval_batch) in enumerate(eachdatabatch(eval_data)) + if num_batch !== nothing && nbatch == num_back + break + end + forward(self, eval_batch, false) + + outputs = get_outputs(self) + push!(output_list, outputs) + + end + + if length(output_list) == 0 + return output_list + end + + if merge_batches + num_outputs = length(output_list[1]) + for out in output_list + @assert(length(out) == num_outputs, + "Cannot merge batches, as num of outputs is not the same in mini-batches. Maybe bucketing is used?") + end + output_list2 = [concatenate([out[i] for out in output_list]) for i = 1:num_outputs] + + if num_outputs == 1 && !always_output_list + return output_list2[1] + end + + return output_list2 + end + + return output_list +end + +""" + score(self::AbstractModule, eval_data, eval_metric; num_batch, batch_end_callback, reset=true, epoch=0) +""" +function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, score_end_callback=nothing, + epoch=0) + @assert isbinded(self) && isinitialized(self) + + reset!(eval_metric) + + for (nbatch, eval_batch) in enumerate(eachdatabatch(eval_data)) + if num_batch !== nothing && nbatch == num_back + break + end + + forward(self, eval_batch, false) + update_metric(self, eval_metric, eval_batch) + + if batch_end_callback !== nothing + error("Not implemented yet!") + end + end + if score_end_callback !== nothing + error("Not implemented yet!") + end + + get(eval_metric) +end + +""" + forward_backward(self :: AbstractModule, data_batch) +""" +function forward_backward(self :: AbstractModule, data_batch) + forward(self, data_batch, true) + backward(self) +end + +# include implementations +include("symbol_module.jl") +# include("pipeline.jl") +include("native_module.jl") +include("sequential_module.jl") + +end diff --git a/src/module/native_module.jl b/src/module/native_module.jl new file mode 100644 index 000000000..99e2b8699 --- /dev/null +++ b/src/module/native_module.jl @@ -0,0 +1,10 @@ +""" + NativeModule + +Allows the implementation of a MXNet module in native Julia. NDArrays +will be translated into native Julia arrays. +""" +type NativeModule{F<:Function,B<:Function} <: AbstractModule + forward :: F + backward :: B +end diff --git a/src/module/pipeline.jl b/src/module/pipeline.jl new file mode 100644 index 000000000..394fa16a7 --- /dev/null +++ b/src/module/pipeline.jl @@ -0,0 +1,39 @@ +abstract PipelineModule <: AbstractModule + +""" + SimplePipelineModule + +Allows the pipelining of several modules. + +# Arguments: +* `pipeline :: Vector{Module}` + The elements that are called sequentially + +# Functionality +* +""" +type SimplePipelineModule <: PipelineModule + pipeline :: Vector{Module} +end + +type ModuleDataProvider <: mx.AbstractDataProvider + mod :: Module +end + + +function forward(self :: SimplePipelineModule) + for mod in self.pipeline + forward(mod) + end +end + +function backward(self :: SimplePipelineModule) + for i in length(self.pipeline):-1:1 + mod = self.pipeline[i] + backward(mod) + end +end + +function get_outputs(self :: SimplePipelineModule) + return get_outputs(last(self.pipeline)) +end diff --git a/src/module/sequential_module.jl b/src/module/sequential_module.jl new file mode 100644 index 000000000..811e21c2d --- /dev/null +++ b/src/module/sequential_module.jl @@ -0,0 +1,304 @@ +import ....MXNet: mx # in order to use mx. +import Base.push! + +@defstruct SequentialModuleMetas ( + take_labels :: Bool = false, + auto_wiring :: Bool = false, + join :: Dict{Symbol, Symbol} = Dict{Symbol, Symbol}() +) + +""" + SequentialModule + +A SequentialModule is a container module that can chain multiple modules together. +Note building a computation graph with this kind of imperative container is less +flexible and less efficient than the symbolic graph. So this should be only used as a +handy utility. + +# Parameters + +""" +type SequentialModule <: AbstractModule + modules :: Vector{AbstractModule} + metas :: Vector{SequentialModuleMetas} + + binded :: Bool + for_training :: Bool + inputs_need_grad :: Bool + params_initialized :: Bool + optimizer_initialized :: Bool + + label_names :: Vector{Symbol} + label_shapes :: Vector{Tuple{Vararg{Int}}} + function SequentialModule(label_names = [:softmax_label]) + new(Vector{AbstractModule}(), + Vector{Symbol}[], + false, false, false, false, false, + label_names) + end +end + +### default API +isbinded(self::SequentialModule) = self.binded +allows_training(self::SequentialModule) = self.for_training +isinitialized(self::SequentialModule) = self.params_initialized +hasoptimizer(self::SequentialModule) = self.optimizer_initialized + +data_names(self::SequentialModule) = length(self.modules) > 0 ? data_names(self.modules[1]) : Symbol[] +label_names(self::SequentialModule) = self.label_names +output_names(self::SequentialModule) = length(self.modules) > 0 ? output_names(self.modules[end]) : Symbol[] + +""" + push!(self, module; kwargs...) + +Add a module to the chain. +# Arguments +* `module` : AbstractModule + The new module to add. +* `kwargs` : keywords + All the keyword arguments are saved as meta information + for the added module. The currently known meta includes + * `take_labels`: `Bool` indicating whether the module expect to + take labels when doing computation. Note any module in + the chain can take labels (not necessarily only the top + most one), and they all take the same labels passed + from the original data batch for the `SequentialModule`. + * `auto_wiring`: `Bool` transfer data outputs of previous module + to the data inputs of next module, with the order given by + `output_names()` + * `join`: `Dict{Symbol, Symbol}`. + +# Returns + +This function returns `self` to allow us to easily chain a +series of `add` calls. + +# Examples +An example of addinging two modules to a chain:: +```julia +seq_mod = @mx.chain mx.Module.SequentialModule() => + mx.Module.push!(mod1) => + mx.Module.push!(mod2) +``` +""" +function push!(self::SequentialModule, mod::AbstractModule; kwargs...) + push!(self.modules, mod) + + metas = SequentialModuleMetas(;kwargs...) + for (key, _) in (kwargs...) + @assert(key ∈ fieldnames(metas), "Unknown meta '$key', a typo?") + end + push!(self.metas, metas) + + # after adding new modules, we are reset back to raw states, needs + # to bind, init_params, etc. + self.binded = false + self.params_initialized = false + self.optimizer_initialized = false + + return self # for easier chaining +end + +function data_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return data_shapes(self.modules[1]) +end + +function label_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return self.label_shapes +end + +function output_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return output_shapes(self.modules[end]) +end + +function get_params(self::SequentialModule) + @assert isbinded(self) && isinitialized(self) + + reduce((Dict{Symbol, NDArray}(), Dict{Symbol, NDArray}()), self.modules) do acc, mod + arg, aux = get_params(mod) + merge(acc[1], arg) + merge(acc[2], aux) + end +end + +""" +""" +function init_params(self::SequentialModule, opts::ModuleInitParamsOptions) + if isinitialized(self) && !opts.force_init + return self + end + + @assert(isbinded(self), "call bind before initializing the parameters") + # make sure we do not have duplicated parameter names + arg_dict = Dict() + aux_dict = Dict() + duplicates = false + for (i, mod) in enumerate(self.modules) + arg_params, aux_params = get_params(mod) + map((arg_dict, arg_params), (aux_dict, aux_params)) do arg + dict, params = arg + for name in keys(params) + if haskey(dict, name) + info("Name $name in layer $i ($(typeof(mod))) is already used in layer $(dict[name][1])($(typeof(dict[name][2])))") + duplicates = true + else + dict[name] = (i, typeof(mod)) + end + end + end + end + if duplicates + error("Duplicates in layer names") + end + + for mod in self.modules + init_params(mod, opts) + end + + self.params_initialized = true + + return self +end + +""" +""" +function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBindOptions) + if opts.inputs_need_grad + @assert opts.for_training + end + if opts.shared_module !== nothing + info("Shared module is not supported") + end + @assert(length(self.modules) > 0, "Attempting to bind empty SequentialModule") + + wrap_in_dict(x::Vector, names) = Dict(k => v for (k, v) in zip(names, x)) + wrap_in_dict(x::Dict, names) = x + data_shapes = wrap_in_dict(data_shapes, data_names(self)) + + # the same label shapes are used for all chained modules + self.label_shapes = label_shapes + + module_data_shapes = data_shapes + module_data_names = data_names(self) + anybody_ever_needs_label = false + for (i, mod) in enumerate(self.modules) + meta = self.metas[i] + if meta.take_labels + module_label_shapes = label_shapes + anybody_ever_needs_label = true + else + module_label_shapes = Tuple{Int}[] + end + + module_inputs_need_grad = opts.inputs_need_grad || (opts.for_training && i > 1) + + @assert length(module_data_names) == length(data_names(mod)) + if length(module_data_names) == 1 + module_data_shapes = collect(values(module_data_shapes)) + elseif Set(module_data_names) != Set(data_names(mod)) + if meta.auto_wiring + module_data_shapes = [module_data_shapes[name] for name in module_data_names] + else + @assert union(setdiff(Set(module_data_names), Set(keys(meta.join))), Set(values(meta.join))) == Set(data_names(mod)) + module_data_shapes = Dict(get(meta.join, name, name) => value for (name, value) in zip(module_data_names, module_data_shapes)) + end + end + + bind(mod, module_data_shapes, module_label_shapes, + for_training=opts.for_training, inputs_need_grad=module_inputs_need_grad, + force_rebind=opts.force_rebind, shared_module=nothing, grad_req=opts.grad_req) + + # the output of the previous module is the data of the next module + module_data_names = output_names(mod) + module_data_shapes = output_shapes(mod) + end + + if !anybody_ever_needs_label + # then I do not need label either + self.label_shapes = Tuple{Int}[] + end + self.binded = true + + return self +end + +""" +""" +function init_optimizer(self::SequentialModule; optimizer::AbstractOptimizer=ADAM(), kvstore :: Union{Base.Symbol, KVStore}=:local, force_init :: Bool=false) + @assert isbinded(self) && isinitialized(self) + if hasoptimizer(self) && !force_init + warn("Optimizer already initialized, ignoring.") + end + for mod in self.modules + init_optimizer(mod, optimizer=optimizer, kvstore=kvstore, force_init=force_init) + end + + self.optimizer_initialized = true + return self +end + +""" +""" +function get_outputs(self::SequentialModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) + return get_outputs(last(self.modules), merge_multi_context) +end + +""" +""" +function forward(self::SequentialModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool=self.for_training) + @assert isbinded(self) && isinitialized(self) + + batch = DataBatch(data_provider, data_batch) + for (i, mod) in enumerate(self.modules) + forward(mod, batch, is_train) + batch.data = get_outputs(mod) + end +end + +""" +""" +function backward(self::SequentialModule, out_grads::Vector{NDArray}) + @assert isbinded(self) && isinitialized(self) + + for (i, mod) in reverse(zip(1:length(self.modules), self.modules)) + backward(mod, out_grads) + if i == 1 + break + end + + out_grads = get_input_grads(mod) + end +end + +""" +""" +function update(self::SequentialModule) + @assert isbinded(self) && isinitialized(self) && hasoptimizer(self) + + for mod in self.modules + update(mod) + end +end + +""" +""" +function update_metric(self::SequentialModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + @assert isbinded(self) && isinitialized(self) + for (meta, mod) in zip(self.metas, self.modules) + if meta.take_labels + update_metric(mod, eval_metric, provider, batch) + end + end +end + +""" +""" +function get_input_grads(self::SequentialModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) && self.inputs_need_grad + + return get_input_grads(self.modules[1], merge_multi_context) +end diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl new file mode 100644 index 000000000..f5935b821 --- /dev/null +++ b/src/module/symbol_module.jl @@ -0,0 +1,389 @@ +import ....MXNet: mx # in order to use mx. + +""" + SymbolModule + +SymbolModule is a basic module that wraps a `SymbolicNode`. It is functionally the same +as the `FeedForward` model, except using the module API. + +A current limitation is that it only supports one context. + +# Parameters +* `symbol :: SymbolicNode`: The wrapped `SymbolicNode` +* `data_names :: Vector{Symbol}`: +""" +type SymbolModule <: AbstractModule + symbol :: SymbolicNode + data_names :: Vector{Symbol} + label_names :: Vector{Symbol} + aux_names :: Vector{Symbol} + context :: Vector{Context} + + binded :: Bool + for_training :: Bool + inputs_need_grad :: Bool + params_initialized :: Bool + optimizer_initialized :: Bool + + data_shapes :: Vector{Tuple{Vararg{Int}}} + label_shapes :: Vector{Tuple{Vararg{Int}}} + + arg_arrays :: Nullable{Vector{NDArray}} + aux_arrays :: Nullable{Vector{NDArray}} + grad_arrays :: Nullable{Vector{NDArray}} + params_dirty :: Bool + + fixed_param_names :: Nullable{Vector{Symbol}} + optimizer + updater + kvstore + update_on_kvstore + + arg_params + aux_params + + exec_group :: AbstractExecutorGroup + + function SymbolModule(symbol::SymbolicNode, data_names::Vector{Symbol}, + label_names::Vector{Symbol}, context :: Vector{Context}, + fixed_param_names::Nullable{Vector{Symbol}}) + + aux_names = mx.list_auxiliary_states(symbol) + return new(symbol, data_names, label_names, aux_names, context, + false, false, false, false, false, + Vector{Tuple{Int}}(), + Vector{Tuple{Int}}(), + Nullable{Vector{NDArray}}(), + Nullable{Vector{NDArray}}(), + Nullable{Vector{NDArray}}(), + false, + fixed_param_names) + end +end + +function SymbolModule(symbol::SymbolicNode; + data_names = [:data], + label_names = [:softmax_label], + context::Union{Context, Vector{Context}} = [mx.cpu()], + fixed_param_names = nothing) + fixed_param_names = Nullable{Vector{Symbol}}(fixed_param_names) + label_names = Vector{Symbol}(label_names) + context = Vector{Context}(context) + @assert !isempty(data_names) + @assert !isempty(context) + return SymbolModule(symbol, data_names, label_names, context, fixed_param_names) +end + +### default API +isbinded(self::SymbolModule) = self.binded +allows_training(self::SymbolModule) = self.for_training +isinitialized(self::SymbolModule) = self.params_initialized +hasoptimizer(self::SymbolModule) = self.optimizer_initialized + +data_names(self::SymbolModule) = self.data_names +label_names(self::SymbolModule) = self.label_names +output_names(self::SymbolModule) = list_outputs(self.symbol) + +""" + get_symbol(self::SymbolModule) -> Nullable{SymbolicNode} + +Returns the associated [`SymbolicNode`](@ref) of the module. It might not be defined or change over time. +""" +function get_symbol(self::SymbolModule) + return Nullable{SymbolicNode}(self.symbol) +end + +function data_shapes(self::SymbolModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return Dict(k => v for (k, v) in zip(data_names(self), self.data_shapes)) +end + +function label_shapes(self::SymbolModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return Dict(k => v for (k, v) in zip(label_names(self), self.label_shapes)) +end + +function output_shapes(self::SymbolModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return mx.output_shapes(self.exec_group) +end + +function get_params(self::SymbolModule) + if !(isbinded(self) && isinitialized(self)) + return (Dict{Symbol, NDArray}(), Dict{Symbol, NDArray}()) + end + if self.params_dirty + mx.get_params!(self.exec_group, self.arg_params, self.aux_params) + self.params_dirty = false + end + + return (self.arg_params, self.aux_params) +end + +function init_params(self::SymbolModule, opts::ModuleInitParamsOptions) + if isinitialized(self) && !opts.force_init + return self + end + @assert isbinded(self) "Call `bind` before initialization" + + if !isdefined(self, :arg_params) || isempty(self.arg_params) + self.arg_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.arg_params) + end + + if !isdefined(self, :aux_params) || isempty(self.aux_params) + self.aux_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.aux_params) + end + + map([[self.arg_params, opts.arg_params], [self.aux_params, opts.aux_params]]) do param_arr + dst, src = param_arr + for (name, arr) in dst + if isempty(src) + init(opts.initializer, name, arr) + else + src = get(src) + if name in keys(src) + if src[name] != arr + copy!(arr, src[name]) + end + else + @assert(!opts.allow_missing, "$name is not presented") + init(opts.initializer, name, arr) + end + end + end + end + + # copy the initialized parameters to devices + set_params!(self.exec_group, self.arg_params, self.aux_params) + + self.params_dirty = false + self.params_initialized = true + + return self +end + +function bind(self::SymbolModule, data_shapes, label_shapes, opts::ModuleBindOptions) + if opts.force_rebind + reset_bind(self) + end + + if isbinded(self) + warn("Already bound, ignoring bind()") + return self + end + + if !opts.for_training + @assert !opts.inputs_need_grad + end + + self.for_training = opts.for_training + self.inputs_need_grad = opts.inputs_need_grad + self.binded = true + + wrap_in_vector(x::Vector, names) = x + wrap_in_vector(x::Dict, names) = [x[name] for name in names] + data_shapes = wrap_in_vector(data_shapes, self.data_names) + label_shapes = wrap_in_vector(label_shapes, self.label_names) + + @assert length(self.data_names) == length(data_shapes) + @assert length(self.label_names) == length(label_shapes) + + self.data_shapes = data_shapes + self.label_shapes = label_shapes + + # TODO propagate type information + data_types = [Float32 for _ in 1:length(self.data_names)] + label_types = [Float32 for _ in 1:length(self.label_names)] + + if opts.shared_module !== nothing + shared_group = Nullable(opts.shared_module.exec_group) + else + shared_group = Nullable{DataParallelExecutorGroup}() + end + + self.exec_group = DataParallelExecutorGroup(self.symbol, self.context, + self.data_shapes, self.data_names, data_types, + self.label_shapes, self.label_names, label_types, + self.for_training, self.inputs_need_grad, shared_group, + self.fixed_param_names, opts.grad_req) + return self +end + +# TODO add description +function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), kvstore :: Union{Base.Symbol, KVStore}=:local, force_init :: Bool=false) + @assert isbinded(self) && isinitialized(self) + + if hasoptimizer(self) && !force_init + warn("Optimizer already initialized, ignoring...") + return self + end + + # TODO initialize KV store + # setup kvstore + #= kvstore, update_on_kvstore = _create_kvstore(kvstore, length(self.context), self.arg_params) =# + kvstore, update_on_kvstore = nothing, false + + self.optimizer = optimizer + self.kvstore = kvstore + self.update_on_kvstore = update_on_kvstore + self.updater = Nullable() + + # add adequate calculation of batch_size + op_state = OptimizationState(self.data_shapes[1][end]) + optimizer.state = op_state + + if !update_on_kvstore + self.updater = Nullable(get_updater(optimizer)) + end + if !isa(kvstore, Void) + if update_on_kvstore + set_optimizer(kvstore, optimizer) + end + + info("Initializing KVStore...") + # init kv with gradients + for idx = 1:length(param_arrays) + param_on_devs = param_arrays[idx] + + init!(kvstore, idx, self.arg_params[param_names[idx]]) + + if update_on_kvstore + # pull weights back + pull!(kvstore, idx, param_on_devs, priority=-idx) + end + end + end + + # TODO add preloaded states + #= if !isa(self.preload_opt_states, Void) =# + #= load_optimizer_states!(self, self.preload_opt_states) =# + #= self.preload_opt_states = nothing =# + #= end =# + + self.optimizer_initialized = true + + return self +end + +""" + get_outputs + +Get outputs of the previous forward computation. + +# Arguments +* merge_multi_context : bool + Default is `True`. In the case when data-parallelism is used, the outputs + will be collected from multiple devices. A `True` value indicate that we + should merge the collected results so that they look like from a single + executor. + +# Returns +If `merge_multi_context` is `true`, it is like `[out1, out2]`. Otherwise, it +is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_outputs(self::SymbolModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) + mx.get_outputs(self.exec_group, merge_multi_context) +end + +""" + forward(module, data_provider, data_batch; is_train) + +Forward computation. + +# Arguments +* `data_batch` : AbstractDataBatch +* `is_train` : Bool + Default is `nothing`, which means `is_train` takes the value of `self.for_training`. +""" +forward(self::SymbolModule, data_batch::DataBatch) = forward(self, data_batch, self.for_training) +function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool=self.for_training) + @assert isbinded(self) && isinitialized(self) + mx.forward(self.exec_group, data_provider, data_batch, is_train) +end + +""" + backward(module, out_grads) +Backward computation. +# Arguments +* `out_grads` : `NDArray` or vector of `NDArray`, default `nothing`. + Gradient on the outputs to be propagated back. + This parameter is only needed when bind is called + on outputs that are not a loss function. +""" +backward(self::SymbolModule, out_grads::Void=nothing) = backward(self, NDArray[]) +backward(self::SymbolModule, out_grads::NDArray) = backward(self, [out_grads]) +function backward(self::SymbolModule, out_grads::Vector{NDArray}) + @assert isbinded(self) && isinitialized(self) + mx.backward(self.exec_group, out_grads) +end + +""" + update(module) +Update parameters according to the installed optimizer and the gradients computed +in the previous forward-backward batch. +""" +function update(self::SymbolModule) + @assert isbinded(self) && isinitialized(self) && hasoptimizer(self) + self.params_dirty = true + update_params(self.exec_group, self.updater, self.update_on_kvstore, self.kvstore) +end + +""" + update_metric() + +Evaluate and accumulate evaluation metric on outputs of the last forward computation. +# Arguments +* eval_metric : EvalMetric +* labels : Dict of NDArray + Typically `data_batch.label`. +""" +function update_metric(self::SymbolModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + mx.update_metric(self.exec_group, eval_metric, provider, batch) +end + +""" + get_input_grads(self::SymbolModule, merge_multi_context=true) + +Get the gradients with respect to the inputs of the module. + +# Arguments + +* `merge_multi_context` : `Bool` + Default is `true`. In the case when data-parallelism is used, the outputs +will be collected from multiple devices. A `true` value indicate that we +should merge the collected results so that they look like from a single +executor. + +# Returns +If `merge_multi_context` is `true`, it is like `[grad1, grad2]`. Otherwise, it +is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_input_grads(self::SymbolModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) && self.inputs_need_grad + return mx.get_input_grads(self.exec_group, merge_multi_context) +end + +## +# Internals +## + +""" + borrow_optimizer!(module, shared_module) + +Borrow optimizer from a shared module. Used in bucketing, where exactly the same +optimizer (esp. kvstore) is used. +# Arguments +* `module` : SymbolModule +* `shared_module` : SymbolModule +""" +function borrow_optimizer!(self::SymbolModule, shared_module::SymbolModule) + @assert hasoptimizer(shared_module) + self.optimizer = shared_module.optimizer + self.kvstore = shared_module.kvstore + self.update_on_kvstore = shared_module.update_on_kvstore + self.updater = shared_module.updater + self.optimizer_initialized = true +end diff --git a/src/ndarray.jl b/src/ndarray.jl index d37b321a2..f86d30337 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -353,7 +353,10 @@ end function setindex!(arr :: NDArray, val :: NDArray, ::Colon) copy!(arr, val) end -function setindex!{T<:Real}(arr :: NDArray, val :: Union{T,Array{T},NDArray}, idx::UnitRange{Int}) +function setindex!(arr :: NDArray, val :: NDArray, idx::UnitRange{Int}) + setindex!(slice(arr, idx), val, Colon()) +end +function setindex!{T<:Real}(arr :: NDArray, val :: Union{T,Array{T}}, idx::UnitRange{Int}) setindex!(slice(arr, idx), val, Colon()) end @@ -679,6 +682,44 @@ function /(arg0 :: NDArray, arg :: Real) ./(arg0, arg) end +""" + concatenate(arrays; always_copy=true, context=cpu) + +Concatenate a list of NDArrays along the last dimension. + +# Arguments +* `arrays` : vector of NDArray + Arrays to be concatenate. They must have identical shape except + the last dimension. +* `always_copy` : `Bool`, default `true`. When `false`, if the arrays only contain one `NDArray`, that element will be returned directly, avoid copying. +* `context`: `Context`, context of output NDArray. + +# Returns +An `NDArray` that lives on the `context`. +""" +function concatenate(arrays::Vector{NDArray}; always_copy=true, context=cpu()) + if isempty(arrays) || (!always_copy && length(arrays) == 1) + return arrays + end + + shape_axis = size(arrays[1])[end] + shape_rest = size(arrays[1])[1:end-1] + for i in 2:length(arrays) + shape_axis += size(arrays[i])[end] + @assert shape_rest == size(arrays[i])[1:end-1] + end + + ret_shape = tuple(shape_rest..., shape_axis) + ret = empty(ret_shape, context) + + idx = 1 + for arr in arrays + ret[idx:(idx + size(arr)[end] - 1)] = arr + idx += size(arr)[end] + end + + return ret +end """ Manipulating as Julia Arrays diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index f5a518c35..d5b805ee1 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -297,6 +297,7 @@ end """ infer_shape(self :: SymbolicNode, args...) + infer_shape(self :: SymbolicNode, args::Dict) infer_shape(self :: SymbolicNode; kwargs...) Do shape inference according to the input shapes. The input shapes could be provided @@ -308,7 +309,8 @@ Returns a 3-tuple containing shapes of all the arguments, shapes of all the outp shapes of all the auxiliary variables. If shape inference failed due to incomplete or incompatible inputs, the return value will be `(nothing, nothing, nothing)`. """ -function infer_shape(self :: SymbolicNode; kwargs...) + +function infer_shape(self :: SymbolicNode, kwargs::Dict) sdata = MX_uint[] indptr = MX_uint[0] for (k,v) in kwargs @@ -318,6 +320,8 @@ function infer_shape(self :: SymbolicNode; kwargs...) keys = AbstractString[string(x[1]) for x in kwargs] _infer_shape(self, keys, indptr, sdata) end +infer_shape(self :: SymbolicNode; kwargs...) = infer_shape(self, Dict(kwargs)) + function infer_shape(self :: SymbolicNode, args :: Union{Tuple, Void}...) sdata = MX_uint[] indptr = MX_uint[0] @@ -345,7 +349,7 @@ function _infer_type(self, keys, arg_type_data) Ref{MX_uint}, Ref{Ptr{Cint}}, Ref{MX_uint}, Ref{Ptr{Cint}}, Ref{Cint}), - self, length(arg_type_data)-1, keys, arg_type_data, + self, length(arg_type_data), keys, arg_type_data, ref_in_type_size, ref_in_type_data, ref_out_type_size, ref_out_type_data, ref_aux_type_size, ref_aux_type_data, @@ -365,6 +369,7 @@ end """ infer_type(self :: SymbolicNode; kwargs...) + infer_type(self :: SymbolicNode, args::Dict) infer_type(self :: SymbolicNode, args...) Do type inference according to the input types. The input types could be provided @@ -376,11 +381,12 @@ Returns a 3-tuple containing types of all the arguments, types of all the output types of all the auxiliary variables. If type inference failed due to incomplete or incompatible inputs, the return value will be `(nothing, nothing, nothing)`. """ -function infer_type(self :: SymbolicNode; kwargs...) +function infer_type(self :: SymbolicNode, kwargs::Dict) types = Cint[toTypeFlag(x[2]) for x in kwargs] keys = AbstractString[string(x[1]) for x in kwargs] _infer_type(self, keys, types) end +infer_type(self :: SymbolicNode; kwargs...) = infer_type(self, Dict(kwargs)) function infer_type(self :: SymbolicNode, args :: Union{Tuple, Void}...) types = Cint[] diff --git a/test/test-module.jl b/test/test-module.jl new file mode 100644 index 000000000..da0daaf73 --- /dev/null +++ b/test/test-module.jl @@ -0,0 +1,17 @@ +using MXNet +using Base.Test + +# run test in the whole directory, latest modified files +# are run first, this makes waiting time shorter when writing +# or modifying unit-tests +function test_dir(dir) + jl_files = sort(filter(x -> ismatch(r".*module\.jl$", x), readdir(dir)), by = fn -> stat(joinpath(dir,fn)).mtime) + map(reverse(jl_files)) do file + include("$dir/$file") + end +end + +include(joinpath(dirname(@__FILE__), "common.jl")) +@testset "Modules Test" begin + test_dir(joinpath(dirname(@__FILE__), "unittest")) +end diff --git a/test/unittest/sequential-module.jl b/test/unittest/sequential-module.jl new file mode 100644 index 000000000..e4318424c --- /dev/null +++ b/test/unittest/sequential-module.jl @@ -0,0 +1,51 @@ +module TestSequentialModule +using MXNet +using Base.Test + +using ..Main: reldiff + +################################################################################ +# Utils +################################################################################ + +################################################################################ +# Test Implementations +################################################################################ + +function test_basic() + info("SequentialModule::basic") + + net1 = @mx.chain mx.Variable(:data) => + mx.FullyConnected(name=:fc1, num_hidden=4) + net2 = @mx.chain mx.Variable(:fc1_output) => + mx.FullyConnected(name=:fc2, num_hidden=1) => + mx.LinearRegressionOutput(name=:linout) + + m1 = mx.Module.SymbolModule(net1, label_names=Symbol[]) + m2 = mx.Module.SymbolModule(net2, data_names=[:fc1_output], label_names=[:linout_label]) + seq_mod = mx.Module.SequentialModule() + mx.Module.push!(seq_mod, m1) + mx.Module.push!(seq_mod, m2, take_labels=true) + @test !mx.Module.isbinded(seq_mod) + @test !mx.Module.allows_training(seq_mod) + @test !mx.Module.isinitialized(seq_mod) + @test !mx.Module.hasoptimizer(seq_mod) + + @test mx.Module.data_names(seq_mod) == [:data] + @test mx.Module.output_names(seq_mod) == [:linout_output] + + mx.Module.bind(seq_mod, [(4, 10)], [(1, 10)]) + @test mx.Module.isbinded(seq_mod) + @test !mx.Module.isinitialized(seq_mod) + @test !mx.Module.hasoptimizer(seq_mod) +end + +################################################################################ +# Run tests +################################################################################ + +@testset " Sequential Module Test" begin + test_basic() +end + +end diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl new file mode 100644 index 000000000..f9ad74b08 --- /dev/null +++ b/test/unittest/symbol-module.jl @@ -0,0 +1,150 @@ +module TestSymbolModule +using MXNet +using Base.Test + +using ..Main: reldiff + +################################################################################ +# Utils +################################################################################ + +function create_network() + arch = mx.@chain mx.Variable(:data) => + mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => + mx.SoftmaxOutput(name=:softmax, multi_output = true) + + return arch +end + +function create_linreg(num_hidden::Int=1) + arch = @mx.chain mx.Variable(:data) => + mx.FullyConnected(name=:fc1, num_hidden=num_hidden) => + mx.FullyConnected(name=:fc2, num_hidden=1) => + mx.LinearRegressionOutput(name=:linout) + return arch +end + +################################################################################ +# Test Implementations +################################################################################ + +function test_basic() + info("SymbolModule::basic") + + m1 = mx.Module.SymbolModule(create_network()) + + @test !mx.Module.isbinded(m1) + @test !mx.Module.allows_training(m1) + @test !mx.Module.isinitialized(m1) + @test !mx.Module.hasoptimizer(m1) + + @test mx.Module.data_names(m1) == [:data] + @test mx.Module.output_names(m1) == [:softmax_output] + + mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) + @test mx.Module.isbinded(m1) + @test !mx.Module.isinitialized(m1) + @test !mx.Module.hasoptimizer(m1) + + mx.Module.init_params(m1) + @test mx.Module.isinitialized(m1) + + mx.Module.init_optimizer(m1) + @test mx.Module.hasoptimizer(m1) +end + +function test_shapes() + info("SymbolModule::Shapes") + + m1 = mx.Module.SymbolModule(create_network()) + mx.Module.bind(m1, [(20, 20, 1, 10)], [(400, 10)]) + + @test mx.Module.data_shapes(m1) == Dict(:data => (20, 20, 1, 10)) + @test mx.Module.label_shapes(m1) == Dict(:softmax_label => (400, 10)) + @test mx.Module.output_shapes(m1) == Dict(:softmax_output => (20, 20, 64, 10)) + + m2 = mx.Module.SymbolModule(create_network(), label_names=[]) + mx.Module.bind(m2, [(20, 20, 1, 10)]) + @test isempty(mx.Module.label_shapes(m2)) +end + +function test_linear_regression(n_epoch::Int = 10) + info("SymbolModule::LinearRegression") + + srand(123456) + epsilon = randn(1, 10) + x = rand(4, 10) + y = mapslices(sum, [1, 2, 3, 4] .* x, 1) .+ epsilon + data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 5) + + metric = mx.MSE() + m1 = mx.Module.SymbolModule(create_linreg(4), + label_names = [:linout_label], + context=[mx.cpu(), mx.cpu()]) + mx.Module.bind(m1, data) + mx.Module.init_params(m1) + mx.Module.init_optimizer(m1) + + # TODO Should be changed to tests + #= info(mx.Module.get_params(m1)) =# + + for i in 1:n_epoch + for batch in mx.eachdatabatch(data) + mx.Module.forward(m1, batch) + mx.Module.backward(m1) + mx.Module.update(m1) + + mx.Module.update_metric(m1, metric, batch) + end + + for (name, value) in get(metric) + info("Epoch: $i: $name = $value") + end + mx.reset!(metric) + end + + y_pred = Float64[] + for batch in mx.eachdatabatch(data) + mx.Module.forward(m1, batch, false) + append!(y_pred, Array{Float64}(mx.Module.get_outputs(m1)[1])) + end + y_pred = reshape(y_pred, 1, 10) + + info("Prediction: $y_pred") + info("Actual: $y") + info("No noise: $(mapslices(sum, [1, 2, 3, 4] .* x, 1))") + + # High Level Api + name, score = mx.Module.score(m1, data, metric)[1] + ha_pred = mx.copy(mx.Module.predict(m1, data)) + info("Predict result: ", ha_pred) + info("Score $name : ", score) + + @test sum(abs(ha_pred-y_pred)) < 1e-6 + + m2 = mx.Module.SymbolModule(create_linreg(4), + label_names = [:linout_label], + context=[mx.cpu(), mx.cpu()]) + mx.Module.fit(m2, data, 10, eval_metric=mx.MSE()) + name, score = mx.Module.score(m2, data, metric)[1] + ha_pred = mx.copy(mx.Module.predict(m2, data)) + info("Predict result: ", ha_pred) + info("Score $name : ", score) + @test sum(abs(ha_pred-y_pred)) < 1e-1 +end + +function test_simplepipeline() +end + +################################################################################ +# Run tests +################################################################################ + +@testset " Symbol Module Test" begin + test_basic() + test_shapes() + #= test_init_params(500) =# + test_linear_regression() +end + +end