From b11299e74ba917c7b5dac70cbd792022ee473910 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 19 Sep 2016 02:44:19 +0900 Subject: [PATCH 01/18] add basic outline --- src/module/Module.jl | 114 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 src/module/Module.jl diff --git a/src/module/Module.jl b/src/module/Module.jl new file mode 100644 index 000000000..f35965509 --- /dev/null +++ b/src/module/Module.jl @@ -0,0 +1,114 @@ +""" + AbstractModule + +A module represnets a computation component. The design purpose of a module is +that abstracts a computation unit, that one can run forward, backward, update parameters, etc. +We aim to make the APIs easy to use, especially in the case when we need to use +an imperative API to work with multiple modules (e.g. stochastic depth networks). + +A module has several states: + +* Initial state. Memory is not allocated yet, not ready for computation. +* Binded. Shapes for inputs, outputs, and parameters are all known, memory allocated. +* Parameter initialized. For modules with parameters, doing computation before intitializing + the parameters might result in undefined outputs. +* Optimizer installed. An optimizer can be installed to a module. After this, the parameters. + of the module can be updated according to the optimizers after gradients are computed + (forward-backward). + +In order for a module to interact with others, a module should be able to report the following +information in its raw stage (before binded): + +* [`data_names`](@ref): Names of required data. +* [`output_names`](@ref): Names of the defined outputs. + +And also the following richer information after being binded: + +* State information: + * [`isbinded`](@ref): indicating whether the memory buffers needed for computation + have been allocated. + * [`allows_training`](@ref): whether the module is binded for training (if binded). + * [`isinitialized`](@ref): indicating whether the parameters of this module have + been initialized. + * [`hasoptimizer`](@ref): indicating wherger an optimizers is defined and intialized. +* Input/Output information: + * [`data_shapes`](@ref): + * [`label_shapes`](@ref): + * [`output_shapes`](@ref): +* Parameters (for modules with parameters) + * [`get_params`](@ref): + * [`set_params`](@ref): + * [`init_params`](@ref): +* Setup: + * [`bind`](@ref): + * [`init_optimizer`](@ref): +* Computation: + * [`forward`](@ref): + * [`backward`](@ref): + * [`update!`](@ref): + * [`get_outputs`](@ref): + * [`get_input_grads`](@ref): + * [`update_metric`](@ref): + +Based on the underlyin API a high-level API is implemented: +* [`fit`](@ref): +* [`predict`](@ref): +* [`score`](@ref): +""" +abstract AbstractModule + +function isbinded(self::AbstractModule) + throw(MethodError(isbinded, (self,))) +end + +function allows_training(self::AbstractModule) + throw(MethodError(allows_training, (self,))) +end + +function isinitialized(self::AbstractModule) + throw(MethodError(isinitialized, (self,))) +end + +function hasoptimizer(self::AbstractModule) + throw(MethodError(hasoptimizer, (self,))) +end + + +function forward_backward(self :: AbstractModule, data_batch) + forward(self, is_train=true) + backward(self) +end + +function score(self :: AbstractModule, eval_data, eval_metric, num_batch=nothing, batch_end_callback=nothing, reset=true, epoch=0) + @assert isbinded(self) && isinitialized(self) + + reset && reset!(eval_data) + reset!(eval_metric) + + for (nbatch, eval_batch) in enumerate(eval_data) + if num_batch !== nothing && nbatch == num_back + break + end + + forward(self, eval_batch, is_train=false) + update_metric(self, eval_metric, label(eval_batch)) + + if batch_end_callback !== nothing + error("Not implemented yet!") + end + end + get(eval_metric) +end + +function iter_predict(self :: AbstractModule, eval_data, num_batch=nothing, reset=true) + @assert isbinded(self) && isinitialized(self) + + reset && reset!(eval_data) + + for (nbatch, eval_batch) in enumerate(eval_data) + if num_batch !== nothing && nbatch == num_back + break + end + forward(self, eval_batch, is_train=false) + samples = count_samples(eval_batch) +end From f214bf0c9d89d0b8c4f70b0d1982a3dbd6e1f98b Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Sep 2016 06:01:42 +0900 Subject: [PATCH 02/18] add basic interface prototypes --- src/module/Module.jl | 178 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 162 insertions(+), 16 deletions(-) diff --git a/src/module/Module.jl b/src/module/Module.jl index f35965509..221ef3222 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -49,37 +49,187 @@ And also the following richer information after being binded: * [`get_outputs`](@ref): * [`get_input_grads`](@ref): * [`update_metric`](@ref): +* Optional: + * [`get_symbol`](@ref): Access the associated `SymbolicNode` if the module has one. + The returned value needs not to be constant (or defined) -Based on the underlyin API a high-level API is implemented: +Based on the underlying API a high-level API is implemented: * [`fit`](@ref): * [`predict`](@ref): * [`score`](@ref): +* [`forward_backward`](@ref): """ abstract AbstractModule +## +# Names +## +""" + data_names(self::AbstractModule) -> Vector{Symbol} +""" +function data_names(self::AbstractModule) + throw(MethodError(data_names, (self,))) +end + +""" + output_names(self::AbstractModule) -> Vector{Symbol} +""" +function output_names(self::AbstractModule) + throw(MethodError(output_names, (self,))) +end + +## +# State information +## + +""" + isbinded(self::AbstractModule) -> Bool +""" function isbinded(self::AbstractModule) throw(MethodError(isbinded, (self,))) end +""" + allows_training(self::AbstractModule) -> Bool +""" function allows_training(self::AbstractModule) throw(MethodError(allows_training, (self,))) end +""" + isinitialized(self::AbstractModule) -> Bool +""" function isinitialized(self::AbstractModule) throw(MethodError(isinitialized, (self,))) end +""" + hasoptimizer(self::AbstractModule) -> Bool +""" function hasoptimizer(self::AbstractModule) throw(MethodError(hasoptimizer, (self,))) end +## +# Input/Output information +## -function forward_backward(self :: AbstractModule, data_batch) - forward(self, is_train=true) - backward(self) +""" +""" +function data_shapes(self :: AbstractModule) + throw(MethodError(data_shapes, (self,))) +end + +""" +""" +function label_shapes(self :: AbstractModule) + throw(MethodError(label_shapes, (self,))) +end + +""" +""" +function output_shapes(self :: AbstractModule) + throw(MethodError(output_shapes, (self,))) +end + +## +# Parameters +## + +""" +""" +function get_params(self :: AbstractModule) + throw(MethodError(get_params, (self,))) +end + +""" +""" +function set_params(self :: AbstractModule, arg_params, aux_params) + throw(MethodError(set_params, (self, arg_params, aux_params))) +end + +""" +""" +function init_params(self :: AbstractModule, args...) + throw(MethodError(init_params, (self, args...))) +end + +### +# Setup +### +""" +""" +function bind(self :: AbstractModule, ) +end + +""" +""" +function init_optimizer(self :: AbstractModule, ) +end + +### +# Computation +### +""" +""" +function forward(self :: AbstractModule, ) +end + +""" +""" +function backward(self :: AbstractModule, ) +end + +""" +""" +function update(self :: AbstractModule, ) +end + +""" +""" +function get_outputs(self :: AbstractModule, ) end -function score(self :: AbstractModule, eval_data, eval_metric, num_batch=nothing, batch_end_callback=nothing, reset=true, epoch=0) +""" +""" +function get_input_grads(self :: AbstractModule, ) +end + +""" +""" +function update_metric(self :: AbstractModule, ) +end + +### +# Optional +## +""" + get_symbol(self::AbstractModule) -> Nullable{SymbolicNode} + +Returns the associated [`SymbolicNode`](@ref) of the module. It might not be defined or change over time. +""" +function get_symbol(self::AbstractModule) + return Nullable{SymbolicNode}() +end + +### +# High-level +### + +""" +""" +function fit(self::AbstractModule) +end + +""" +""" +function predict(self::AbstractModule) +end + +""" + score(self::AbstractModule, eval_data, eval_metric; num_batch, batch_end_callback, reset=true, epoch=0) +""" +function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, reset=true, epoch=0) @assert isbinded(self) && isinitialized(self) reset && reset!(eval_data) @@ -100,15 +250,11 @@ function score(self :: AbstractModule, eval_data, eval_metric, num_batch=nothing get(eval_metric) end -function iter_predict(self :: AbstractModule, eval_data, num_batch=nothing, reset=true) - @assert isbinded(self) && isinitialized(self) - - reset && reset!(eval_data) - - for (nbatch, eval_batch) in enumerate(eval_data) - if num_batch !== nothing && nbatch == num_back - break - end - forward(self, eval_batch, is_train=false) - samples = count_samples(eval_batch) +""" + forward_backward(self :: AbstractModule, data_batch) +""" +function forward_backward(self :: AbstractModule, data_batch) + forward(self, data_batch, is_train=true) + backward(self, data_batch) end + From 84c3e3326ae27f4aaabb63d36c399fcead012db8 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Sep 2016 06:09:12 +0900 Subject: [PATCH 03/18] integrate modules to docs --- docs/mkdocs.yml | 1 + docs/src/api/modules.md | 6 ++++++ src/MXNet.jl | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 docs/src/api/modules.md diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index bf6048180..914a786fe 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -45,4 +45,5 @@ pages: - Symbolic API: api/symbolic-node.md - Neural Networks Factory: api/nn-factory.md - Executor: api/executor.md + - Modules: api/modules.md - Network Visualization: api/visualize.md diff --git a/docs/src/api/modules.md b/docs/src/api/modules.md new file mode 100644 index 000000000..6bb3a4b53 --- /dev/null +++ b/docs/src/api/modules.md @@ -0,0 +1,6 @@ +# Modules + +```@autodocs +Modules = [MXNet.mx] +Pages = ["module/Module.jl"] +``` diff --git a/src/MXNet.jl b/src/MXNet.jl index b9de52a58..1d8bcc181 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -1,4 +1,4 @@ -__precompile__() +#__precompile__() module MXNet @@ -37,6 +37,7 @@ include("kvstore.jl") include("callback.jl") include("model.jl") +include("module/Module.jl") include("visualize.jl") From e3f81f9f22e6f9225ce80611f0289208895428b7 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Sep 2016 07:25:26 +0900 Subject: [PATCH 04/18] more prototypes --- src/executor.jl | 2 +- src/module/Module.jl | 35 ++++++-- src/module/symbol_module.jl | 174 ++++++++++++++++++++++++++++++++++++ test/test-module.jl | 10 +++ 4 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 src/module/symbol_module.jl create mode 100644 test/test-module.jl diff --git a/src/executor.jl b/src/executor.jl index 3ae5301a6..75bccba26 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -152,7 +152,7 @@ function simple_bind(self :: SymbolicNode, ctx :: Context; end end - aux_arrays = [zeros(shape, ctx) for shape in aux_shapes] + aux_arrays = NDArray[zeros(shape, ctx) for shape in aux_shapes] return bind(self, ctx, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) end diff --git a/src/module/Module.jl b/src/module/Module.jl index 221ef3222..7f94f7b5c 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,9 +1,11 @@ +module Module + """ AbstractModule A module represnets a computation component. The design purpose of a module is that abstracts a computation unit, that one can run forward, backward, update parameters, etc. -We aim to make the APIs easy to use, especially in the case when we need to use +We aim to make the APIs easy to use, especially in the case when we need to use an imperative API to work with multiple modules (e.g. stochastic depth networks). A module has several states: @@ -63,7 +65,7 @@ abstract AbstractModule ## # Names -## +## """ data_names(self::AbstractModule) -> Vector{Symbol} """ @@ -80,7 +82,7 @@ end ## # State information -## +## """ isbinded(self::AbstractModule) -> Bool @@ -112,7 +114,7 @@ end ## # Input/Output information -## +## """ """ @@ -218,12 +220,29 @@ end """ """ -function fit(self::AbstractModule) +function fit(self::AbstractModule, train_data) + + error("Not yet implemented") end """ """ -function predict(self::AbstractModule) +function predict(self::AbstractModule, eval_data; + num_batch=nothing, merge_batches=true, reset=true) + @assert isbinded(self) && isinitialized(self) + + reset && reset!(eval_data) + + for (nbatch, eval_batch) in enumerate(eval_data) + if num_batch !== nothing && nbatch == num_back + break + end + forward(self, eval_batch, is_train=false) + + outputs = get_outputs(self) + + error("Not yet implemented") + end end """ @@ -258,3 +277,7 @@ function forward_backward(self :: AbstractModule, data_batch) backward(self, data_batch) end +# include implementations +include("symbol_module.jl") + +end diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl new file mode 100644 index 000000000..a70de1980 --- /dev/null +++ b/src/module/symbol_module.jl @@ -0,0 +1,174 @@ +import ....MXNet: mx # in order to use mx. +import ..mx: SymbolicNode, NDArray, Context, Executor + +""" + Module + +Module is a basic module that wraps a `SymbolicNode`. It is functionally the same +as the `FeedForward` model, except using the module API. + +A current limitation is that it only supports one context. + +# Parameters +* `symbol :: SymbolicNode`: The wrapped `SymbolicNode` +* `data_names :: Vector{Symbol}`: +""" +type SymbolModule <: AbstractModule + symbol :: SymbolicNode + data_names :: Vector{Symbol} + label_names :: Vector{Symbol} + aux_names :: Vector{Symbol} + context :: Context + + binded :: Bool + for_training :: Bool + inputs_need_grad :: Bool + params_initialized :: Bool + optimizer_initialized :: Bool + + data_shapes :: Nullable{Vector{Tuple{Int}}} + label_shapes :: Nullable{Vector{Tuple{Int}}} + output_shapes :: Nullable{Vector{Tuple{Int}}} + + arg_arrays :: Nullable{Vector{NDArray}} + aux_arrays :: Nullable{Vector{NDArray}} + grad_arrays :: Nullable{Vector{NDArray}} + params_dirty :: Bool + + executor :: Nullable{Executor} + + function SymbolModule(symbol::SymbolicNode, data_names::Vector{Symbol}, + label_names::Vector{Symbol}, context :: Context) + + aux_names = mx.list_auxiliary_states(symbol) + return new(symbol, data_names, label_names, aux_names, context, + false, false, false, false, false, + Nullable{Vector{Tuple{Int}}}(), + Nullable{Vector{Tuple{Int}}}(), + Nullable{Vector{Tuple{Int}}}(), + Nullable{Vector{NDArray}}(), + Nullable{Vector{NDArray}}(), + Nullable{Vector{NDArray}}(), + false, + Nullable{Executor}()) + end +end + +function SymbolModule(symbol::SymbolicNode; + data_names = [:data], label_names = [:softmax_label], + context = mx.cpu()) + return SymbolModule(symbol, data_names, label_names, context) +end + +### default API +isbinded(self::SymbolModule) = self.binded +allows_training(self::SymbolModule) = self.for_training +isinitialized(self::SymbolModule) = self.params_initialized +hasoptimizer(self::SymbolModule) = self.hasoptimizer + +data_names(self::SymbolModule) = self.data_names +output_names(self::SymbolModule) = list_outputs(symbol) + +function data_shapes(self::SymbolModule) + !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + return self.data_shapes +end + +function label_shapes(self::SymbolModule) + !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + return self.label_shapes +end + +function output_shapes(self::SymbolModule) + !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + return self.output_shapes +end + +function get_params(self::SymbolModule) + if !(isbinded(self) && isinitialized(self)) + return (Nullable{Dict{Symbol, NDArray}}(), Nullable{Dict{Symbol, NDArray}}()) + end + if self.params_dirty + sync_params_from_device(self) + end + return (Dict(name => data for (name, data) in zip()), + Dict(name => data for (name, data) in zip())) +end + +function init_params(self::SymbolModule; initializer=nothing, arg_params=nothing, + aux_params=nothing, allow_missing=false, force_init=false) + if isinitialized(self) && !force_init + return + end + + @assert isbinded(self) "Call `bind` before initialization" +end + +function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Typle{Int}}(); + for_training=true, inputs_need_grad=true, force_rebind=false, + grad_req=mx.GRAD_WRITE) + if force_rebind + reset_bind(self) + end + + if isbinded(self) + warn("Already bound, ignoring bind()") + return + end + + self.for_training = for_training + self.inputs_need_grad = inputs_need_grad + self.binded = true + + #@assert !for_training && !inputs_need_grad + + @assert length(self.data_names) == length(data_shapes) + @assert length(self.label_names) == length(label_shapes) + + self.data_shapes = Nullable(data_shapes) + self.label_shapes = Nullable(label_shapes) + + provided_shapes = merge( + Dict(name => shape for (name, shape) in zip(self.data_names, data_shapes)), + Dict(name => shape for (name, shape) in zip(self.label_names, label_shapes))) + + arg_shapes, out_shapes, aux_shapes = infer_shape(self; provided_shapes...) + @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") + + # TODO: perform type inference + + arg_arrays = NDArray[mx.zeros(shape, ctx) for shape in arg_shapes] + arg_names = list_arguments(self.symbol) + + grad_arrays = Dict{Symbol,NDArray}() + + if grad_req != GRAD_NOP + shapes = zip(arg_names, arg_shapes) + + # if not in provided data, should be parameters + provided_data_names = [x[1] for x in keys(provided_shapes)] + shapes = filter(x -> !in(x[1], provided_data_names), shapes) + + # Remove all gradients for nop params + # if isa(grad_req, Dict{Symbol, GRAD_REQ}) + # shapes = filter(x -> grad_req[x[1]] != GRAD_NOP,shapes) + # end + + for (name, shape) in shapes + grad_arrays[name] = mx.zeros(shape, ctx) + end + end + + aux_arrays = NDArray[mx.zeros(shape, ctx) for shape in aux_shapes] + executor = mx.bind(self, ctx, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) + + self.executor = Nullable{Executor}(executor) +end + +## +# Internals +## + +function sync_params_from_devices(self::SymbolModule) + throw(MethodError(sync_params_from_devices, (self,))) +end diff --git a/test/test-module.jl b/test/test-module.jl new file mode 100644 index 000000000..74fb3ec16 --- /dev/null +++ b/test/test-module.jl @@ -0,0 +1,10 @@ +using MXNet + +# Create Network +symbol = mx.@chain mx.Variable(:data) => +mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => +mx.SoftmaxOutput(name=:softmax, multi_output = true) + +m1 = mx.Module.SymbolModule(symbol) + +mx.Module.bind(m1, [(20,20,1,10)], [(20,20,1,10)]) From 8806d613a267de9559e888ed46a1345be685be5a Mon Sep 17 00:00:00 2001 From: Andrey Oskin Date: Mon, 16 Jan 2017 14:28:24 +0300 Subject: [PATCH 05/18] bugfix for running test-module --- src/module/symbol_module.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index a70de1980..b0a8bddc3 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -1,5 +1,5 @@ import ....MXNet: mx # in order to use mx. -import ..mx: SymbolicNode, NDArray, Context, Executor +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP """ Module @@ -26,9 +26,9 @@ type SymbolModule <: AbstractModule params_initialized :: Bool optimizer_initialized :: Bool - data_shapes :: Nullable{Vector{Tuple{Int}}} - label_shapes :: Nullable{Vector{Tuple{Int}}} - output_shapes :: Nullable{Vector{Tuple{Int}}} + data_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} + label_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} + output_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} arg_arrays :: Nullable{Vector{NDArray}} aux_arrays :: Nullable{Vector{NDArray}} @@ -104,7 +104,7 @@ function init_params(self::SymbolModule; initializer=nothing, arg_params=nothing @assert isbinded(self) "Call `bind` before initialization" end -function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Typle{Int}}(); +function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}}(); for_training=true, inputs_need_grad=true, force_rebind=false, grad_req=mx.GRAD_WRITE) if force_rebind @@ -132,13 +132,13 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Typle{Int}} Dict(name => shape for (name, shape) in zip(self.data_names, data_shapes)), Dict(name => shape for (name, shape) in zip(self.label_names, label_shapes))) - arg_shapes, out_shapes, aux_shapes = infer_shape(self; provided_shapes...) + arg_shapes, out_shapes, aux_shapes = infer_shape(self.symbol; provided_shapes...) @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") # TODO: perform type inference - arg_arrays = NDArray[mx.zeros(shape, ctx) for shape in arg_shapes] - arg_names = list_arguments(self.symbol) + arg_arrays = NDArray[mx.zeros(shape, self.context) for shape in arg_shapes] + arg_names = mx.list_arguments(self.symbol) grad_arrays = Dict{Symbol,NDArray}() @@ -146,7 +146,7 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Typle{Int}} shapes = zip(arg_names, arg_shapes) # if not in provided data, should be parameters - provided_data_names = [x[1] for x in keys(provided_shapes)] + provided_data_names = keys(provided_shapes) shapes = filter(x -> !in(x[1], provided_data_names), shapes) # Remove all gradients for nop params @@ -155,12 +155,12 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Typle{Int}} # end for (name, shape) in shapes - grad_arrays[name] = mx.zeros(shape, ctx) + grad_arrays[name] = mx.zeros(shape, self.context) end end - aux_arrays = NDArray[mx.zeros(shape, ctx) for shape in aux_shapes] - executor = mx.bind(self, ctx, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) + aux_arrays = NDArray[mx.zeros(shape, self.context) for shape in aux_shapes] + executor = mx.bind(self.symbol, self.context, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) self.executor = Nullable{Executor}(executor) end From d1059213502d5f50f30ff2885bd3a0932aebd548 Mon Sep 17 00:00:00 2001 From: Andrey Oskin Date: Mon, 16 Jan 2017 19:01:49 +0300 Subject: [PATCH 06/18] Prototype of Executor Group. Additional tests. Added DataBatchProvider. --- src/MXNet.jl | 1 + src/executor-group.jl | 225 ++++++++++++++++++++++++++++++++ src/executor.jl | 4 +- src/io.jl | 27 ++++ src/module/Module.jl | 6 +- src/module/symbol_module.jl | 228 +++++++++++++++++++++++++-------- test/test-module.jl | 10 +- test/unittest/symbol-module.jl | 80 ++++++++++++ 8 files changed, 518 insertions(+), 63 deletions(-) create mode 100644 src/executor-group.jl create mode 100644 test/unittest/symbol-module.jl diff --git a/src/MXNet.jl b/src/MXNet.jl index 1d8bcc181..192f6ed51 100644 --- a/src/MXNet.jl +++ b/src/MXNet.jl @@ -37,6 +37,7 @@ include("kvstore.jl") include("callback.jl") include("model.jl") +include("executor-group.jl") include("module/Module.jl") include("visualize.jl") diff --git a/src/executor-group.jl b/src/executor-group.jl new file mode 100644 index 000000000..cdeae6d69 --- /dev/null +++ b/src/executor-group.jl @@ -0,0 +1,225 @@ +""" + AbstractExecutorGroup +Executor group is a convenient tool for managing a group of executors. +""" +abstract AbstractExecutorGroup + +function forward(self::AbstractExecutorGroup, data_provider :: AbstractDataProvider, + data_batch :: AbstractDataBatch, is_train) + throw(MethodError(forward, (self, ))) +end + +type DataParallelExecutorGroup <: AbstractExecutorGroup + symbol :: SymbolicNode + context :: Vector{Context} + execs :: Vector{Executor} + + data_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + label_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + for_training :: Bool + + shared_group :: Nullable{DataParallelExecutorGroup} + inputs_need_grad :: Bool + fixed_param_names :: Nullable{Vector{Symbol}} + grad_req :: Dict{Symbol, GRAD_REQ} + freeze_idx + + data_arrays :: Vector{Vector{SlicedNDArray}} + label_arrays :: Vector{Vector{SlicedNDArray}} + param_arrays :: Vector{Vector{NDArray}} + grad_arrays :: Vector{Vector{NDArray}} + aux_arrays :: Vector{Vector{NDArray}} + input_grad_arrays :: Vector{Vector{NDArray}} + + arg_params :: Dict{Symbol, NDArray} + aux_params :: Dict{Symbol, NDArray} +end +function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context}, + data_shapes, data_names, label_shapes, label_names, for_training, inputs_need_grad, + shared_group, fixed_param_names, grad_req) + + num_dev = length(context) + arg_names = list_arguments(symbol) + input_names = [data_names; label_names] + param_names = setdiff(arg_names, input_names) + aux_names = list_auxiliary_states(symbol) + + batch_size = data_shapes[1][end] + for shape in data_shapes + @assert batch_size == shape[end] + end + if !isempty(label_shapes) + for shape in label_shapes + @assert batch_size == shape[end] + end + end + + # TODO imlplement workload + slices = _split_inputs(batch_size, num_dev) + + execs = Vector{Executor}(num_dev) + + provided_shapes = merge(Dict(name => shape for (name, shape) in zip(data_names, data_shapes)), + Dict(name => shape for (name, shape) in zip(label_names, label_shapes))) + arg_shapes, out_shapes, aux_shapes = infer_shape(symbol; provided_shapes...) + @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") + + grad_req, freeze_idx = get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) + + arg_params = Dict{Symbol, NDArray}() + aux_params = Dict{Symbol, NDArray}() + + for (name, shape) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes)) + arg_params[name] = empty(shape) + end + + for (name, shape) in zip(aux_names, aux_shapes) + aux_params[name] = empty(shape) + end + + for i = 1:num_dev + data_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(data_names, data_shapes)) + label_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(label_names, label_shapes)) + arg_arrays = NDArray[zeros(shape, context[i]) for shape in arg_shapes] + grad_arrays = Dict{Symbol,NDArray}() + aux_arrays = NDArray[zeros(shape, context[i]) for shape in aux_shapes] + + shapes = zip(arg_names, arg_shapes) + + # if not in provided data, should be parameters + if inputs_need_grad + provided_data_names = label_names + else + provided_data_names = [data_names; label_names] + end + shapes = filter(x -> !in(x[1], provided_data_names), shapes) + + # Remove all gradients for nop params + shapes = filter(x -> grad_req[x[1]] != GRAD_NOP, shapes) + + for (name, shape) in shapes + grad_arrays[name] = zeros(shape, context[i]) + end + + execs[i] = bind(symbol, context[i], arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) + #= dbg_str = mx.debug_str(train_execs[i]) =# + #= info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i])) =# + end + + # TODO: perform type inference + + # set up input data structures + data_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in data_names] + label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in label_names] + + param_idx = filter(i -> in(arg_names[i], param_names), 1:length(arg_names)) + name_idx = filter(i -> in(arg_names[i], data_names), 1:length(arg_names)) + + param_arrays = [NDArray[exec.arg_arrays[i] for exec in execs] for i in param_idx] + grad_arrays = [NDArray[exec.grad_arrays[i] for exec in execs] for i in param_idx] + aux_arrays = [NDArray[exec.aux_arrays[i] for exec in execs] for i = 1:length(aux_names)] + + if inputs_need_grad + input_grad_arrays = [NDArray[exec.grad_arrays[i] for exec in execs] for i in name_idx] + else + input_grad_arrays = [] + end + + return DataParallelExecutorGroup( + symbol, context, execs, + data_shapes, label_shapes, for_training, + shared_group, inputs_need_grad, fixed_param_names, grad_req, freeze_idx, + data_arrays, label_arrays, param_arrays, grad_arrays, aux_arrays, + input_grad_arrays, arg_params, aux_params) +end + +""" + forward(exec_group, data_batch, is_train) +Split `data_batch` according to workload and run forward on each devices. +# Arguments +* `data_batch` : AbstractDataBatch +* `is_train` : Nullable{Bool} + The hint for the backend, indicating whether we are during training phase. + Default is `nothing`, then the value `self.for_training` will be used. +""" +function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train = nothing) + + load_data!(data_provider, data_batch, self.data_arrays) + is_train = get(is_train, self.for_training) + + if is_train && !isempty(get_label(data_provider, data_batch)) + load_label!(data_provider, data_batch, self.label_arrays) + end + + for exec in self.execs + forward(exec, is_train=is_train) + end + # TODO add callbacks here +end + +""" + set_params!(self::DataParallelExecutorGroup, arg_params, aux_params) + +Assign, i.e. copy parameters to all the executors. +# Arguments +* `arg_params` : Dict{Symbol, NDArray} + A dictionary of name to `NDArray` parameter mapping. +* `aux_params` : Dict{Symbol, NDArray} + A dictionary of name to `NDArray` auxiliary variable mapping. +""" +function set_params!(self::DataParallelExecutorGroup, + arg_params, aux_params; allow_extra_params::Bool = false) + for exec in self.execs + copy_params_from(exec, arg_params, aux_params, allow_extra_params=allow_extra_params) + end +end + +## +# Internals +## + + +function output_shapes(self:: DataParallelExecutorGroup) + #= outputs = [size(out) for out in self.execs[1].outputs] =# + #= return [tuple(key, shape) for key, shape in zip(list_outputs(exec_group.symbol), outputs)] =# +end + +function get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) + if isnull(fixed_param_names) + # get grad attribute to allow for freezing + fixed_param_names = Symbol[] + for (attr, value) in list_all_attr(symbol) + sattr = string(attr) + if endswith(sattr, "grad") && value == "freeze" + push!(fixed_param_names, Symbol(sattr[1:end-5])) + end + end + else + fixed_param_names = get(fixed_param_names) + end + + # Needs to correspond to the correct id in the update loop layer idx=1:length(param_names). + freeze_idx = filter(i -> in(param_names[i], fixed_param_names), 1:length(param_names)) + + # Setup grad_req as a dictionary + grad_req_dict = Dict{Symbol, GRAD_REQ}() + for param in arg_names + if param in param_names + if in(param, fixed_param_names) + grad_req_dict[param] = GRAD_NOP + else + grad_req_dict[param] = grad_req + end + elseif param in data_names + if inputs_need_grad + grad_req_dict[param] = grad_req + else + grad_req_dict[param] = GRAD_NOP + end + else + grad_req_dict[param] = GRAD_NOP + end + end + + return grad_req_dict, freeze_idx +end diff --git a/src/executor.jl b/src/executor.jl index 75bccba26..077491216 100644 --- a/src/executor.jl +++ b/src/executor.jl @@ -179,8 +179,7 @@ function backward(self :: Executor, out_grads :: Vector{NDArray}) end -function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray}, - aux_params::Union{Void,Dict{Base.Symbol,NDArray}}=nothing; +function copy_params_from(self::Executor, arg_params, aux_params=nothing; allow_extra_params::Bool=false) for (name, array) in arg_params if haskey(self.arg_dict, name) @@ -201,7 +200,6 @@ function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray}, end end - """ Get a debug string about internal execution plan. diff --git a/src/io.jl b/src/io.jl index f65314e67..1d50dd44a 100644 --- a/src/io.jl +++ b/src/io.jl @@ -122,6 +122,33 @@ count_samples(batch :: DataBatch) = batch.count get_data{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.data get_label{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.label +type DataBatchProvider <: AbstractDataProvider + provider :: AbstractDataProvider + + DataBatchProvider() = new() + DataBatchProvider(provider) = new(provider) +end + +eachdatabatch(provider :: AbstractDataProvider) = DataBatchProvider(provider) + +function Base.eltype(provider :: DataBatchProvider) + DataBatch +end +function Base.start(provider :: DataBatchProvider) + return Base.start(provider.provider) +end +function Base.next(provider :: DataBatchProvider, state :: AbstractDataProviderState) + (inner_batch, next_state) = Base.next(provider.provider, state) + batch = DataBatch(get_data(provider.provider, inner_batch), + get_label(provider.provider, inner_batch), + count_samples(provider.provider, inner_batch)) + + return (batch, next_state) +end +function Base.done(provider :: DataBatchProvider, state :: AbstractDataProviderState) + return Base.done(provider.provider, state) +end + """ SlicedNDArray diff --git a/src/module/Module.jl b/src/module/Module.jl index 7f94f7b5c..511d3a2e5 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,4 +1,6 @@ module Module +import ....MXNet: mx +import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider """ AbstractModule @@ -174,7 +176,9 @@ end ### """ """ -function forward(self :: AbstractModule, ) +forward(self :: AbstractModule, data_batch :: DataBatch, is_train=nothing) = forward(self, DataBatchProvider(), data_batch, is_train) +function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train=nothing) + throw(MethodError(forward, (self, ))) end """ diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index b0a8bddc3..a4277e5b1 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -1,5 +1,5 @@ import ....MXNet: mx # in order to use mx. -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer """ Module @@ -18,7 +18,7 @@ type SymbolModule <: AbstractModule data_names :: Vector{Symbol} label_names :: Vector{Symbol} aux_names :: Vector{Symbol} - context :: Context + context :: Vector{Context} binded :: Bool for_training :: Bool @@ -26,61 +26,76 @@ type SymbolModule <: AbstractModule params_initialized :: Bool optimizer_initialized :: Bool - data_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} - label_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} - output_shapes :: Nullable{Vector{Tuple{Vararg{Int}}}} + data_shapes :: Vector{Tuple{Vararg{Int}}} + label_shapes :: Vector{Tuple{Vararg{Int}}} + output_shapes :: Vector{Tuple{Vararg{Int}}} arg_arrays :: Nullable{Vector{NDArray}} aux_arrays :: Nullable{Vector{NDArray}} grad_arrays :: Nullable{Vector{NDArray}} params_dirty :: Bool - executor :: Nullable{Executor} + fixed_param_names :: Nullable{Vector{Symbol}} + optimizer + kvstore + update_on_kvstore + + arg_params + aux_params + + exec_group :: AbstractExecutorGroup function SymbolModule(symbol::SymbolicNode, data_names::Vector{Symbol}, - label_names::Vector{Symbol}, context :: Context) + label_names::Vector{Symbol}, context :: Vector{Context}, + fixed_param_names::Nullable{Vector{Symbol}}) aux_names = mx.list_auxiliary_states(symbol) return new(symbol, data_names, label_names, aux_names, context, false, false, false, false, false, - Nullable{Vector{Tuple{Int}}}(), - Nullable{Vector{Tuple{Int}}}(), - Nullable{Vector{Tuple{Int}}}(), + Vector{Tuple{Int}}(), + Vector{Tuple{Int}}(), + Vector{Tuple{Int}}(), Nullable{Vector{NDArray}}(), Nullable{Vector{NDArray}}(), Nullable{Vector{NDArray}}(), false, - Nullable{Executor}()) + fixed_param_names) end end function SymbolModule(symbol::SymbolicNode; data_names = [:data], label_names = [:softmax_label], - context = mx.cpu()) - return SymbolModule(symbol, data_names, label_names, context) + context = [mx.cpu()], fixed_param_names = nothing) + fixed_param_names = Nullable{Vector{Symbol}}(fixed_param_names) + if !isa(context, Vector{Context}) + context = [context] + end + @assert !isempty(data_names) + @assert !isempty(context) + return SymbolModule(symbol, data_names, label_names, context, fixed_param_names) end ### default API isbinded(self::SymbolModule) = self.binded allows_training(self::SymbolModule) = self.for_training isinitialized(self::SymbolModule) = self.params_initialized -hasoptimizer(self::SymbolModule) = self.hasoptimizer +hasoptimizer(self::SymbolModule) = self.optimizer_initialized data_names(self::SymbolModule) = self.data_names -output_names(self::SymbolModule) = list_outputs(symbol) +output_names(self::SymbolModule) = list_outputs(self.symbol) function data_shapes(self::SymbolModule) - !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + !isbinded(self) && return Vector{Tuple{Int}}() return self.data_shapes end function label_shapes(self::SymbolModule) - !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + !isbinded(self) && return Vector{Tuple{Int}}() return self.label_shapes end function output_shapes(self::SymbolModule) - !isbinded(self) && return Nullable{Vector{Tuple{Int}}}() + !isbinded(self) && return Vector{Tuple{Int}}() return self.output_shapes end @@ -91,78 +106,172 @@ function get_params(self::SymbolModule) if self.params_dirty sync_params_from_device(self) end - return (Dict(name => data for (name, data) in zip()), - Dict(name => data for (name, data) in zip())) + + return (self.arg_params, self.aux_params) end -function init_params(self::SymbolModule; initializer=nothing, arg_params=nothing, - aux_params=nothing, allow_missing=false, force_init=false) +function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), arg_params=nothing, + aux_params=nothing, allow_extra_params=false, force_init=false) if isinitialized(self) && !force_init - return + return self end @assert isbinded(self) "Call `bind` before initialization" + + if !isdefined(self, :arg_params) || isempty(self.arg_params) + self.arg_params = Dict(k => zeros(size(v)) for (k, v) in self.exec_group.arg_params) + end + + if !isdefined(self, :aux_params) || isempty(self.aux_params) + self.aux_params = Dict(k => zeros(size(v)) for (k, v) in self.exec_group.aux_params) + end + + # TODO need initialization + + # copy the initialized parameters to devices + set_params!(self.exec_group, self.arg_params, self.aux_params, allow_extra_params=allow_extra_params) + + self.params_dirty = false + self.params_initialized = true + + return self end function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}}(); for_training=true, inputs_need_grad=true, force_rebind=false, - grad_req=mx.GRAD_WRITE) + grad_req=mx.GRAD_WRITE, shared_group = nothing) if force_rebind reset_bind(self) end if isbinded(self) warn("Already bound, ignoring bind()") - return + return self + end + + if !for_training + @assert !inputs_need_grad end self.for_training = for_training self.inputs_need_grad = inputs_need_grad self.binded = true - #@assert !for_training && !inputs_need_grad @assert length(self.data_names) == length(data_shapes) @assert length(self.label_names) == length(label_shapes) - self.data_shapes = Nullable(data_shapes) - self.label_shapes = Nullable(label_shapes) + self.data_shapes = data_shapes + self.label_shapes = label_shapes + + self.exec_group = DataParallelExecutorGroup(self.symbol, self.context, + self.data_shapes, self.data_names, + self.label_shapes, self.label_names, + self.for_training, self.inputs_need_grad, shared_group, + self.fixed_param_names, grad_req) + return self +end - provided_shapes = merge( - Dict(name => shape for (name, shape) in zip(self.data_names, data_shapes)), - Dict(name => shape for (name, shape) in zip(self.label_names, label_shapes))) +# TODO add description +function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), kvstore :: Union{Base.Symbol, KVStore}=:local, force_init :: Bool=false) + @assert isbinded(self) && isinitialized(self) - arg_shapes, out_shapes, aux_shapes = infer_shape(self.symbol; provided_shapes...) - @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") + if hasoptimizer(self) && !force_init + warn("Optimizer already initialized, ignoring...") + return self + end - # TODO: perform type inference + # TODO initialize KV store + # setup kvstore + #= kvstore, update_on_kvstore = _create_kvstore(kvstore, length(self.context), self.arg_params) =# + kvstore, update_on_kvstore = nothing, false - arg_arrays = NDArray[mx.zeros(shape, self.context) for shape in arg_shapes] - arg_names = mx.list_arguments(self.symbol) + self.optimizer = optimizer + self.kvstore = kvstore + self.update_on_kvstore = update_on_kvstore + self.optimizer_initialized = true - grad_arrays = Dict{Symbol,NDArray}() + # add adequate calculation of batch_size + op_state = OptimizationState(self.data_shapes[1][end]) + optimizer.state = op_state - if grad_req != GRAD_NOP - shapes = zip(arg_names, arg_shapes) + if !isa(kvstore, Void) + if update_on_kvstore + set_optimizer(kvstore, optimizer) + end - # if not in provided data, should be parameters - provided_data_names = keys(provided_shapes) - shapes = filter(x -> !in(x[1], provided_data_names), shapes) + info("Initializing KVStore...") + # init kv with gradients + for idx = 1:length(param_arrays) + param_on_devs = param_arrays[idx] - # Remove all gradients for nop params - # if isa(grad_req, Dict{Symbol, GRAD_REQ}) - # shapes = filter(x -> grad_req[x[1]] != GRAD_NOP,shapes) - # end + init!(kvstore, idx, self.arg_params[param_names[idx]]) - for (name, shape) in shapes - grad_arrays[name] = mx.zeros(shape, self.context) + if update_on_kvstore + # pull weights back + pull!(kvstore, idx, param_on_devs, priority=-idx) + end end end + + # TODO add preloaded states + #= if !isa(self.preload_opt_states, Void) =# + #= load_optimizer_states!(self, self.preload_opt_states) =# + #= self.preload_opt_states = nothing =# + #= end =# + + return self +end - aux_arrays = NDArray[mx.zeros(shape, self.context) for shape in aux_shapes] - executor = mx.bind(self.symbol, self.context, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) +# TODO add description +""" + forward(module, data_provider, data_batch; is_train) +Forward computation. +# Arguments +* `data_batch` : AbstractDataBatch +* `is_train` : Nullable{Bool} + Default is `nothing`, which means `is_train` takes the value of `self.for_training`. +""" +function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train=nothing) + @assert isbinded(self) && isinitialized(self) + is_train = convert(Nullable{Bool}, is_train) + mx.forward(self.exec_group, data_provider, data_batch, is_train) +end - self.executor = Nullable{Executor}(executor) +""" + backward(module, out_grads) +Backward computation. +# Arguments +* `out_grads` : NDArray or list of NDArray, optional + Gradient on the outputs to be propagated back. + This parameter is only needed when bind is called + on outputs that are not a loss function. +""" +function backward(self:: SymbolModule, out_grads=nothing) + @assert isbinded(self) && isinitialized(self) + backward(self.exec_group, out_grads=out_grads) +end + + +""" + update!(mod) +Update parameters according to the installed optimizer and the gradients computed +in the previous forward-backward batch. +""" +function update!(self::SymbolModule) + @assert isbinded(self) && isinitialized(self) && hasoptimizer(self) + self.params_dirty = true + if self.update_on_kvstore + _update_params_on_kvstore(self.kvstore, + self.exec_group.param_arrays, + self.exec_group.grad_arrays) + else + _update_params(self.kvstore, + self.exec_group.param_arrays, + self.exec_group.grad_arrays, + updater=self.updater, + num_device=length(self.context)) + end end ## @@ -172,3 +281,20 @@ end function sync_params_from_devices(self::SymbolModule) throw(MethodError(sync_params_from_devices, (self,))) end + +""" + borrow_optimizer!(module, shared_module) +Borrow optimizer from a shared module. Used in bucketing, where exactly the same +optimizer (esp. kvstore) is used. +# Arguments +* `module` : SymbolModule +* `shared_module` : SymbolModule +""" +function borrow_optimizer!(self::SymbolModule, shared_module::SymbolModule) + @assert hasoptimizer(shared_module) + self.optimizer = shared_module.optimizer + self.kvstore = shared_module.kvstore + self.update_on_kvstore = shared_module.update_on_kvstore + self.updater = shared_module.updater + self.optimizer_initialized = true +end diff --git a/test/test-module.jl b/test/test-module.jl index 74fb3ec16..66a1f63df 100644 --- a/test/test-module.jl +++ b/test/test-module.jl @@ -1,10 +1,4 @@ using MXNet -# Create Network -symbol = mx.@chain mx.Variable(:data) => -mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => -mx.SoftmaxOutput(name=:softmax, multi_output = true) - -m1 = mx.Module.SymbolModule(symbol) - -mx.Module.bind(m1, [(20,20,1,10)], [(20,20,1,10)]) +include(joinpath(dirname(@__FILE__), "common.jl")) +include(joinpath(dirname(@__FILE__), "unittest", "symbol-module.jl")) diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl new file mode 100644 index 000000000..e29309016 --- /dev/null +++ b/test/unittest/symbol-module.jl @@ -0,0 +1,80 @@ +module TestSymbolModule +using MXNet +using Base.Test + +################################################################################ +# Utils +################################################################################ + +function create_network() + arch = mx.@chain mx.Variable(:data) => + mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => + mx.SoftmaxOutput(name=:softmax, multi_output = true) + + return arch +end + +function create_single_neuron() + arch = @mx.chain mx.Variable(:data) => + mx.FullyConnected(name=:fc1, num_hidden=1) => + mx.LinearRegressionOutput(name=:linout) + return arch +end + +################################################################################ +# Test Implementations +################################################################################ + +function test_basic() + info("SymbolModule::basic") + + m1 = mx.Module.SymbolModule(create_network()) + + @test !mx.Module.isbinded(m1) + @test !mx.Module.allows_training(m1) + @test !mx.Module.isinitialized(m1) + @test !mx.Module.hasoptimizer(m1) + + @test mx.Module.data_names(m1) == [:data] + @test mx.Module.output_names(m1) == [:softmax_output] + + mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) + @test mx.Module.isbinded(m1) + @test !mx.Module.isinitialized(m1) + @test !mx.Module.hasoptimizer(m1) + + mx.Module.init_params(m1) + @test mx.Module.isinitialized(m1) + + mx.Module.init_optimizer(m1) + @test mx.Module.hasoptimizer(m1) +end + +function test_init_params() + info("SymbolModule::InitParams") + m1 = mx.Module.SymbolModule(create_single_neuron(), + label_names = [:linout_label]) + mx.Module.bind(m1, [(1, 10)], [(1, 10)]) + mx.Module.init_params(m1) + + # TODO Should be changed to tests + info(mx.Module.get_params(m1)) + + x = reshape(collect(1:10), (1, 10)) + y = reshape(collect(2:11), (1, 10)) + data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 10) + for batch in mx.eachdatabatch(data) + mx.Module.forward(m1, batch) + end +end + +################################################################################ +# Run tests +################################################################################ + +@testset "Symbol Module Test" begin + test_basic() + test_init_params() +end + +end From ff66f55cf0bd43b001d741ea8736387eacc5b816 Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Fri, 20 Jan 2017 01:55:23 +0300 Subject: [PATCH 07/18] Prototyping full training --- src/executor-group.jl | 61 +++++++++++++++++++++++++++++++++- src/module/symbol_module.jl | 26 ++++++--------- test/unittest/symbol-module.jl | 2 ++ 3 files changed, 72 insertions(+), 17 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index cdeae6d69..658c9fd78 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -16,7 +16,10 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup data_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} label_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + for_training :: Bool + slices :: Vector{UnitRange{Int}} + batch_size :: Int shared_group :: Nullable{DataParallelExecutorGroup} inputs_need_grad :: Bool @@ -127,7 +130,7 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context return DataParallelExecutorGroup( symbol, context, execs, - data_shapes, label_shapes, for_training, + data_shapes, label_shapes, for_training, slices, batch_size, shared_group, inputs_need_grad, fixed_param_names, grad_req, freeze_idx, data_arrays, label_arrays, param_arrays, grad_arrays, aux_arrays, input_grad_arrays, arg_params, aux_params) @@ -157,6 +160,21 @@ function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractData # TODO add callbacks here end +# TODO Add description +backward(self::DataParallelExecutorGroup, out_grads::Void) = backward(self, NDArray[]) +backward(self::DataParallelExecutorGroup, out_grads::NDArray) = backward(self, [out_grads]) +function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}=Vector{NDArray}()) + @assert self.for_training, "re-bind with for_training=True to run backward" + + for (i, exec) in enumerate(self.execs) + out_grad_slices = NDArray[] + for grad in out_grads + push!(out_grad_slices, copy(grad, self.context[i])) + end + backward(exec, out_grad_slices) + end +end + """ set_params!(self::DataParallelExecutorGroup, arg_params, aux_params) @@ -174,6 +192,47 @@ function set_params!(self::DataParallelExecutorGroup, end end +## +# Utility +## + +update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::Void = nothing) = update_params(self, updater, update_on_kvstore, Nullable{KVStore}()) +update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::KVStore) = update_params(self, updater, update_on_kvstore, Nullable(kvstore)) +function update_params(self::DataParallelExecutorGroup, updater, update_on_kvstore, kvstore::Nullable{KVStore}) + num_dev = length(self.context) + for idx = 1:length(self.param_names) + #= if isa(self.grad_arrays[i][1], Void) =# + #= continue =# + #= end =# + if in(idx, self.freeze_idx) + continue # Skip parameter update entirely + end + if !isnull(kvstore) + kvstore = get(kvstore) + # push gradient, priority is negative index + push!(kvstore, idx, self.param_arrays[idx], priority=-idx) + if update_on_kvstore + # pull back the weights + pull!(kvstore, idx, self.param_arrays[idx], priority=-idx) + else + # pull back the sum-ed gradients, to the same locations + pull!(kvstore, idx, self.grad_arrays[idx], priority=-idx) + end + end + + if !update_on_kvstore + # manual updating + for i_dev = 1:num_dev + # create a fake index, so that the updater create states + # for different param AND different devices, TODO(mli) + # use a better solution later + fake_idx = idx * num_dev + i_dev + get(updater)(fake_idx, self.grad_arrays[idx][i_dev], self.param_arrays[idx][i_dev]) + end + end + end +end + ## # Internals ## diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index a4277e5b1..b53a5b733 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -189,12 +189,15 @@ function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), self.optimizer = optimizer self.kvstore = kvstore self.update_on_kvstore = update_on_kvstore - self.optimizer_initialized = true + self.updater = Nullable() # add adequate calculation of batch_size op_state = OptimizationState(self.data_shapes[1][end]) optimizer.state = op_state + if !update_on_kvstore + self.updater = Nullable(get_updater(optimizer)) + end if !isa(kvstore, Void) if update_on_kvstore set_optimizer(kvstore, optimizer) @@ -220,6 +223,8 @@ function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), #= self.preload_opt_states = nothing =# #= end =# + self.optimizer_initialized = true + return self end @@ -247,31 +252,20 @@ Backward computation. This parameter is only needed when bind is called on outputs that are not a loss function. """ -function backward(self:: SymbolModule, out_grads=nothing) +function backward(self::SymbolModule, out_grads::Union{NDArray, Vector{NDArray}}=Vector{NDArray}()) @assert isbinded(self) && isinitialized(self) backward(self.exec_group, out_grads=out_grads) end - """ - update!(mod) + update!(module) Update parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. """ -function update!(self::SymbolModule) +function update(self::SymbolModule) @assert isbinded(self) && isinitialized(self) && hasoptimizer(self) self.params_dirty = true - if self.update_on_kvstore - _update_params_on_kvstore(self.kvstore, - self.exec_group.param_arrays, - self.exec_group.grad_arrays) - else - _update_params(self.kvstore, - self.exec_group.param_arrays, - self.exec_group.grad_arrays, - updater=self.updater, - num_device=length(self.context)) - end + update_params(self.exec_group, self.updater, self.update_on_kvstore, self.kvstore) end ## diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index e29309016..71e63377b 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -65,6 +65,8 @@ function test_init_params() data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 10) for batch in mx.eachdatabatch(data) mx.Module.forward(m1, batch) + mx.Module.backward(m1) + mx.Module.update(m1) end end From 7b64029eb9067330e80b95050db08bee28c5a270 Mon Sep 17 00:00:00 2001 From: Andrey Oskin Date: Fri, 20 Jan 2017 18:16:22 +0300 Subject: [PATCH 08/18] Almost finished forward-backward --- src/executor-group.jl | 95 ++++++++++++++++++++++++++++++---- src/module/symbol_module.jl | 63 +++++++++++++++++----- src/ndarray.jl | 29 ++++++++++- test/unittest/symbol-module.jl | 27 ++++++---- 4 files changed, 181 insertions(+), 33 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index 658c9fd78..4429c091c 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -14,8 +14,8 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup context :: Vector{Context} execs :: Vector{Executor} - data_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} - label_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + data_shapes :: Vector{Tuple{Vararg{Int}}} + label_shapes :: Vector{Tuple{Vararg{Int}}} for_training :: Bool slices :: Vector{UnitRange{Int}} @@ -36,6 +36,8 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup arg_params :: Dict{Symbol, NDArray} aux_params :: Dict{Symbol, NDArray} + param_names :: Vector{Symbol} + aux_names :: Vector{Symbol} end function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context}, data_shapes, data_names, label_shapes, label_names, for_training, inputs_need_grad, @@ -81,13 +83,15 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context end for i = 1:num_dev - data_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(data_names, data_shapes)) - label_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(label_names, label_shapes)) - arg_arrays = NDArray[zeros(shape, context[i]) for shape in arg_shapes] + data_shapes_dev = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(data_names, data_shapes)) + label_shapes_dev = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(label_names, label_shapes)) + arg_shapes_dev, out_shapes_dev, aux_shapes_dev = infer_shape(symbol; data_shapes_dev..., label_shapes_dev...) + @assert(!isa(arg_shapes_dev, Void), "Information not enough to perform complete shape inference") + arg_arrays = NDArray[zeros(shape, context[i]) for shape in arg_shapes_dev] grad_arrays = Dict{Symbol,NDArray}() - aux_arrays = NDArray[zeros(shape, context[i]) for shape in aux_shapes] + aux_arrays = NDArray[zeros(shape, context[i]) for shape in aux_shapes_dev] - shapes = zip(arg_names, arg_shapes) + shapes = zip(arg_names, arg_shapes_dev) # if not in provided data, should be parameters if inputs_need_grad @@ -133,7 +137,7 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context data_shapes, label_shapes, for_training, slices, batch_size, shared_group, inputs_need_grad, fixed_param_names, grad_req, freeze_idx, data_arrays, label_arrays, param_arrays, grad_arrays, aux_arrays, - input_grad_arrays, arg_params, aux_params) + input_grad_arrays, arg_params, aux_params, param_names, aux_names) end """ @@ -163,8 +167,8 @@ end # TODO Add description backward(self::DataParallelExecutorGroup, out_grads::Void) = backward(self, NDArray[]) backward(self::DataParallelExecutorGroup, out_grads::NDArray) = backward(self, [out_grads]) -function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}=Vector{NDArray}()) - @assert self.for_training, "re-bind with for_training=True to run backward" +function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}) + @assert(self.for_training, "re-bind with for_training=true to run backward") for (i, exec) in enumerate(self.execs) out_grad_slices = NDArray[] @@ -233,6 +237,75 @@ function update_params(self::DataParallelExecutorGroup, updater, update_on_kvsto end end +""" + get_params!(self, arg_params, aux_params) + +Copy data from each executor to `arg_params` and `aux_params`. +# Arguments +* `arg_params` : Dict{Symbol, Vector{NDArray}}. Target parameter arrays +* `aux_params` : Dict{Symbol, Vector{NDArray}}. Target aux arrays + +# Notes +This function will inplace update the NDArrays in arg_params and aux_params. +""" +function get_params!(self::DataParallelExecutorGroup, arg_params::Dict{Symbol, NDArray}, + aux_params::Dict{Symbol, NDArray}) + for (name, block) in zip(self.param_names, self.param_arrays) + w = empty(size(block[1])) + for i in 1:length(block) + @inplace w .+= copy(block[i], cpu()) + end + @inplace w ./= length(block) + copy!(arg_params[name], w) + end + for (name, block) in zip(self.aux_names, self.aux_arrays) + w = empty(size(block[1])) + for i in 1:length(block) + @inplace w .+= copy(block[i], cpu()) + end + @inplace w ./= length(block) + copy!(aux_params[name], w) + end +end + +""" + +Accumulate the performance according to `eval_metric` on all devices. +# Parameters +* eval_metric : EvalMetric + The metric used for evaluation. +* labels : list of NDArray + Typically comes from `label` of a `DataBatch`. +""" +function update_metric(self::DataParallelExecutorGroup, eval_metric::AbstractEvalMetric, labels) + +end + +""" + get_outputs + +Get outputs of the previous forward computation. + +# Arguments +merge_multi_context : Bool +Default is `True`. In the case when data-parallelism is used, the outputs +will be collected from multiple devices. A `True` value indicate that we +should merge the collected results so that they look like from a single +executor. +# Returns +If `merge_multi_context` is `true`, it is like `[out1, out2]`. Otherwise, it +is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool=true) + outputs = [[exec.outputs[i] for exec in self.execs] for i in 1:length(self.execs[1].outputs)] + + if merge_multi_context + return _merge_multi_context(outputs) + else + return outputs + end +end ## # Internals ## @@ -282,3 +355,5 @@ function get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, return grad_req_dict, freeze_idx end + +_merge_multi_context(outputs) = [concatenate(tensors, always_copy=false) for tensors in outputs] diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index b53a5b733..aefcba93e 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -1,5 +1,5 @@ import ....MXNet: mx # in order to use mx. -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric """ Module @@ -37,6 +37,7 @@ type SymbolModule <: AbstractModule fixed_param_names :: Nullable{Vector{Symbol}} optimizer + updater kvstore update_on_kvstore @@ -104,7 +105,7 @@ function get_params(self::SymbolModule) return (Nullable{Dict{Symbol, NDArray}}(), Nullable{Dict{Symbol, NDArray}}()) end if self.params_dirty - sync_params_from_device(self) + mx.get_params!(self.exec_group, self.arg_params, self.aux_params) end return (self.arg_params, self.aux_params) @@ -119,11 +120,11 @@ function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), a @assert isbinded(self) "Call `bind` before initialization" if !isdefined(self, :arg_params) || isempty(self.arg_params) - self.arg_params = Dict(k => zeros(size(v)) for (k, v) in self.exec_group.arg_params) + self.arg_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.arg_params) end if !isdefined(self, :aux_params) || isempty(self.aux_params) - self.aux_params = Dict(k => zeros(size(v)) for (k, v) in self.exec_group.aux_params) + self.aux_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.aux_params) end # TODO need initialization @@ -137,6 +138,8 @@ function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), a return self end +bind(self::SymbolModule, data_provider::AbstractDataProvider; kwargs...) = bind(self, map((x) -> x[2], provide_data(data_provider)), + map((x) -> x[2], provide_label(data_provider)); kwargs...) function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}}(); for_training=true, inputs_need_grad=true, force_rebind=false, grad_req=mx.GRAD_WRITE, shared_group = nothing) @@ -228,6 +231,29 @@ function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), return self end +""" + get_outputs + +Get outputs of the previous forward computation. + +# Arguments +* merge_multi_context : bool + Default is `True`. In the case when data-parallelism is used, the outputs + will be collected from multiple devices. A `True` value indicate that we + should merge the collected results so that they look like from a single + executor. + +# Returns +If `merge_multi_context` is `true`, it is like `[out1, out2]`. Otherwise, it +is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_outputs(self::SymbolModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) + + mx.get_outputs(self.exec_group, merge_multi_context) +end + # TODO add description """ forward(module, data_provider, data_batch; is_train) @@ -237,9 +263,9 @@ Forward computation. * `is_train` : Nullable{Bool} Default is `nothing`, which means `is_train` takes the value of `self.for_training`. """ -function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train=nothing) +forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train = nothing) = forward(self, data_provider, data_batch, Nullable{Bool}(is_train)) +function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Nullable{Bool}) @assert isbinded(self) && isinitialized(self) - is_train = convert(Nullable{Bool}, is_train) mx.forward(self.exec_group, data_provider, data_batch, is_train) end @@ -252,13 +278,15 @@ Backward computation. This parameter is only needed when bind is called on outputs that are not a loss function. """ -function backward(self::SymbolModule, out_grads::Union{NDArray, Vector{NDArray}}=Vector{NDArray}()) +backward(self::SymbolModule, out_grads::Void=nothing) = backward(self, NDArray[]) +backward(self::SymbolModule, out_grads::NDArray) = backward(self, [out_grads]) +function backward(self::SymbolModule, out_grads::Vector{NDArray}) @assert isbinded(self) && isinitialized(self) - backward(self.exec_group, out_grads=out_grads) + mx.backward(self.exec_group, out_grads) end """ - update!(module) + update(module) Update parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. """ @@ -268,14 +296,23 @@ function update(self::SymbolModule) update_params(self.exec_group, self.updater, self.update_on_kvstore, self.kvstore) end +""" + update_metric() + +Evaluate and accumulate evaluation metric on outputs of the last forward computation. +# Arguments +* eval_metric : EvalMetric +* labels : Dict of NDArray + Typically `data_batch.label`. +""" +function update_metric(self::SymbolModule, eval_metric::AbstractEvalMetric, labels) + mx.update_metric(self.exec_group, eval_metric, labels) +end + ## # Internals ## -function sync_params_from_devices(self::SymbolModule) - throw(MethodError(sync_params_from_devices, (self,))) -end - """ borrow_optimizer!(module, shared_module) Borrow optimizer from a shared module. Used in bucketing, where exactly the same diff --git a/src/ndarray.jl b/src/ndarray.jl index d37b321a2..1a0e8aff0 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -353,7 +353,10 @@ end function setindex!(arr :: NDArray, val :: NDArray, ::Colon) copy!(arr, val) end -function setindex!{T<:Real}(arr :: NDArray, val :: Union{T,Array{T},NDArray}, idx::UnitRange{Int}) +function setindex!(arr :: NDArray, val :: NDArray, idx::UnitRange{Int}) + setindex!(slice(arr, idx), val, Colon()) +end +function setindex!{T<:Real}(arr :: NDArray, val :: Union{T,Array{T}}, idx::UnitRange{Int}) setindex!(slice(arr, idx), val, Colon()) end @@ -679,6 +682,30 @@ function /(arg0 :: NDArray, arg :: Real) ./(arg0, arg) end +function concatenate(arrays::Vector{NDArray}; always_copy=true) + if isempty(arrays) || (!always_copy && length(arrays) == 1) + return arrays + end + + shape_axis = size(arrays[1])[end] + shape_rest = size(arrays[1])[1:end-1] + for i in 2:length(arrays) + shape_axis += size(arrays[i])[end] + @assert shape_rest == size(arrays[i])[1:end-1] + end + + ret_shape = tuple(shape_rest..., shape_axis) + ret = empty(ret_shape, context(arrays[1])) + + idx = 1 + for arr in arrays + setindex!(ret, arr, idx:idx+size(arr)[end]) + #= ret[idx:idx + size(arr)[end]] = arr =# + idx += size(arr)[end] + end + + return ret +end """ Manipulating as Julia Arrays diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index 71e63377b..a62a3fa5f 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -14,9 +14,9 @@ function create_network() return arch end -function create_single_neuron() +function create_linreg(num_hidden::Int=1) arch = @mx.chain mx.Variable(:data) => - mx.FullyConnected(name=:fc1, num_hidden=1) => + mx.FullyConnected(name=:fc1, num_hidden=num_hidden) => mx.LinearRegressionOutput(name=:linout) return arch end @@ -52,19 +52,28 @@ end function test_init_params() info("SymbolModule::InitParams") - m1 = mx.Module.SymbolModule(create_single_neuron(), - label_names = [:linout_label]) - mx.Module.bind(m1, [(1, 10)], [(1, 10)]) + + #= x = reshape(collect(1:10), (1, 10)) =# + #= y = reshape(collect(2:11), (1, 10)) =# + srand(123456) + epsilon = rand(1, 10) + x = rand(4, 10) + y = 2*x .+ epsilon + data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 5) + + m1 = mx.Module.SymbolModule(create_linreg(4), + label_names = [:linout_label], + context=[mx.cpu(), mx.cpu()]) + mx.Module.bind(m1, data) mx.Module.init_params(m1) + mx.Module.init_optimizer(m1) # TODO Should be changed to tests - info(mx.Module.get_params(m1)) + #= info(mx.Module.get_params(m1)) =# - x = reshape(collect(1:10), (1, 10)) - y = reshape(collect(2:11), (1, 10)) - data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 10) for batch in mx.eachdatabatch(data) mx.Module.forward(m1, batch) + info("SymbolModule::InitParams: $(copy(mx.Module.get_outputs(m1)[1]))") mx.Module.backward(m1) mx.Module.update(m1) end From 8060e9e6e76bc61738effb4d5c2d3599ad0b8b35 Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Sat, 21 Jan 2017 02:09:47 +0300 Subject: [PATCH 09/18] Most of medium api is working. Started implementation of high level api --- src/executor-group.jl | 11 +++++++++- src/io.jl | 6 +++--- src/module/Module.jl | 39 ++++++++++++++++++++++++++++++++-- src/module/symbol_module.jl | 34 +++++++++++++++++++++++------ src/ndarray.jl | 22 +++++++++++++++---- test/unittest/symbol-module.jl | 34 +++++++++++++++++++++++------ 6 files changed, 122 insertions(+), 24 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index 4429c091c..777842d8d 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -277,8 +277,13 @@ Accumulate the performance according to `eval_metric` on all devices. * labels : list of NDArray Typically comes from `label` of a `DataBatch`. """ -function update_metric(self::DataParallelExecutorGroup, eval_metric::AbstractEvalMetric, labels) +function update_metric(self::DataParallelExecutorGroup, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + # XXX: there is a possibiilty, that label arrays lie in different + # context than cpu_output_arrays. It should be checked and labels + # should be copied to corresponding context + cpu_output_arrays = get_outputs(self) + update!(eval_metric, get_label(provider, batch), cpu_output_arrays) end """ @@ -301,6 +306,10 @@ function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool= outputs = [[exec.outputs[i] for exec in self.execs] for i in 1:length(self.execs[1].outputs)] if merge_multi_context + # TODO In original FeedForward model single predefined + # output was used. _merge_multi_context creates new array + # each time it is called. Need to benchmark, may be it's better + # to predefine cpu_output_arrays in self. return _merge_multi_context(outputs) else return outputs diff --git a/src/io.jl b/src/io.jl index 1d50dd44a..89f6be129 100644 --- a/src/io.jl +++ b/src/io.jl @@ -17,6 +17,9 @@ Normally this involves defining: """ abstract AbstractDataProvider +type StubProvider <: AbstractDataProvider +end + """ get_batch_size(provider) -> Int @@ -124,9 +127,6 @@ get_label{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batc type DataBatchProvider <: AbstractDataProvider provider :: AbstractDataProvider - - DataBatchProvider() = new() - DataBatchProvider(provider) = new(provider) end eachdatabatch(provider :: AbstractDataProvider) = DataBatchProvider(provider) diff --git a/src/module/Module.jl b/src/module/Module.jl index 511d3a2e5..44c79a498 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,6 +1,7 @@ module Module import ....MXNet: mx import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy! """ AbstractModule @@ -176,7 +177,7 @@ end ### """ """ -forward(self :: AbstractModule, data_batch :: DataBatch, is_train=nothing) = forward(self, DataBatchProvider(), data_batch, is_train) +forward(self :: AbstractModule, data_batch :: DataBatch, is_train=nothing) = forward(self, StubProvider(), data_batch, is_train) function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train=nothing) throw(MethodError(forward, (self, ))) end @@ -203,7 +204,10 @@ end """ """ -function update_metric(self :: AbstractModule, ) + +update_metric(self :: AbstractModule, eval_metric::AbstractEvalMetric, batch::AbstractDataBatch) = update_metric(self, eval_metric, StubProvider(), batch) +function update_metric(self :: AbstractModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + throw(MethodError(update_metric, (self, ))) end ### @@ -230,6 +234,37 @@ function fit(self::AbstractModule, train_data) end """ + predict + +Run prediction and collect the outputs. + +# Arguments + +* `eval_data` : `AbstractDataProvider` +* `num_batch` : Int + Default is `None`, indicating running all the batches in the data iterator. +* `merge_batches` : `Bool` + Default is `true`, see the doc for return values. +* `always_output_list` : `Bool` + Default is `false`, see the doc for return values. + +# Returns +When `merge_batches` is `true` (by default), the return value will be a vector +`[out1, out2, out3]`. Where each element is concatenation of the outputs for +all the mini-batches. If further that `always_output_list` is `false` (by default), +then in the case of a single output, `out1` is returned instead of `[out1]`. +When `merge_batches` is `false`, the return value will be a nested list like +`[[out1_batch1, out2_batch1], [out1_batch2], ...]`. This mode is useful because +in some cases (e.g. bucketing), the module does not necessarily produce the same +number of outputs. +The objects in the results are `NDArray`s. If you need to work with julia array, +just call `Array{Float32}` on each of the `NDArray`. + +# Examples +# TODO finish doc +An example of using predict for prediction:: + >>> #Predict on the first 10 batches of val_dataiter + >>> mod.predict(eval_data=val_dataiter, num_batch=10) """ function predict(self::AbstractModule, eval_data; num_batch=nothing, merge_batches=true, reset=true) diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index aefcba93e..bd0b60598 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -1,5 +1,4 @@ import ....MXNet: mx # in order to use mx. -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric """ Module @@ -106,17 +105,21 @@ function get_params(self::SymbolModule) end if self.params_dirty mx.get_params!(self.exec_group, self.arg_params, self.aux_params) + self.params_dirty = false end return (self.arg_params, self.aux_params) end -function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), arg_params=nothing, - aux_params=nothing, allow_extra_params=false, force_init=false) +function init_params(self::SymbolModule; + initializer=UniformInitializer(0.07), + arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + allow_missing=false, allow_extra_params=false, force_init=false) + if isinitialized(self) && !force_init return self end - @assert isbinded(self) "Call `bind` before initialization" if !isdefined(self, :arg_params) || isempty(self.arg_params) @@ -127,7 +130,24 @@ function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), a self.aux_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.aux_params) end - # TODO need initialization + map([[self.arg_params, arg_params], [self.aux_params, aux_params]]) do param_arr + dst, src = param_arr + for (name, arr) in dst + if isempty(src) + init(initializer, name, arr) + else + src = get(src) + if name in keys(src) + if src[name] != arr + copy!(arr, src[name]) + end + else + @assert(!allow_missing, "$name is not presented") + init(initializer, name, arr) + end + end + end + end # copy the initialized parameters to devices set_params!(self.exec_group, self.arg_params, self.aux_params, allow_extra_params=allow_extra_params) @@ -305,8 +325,8 @@ Evaluate and accumulate evaluation metric on outputs of the last forward computa * labels : Dict of NDArray Typically `data_batch.label`. """ -function update_metric(self::SymbolModule, eval_metric::AbstractEvalMetric, labels) - mx.update_metric(self.exec_group, eval_metric, labels) +function update_metric(self::SymbolModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + mx.update_metric(self.exec_group, eval_metric, provider, batch) end ## diff --git a/src/ndarray.jl b/src/ndarray.jl index 1a0e8aff0..f86d30337 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -682,7 +682,22 @@ function /(arg0 :: NDArray, arg :: Real) ./(arg0, arg) end -function concatenate(arrays::Vector{NDArray}; always_copy=true) +""" + concatenate(arrays; always_copy=true, context=cpu) + +Concatenate a list of NDArrays along the last dimension. + +# Arguments +* `arrays` : vector of NDArray + Arrays to be concatenate. They must have identical shape except + the last dimension. +* `always_copy` : `Bool`, default `true`. When `false`, if the arrays only contain one `NDArray`, that element will be returned directly, avoid copying. +* `context`: `Context`, context of output NDArray. + +# Returns +An `NDArray` that lives on the `context`. +""" +function concatenate(arrays::Vector{NDArray}; always_copy=true, context=cpu()) if isempty(arrays) || (!always_copy && length(arrays) == 1) return arrays end @@ -695,12 +710,11 @@ function concatenate(arrays::Vector{NDArray}; always_copy=true) end ret_shape = tuple(shape_rest..., shape_axis) - ret = empty(ret_shape, context(arrays[1])) + ret = empty(ret_shape, context) idx = 1 for arr in arrays - setindex!(ret, arr, idx:idx+size(arr)[end]) - #= ret[idx:idx + size(arr)[end]] = arr =# + ret[idx:(idx + size(arr)[end] - 1)] = arr idx += size(arr)[end] end diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index a62a3fa5f..741758489 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -17,6 +17,7 @@ end function create_linreg(num_hidden::Int=1) arch = @mx.chain mx.Variable(:data) => mx.FullyConnected(name=:fc1, num_hidden=num_hidden) => + mx.FullyConnected(name=:fc2, num_hidden=1) => mx.LinearRegressionOutput(name=:linout) return arch end @@ -50,17 +51,18 @@ function test_basic() @test mx.Module.hasoptimizer(m1) end -function test_init_params() +function test_init_params(n_epoch::Int = 10) info("SymbolModule::InitParams") #= x = reshape(collect(1:10), (1, 10)) =# #= y = reshape(collect(2:11), (1, 10)) =# srand(123456) - epsilon = rand(1, 10) + epsilon = randn(1, 10) x = rand(4, 10) - y = 2*x .+ epsilon + y = mapslices(sum, [1, 2, 3, 4] .* x, 1) .+ epsilon data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 5) + metric = mx.MSE() m1 = mx.Module.SymbolModule(create_linreg(4), label_names = [:linout_label], context=[mx.cpu(), mx.cpu()]) @@ -71,12 +73,30 @@ function test_init_params() # TODO Should be changed to tests #= info(mx.Module.get_params(m1)) =# + for i in 1:n_epoch + for batch in mx.eachdatabatch(data) + mx.Module.forward(m1, batch) + mx.Module.update_metric(m1, metric, batch) + + mx.Module.backward(m1) + mx.Module.update(m1) + end + + for (name, value) in get(metric) + info("Epoch: $i: $name = $value") + end + mx.reset!(metric) + end + + y_pred = Float64[] for batch in mx.eachdatabatch(data) mx.Module.forward(m1, batch) - info("SymbolModule::InitParams: $(copy(mx.Module.get_outputs(m1)[1]))") - mx.Module.backward(m1) - mx.Module.update(m1) + append!(y_pred, Array{Float64}(mx.Module.get_outputs(m1)[1])) end + + info("Prediction: $y_pred") + info("Actual: $y") + info("No noise: $(mapslices(sum, [1, 2, 3, 4] .* x, 1))") end ################################################################################ @@ -85,7 +105,7 @@ end @testset "Symbol Module Test" begin test_basic() - test_init_params() + test_init_params(500) end end From 493caa7282c9d85330bd1ede18cd6835776496ef Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Sun, 22 Jan 2017 02:05:55 +0300 Subject: [PATCH 10/18] More tests, partially implemented high level api --- src/executor-group.jl | 19 ++++--- src/module/Module.jl | 100 ++++++++++++++++++++++++++------- src/module/symbol_module.jl | 69 +++++++++++++++++------ test/unittest/symbol-module.jl | 54 +++++++++++++----- 4 files changed, 182 insertions(+), 60 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index 777842d8d..ce1b8fa14 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -180,14 +180,15 @@ function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}) end """ - set_params!(self::DataParallelExecutorGroup, arg_params, aux_params) + set_params!(self::DataParallelExecutorGroup, arg_params, aux_params; allow_extra_params) Assign, i.e. copy parameters to all the executors. # Arguments -* `arg_params` : Dict{Symbol, NDArray} +* `arg_params` : `Dict{Symbol, NDArray}` A dictionary of name to `NDArray` parameter mapping. -* `aux_params` : Dict{Symbol, NDArray} +* `aux_params` : `Dict{Symbol, NDArray}` A dictionary of name to `NDArray` auxiliary variable mapping. +* `allow_extra_params`: `Bool`, default `false`, allow parameters in `arg_params` or `aux_params` that not exists in `self`. """ function set_params!(self::DataParallelExecutorGroup, arg_params, aux_params; allow_extra_params::Bool = false) @@ -315,16 +316,16 @@ function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool= return outputs end end -## -# Internals -## - function output_shapes(self:: DataParallelExecutorGroup) - #= outputs = [size(out) for out in self.execs[1].outputs] =# - #= return [tuple(key, shape) for key, shape in zip(list_outputs(exec_group.symbol), outputs)] =# + outputs = [size(out) for out in self.execs[1].outputs] + return Dict(key => shape for (key, shape) in zip(list_outputs(self.symbol), outputs)) end +## +# Internals +## + function get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) if isnull(fixed_param_names) # get grad attribute to allow for freezing diff --git a/src/module/Module.jl b/src/module/Module.jl index 44c79a498..9ebdc6b2a 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,7 +1,7 @@ module Module import ....MXNet: mx import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy! +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy!, concatenate, eachdatabatch, reset! """ AbstractModule @@ -120,18 +120,27 @@ end ## """ + data_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the data inputs to this module. """ function data_shapes(self :: AbstractModule) throw(MethodError(data_shapes, (self,))) end """ + label_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the label inputs to this module. If this module does not accept labels -- either it is a module without loss function, or it is not binded for training, then this should return an empty Dict. """ function label_shapes(self :: AbstractModule) throw(MethodError(label_shapes, (self,))) end """ + output_shapes(AbstractModule) + +A Dict of (name, shape) pairs specifying the outputs of this module. """ function output_shapes(self :: AbstractModule) throw(MethodError(output_shapes, (self,))) @@ -142,18 +151,47 @@ end ## """ + get_params(self::AbstractModule) + +Get parameters, those are potentially copies of the the actual parameters used to do computation on the device. + +# Returns +`(arg_params, aux_params)`, a pair of dictionary of name to value mapping. """ function get_params(self :: AbstractModule) throw(MethodError(get_params, (self,))) end """ -""" -function set_params(self :: AbstractModule, arg_params, aux_params) - throw(MethodError(set_params, (self, arg_params, aux_params))) + set_params(self::AbstractModule; arg_params, aux_params, allow_missing, force_init) + +Assign parameter and aux state values. + +# Arguments +* `arg_params` : `Dict`. Dictionary of name to value (`NDArray`) mapping. +* `aux_params` : `Dict`. Dictionary of name to value (`NDArray`) mapping. +* `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. +* `force_init` : `Bool`. If true, will force re-initialize even if already initialized. +""" +function set_params(self::AbstractModule, + arg_params::Dict{Symbol, NDArray}, + aux_params::Dict{Symbol, NDArray}; + allow_missing=false, force_init=false) + init_params(self, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) end """ + init_params!(self; kwargs...) + +Initialize the parameters and auxiliary states. + +# Arguments +* `self` : `AbstractModule` +* `initializer` : `AbstractInitializer`. Called to initialize parameters if needed. +* `arg_params` : `Dict{Symbol, NDArray}`. If not empty, should be a dictionary of existing `arg_params`. Initialization will be copied from that. +* `aux_params` : `Dict{Symbol, NDArray}`. If not empty, should be a dictionary of existing `aux_params`. Initialization will be copied from that. +* `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. +* `force_init` : `Bool`. If true, will force re-initialize even if already initialized. """ function init_params(self :: AbstractModule, args...) throw(MethodError(init_params, (self, args...))) @@ -233,6 +271,8 @@ function fit(self::AbstractModule, train_data) error("Not yet implemented") end + +# XXX: warning, this function is not type stable. """ predict @@ -261,45 +301,65 @@ The objects in the results are `NDArray`s. If you need to work with julia array, just call `Array{Float32}` on each of the `NDArray`. # Examples -# TODO finish doc An example of using predict for prediction:: - >>> #Predict on the first 10 batches of val_dataiter - >>> mod.predict(eval_data=val_dataiter, num_batch=10) +```julia +# Predict on the first 10 batches of `data` DataProvider +predict(m1, data, num_batch=10) +``` """ -function predict(self::AbstractModule, eval_data; - num_batch=nothing, merge_batches=true, reset=true) +function predict(self::AbstractModule, eval_data::AbstractDataProvider; + num_batch=nothing, merge_batches=true, always_output_list::Bool=false) @assert isbinded(self) && isinitialized(self) - reset && reset!(eval_data) - - for (nbatch, eval_batch) in enumerate(eval_data) + output_list = [] + for (nbatch, eval_batch) in enumerate(eachdatabatch(eval_data)) if num_batch !== nothing && nbatch == num_back break end - forward(self, eval_batch, is_train=false) + forward(self, eval_batch, false) outputs = get_outputs(self) + push!(output_list, outputs) + + end + + if length(output_list) == 0 + return output_list + end + + if merge_batches + num_outputs = length(output_list[1]) + for out in output_list + @assert(length(out) == num_outputs, + "Cannot merge batches, as num of outputs is not the same in mini-batches. Maybe bucketing is used?") + end + output_list2 = [concatenate([out[i] for out in output_list]) for i = 1:num_outputs] - error("Not yet implemented") + if num_outputs == 1 && !always_output_list + return output_list2[1] + end + + return output_list2 end + + return output_list end """ score(self::AbstractModule, eval_data, eval_metric; num_batch, batch_end_callback, reset=true, epoch=0) """ -function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, reset=true, epoch=0) +function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, epoch=0) @assert isbinded(self) && isinitialized(self) - reset && reset!(eval_data) reset!(eval_metric) - for (nbatch, eval_batch) in enumerate(eval_data) + for (nbatch, eval_batch) in enumerate(eachdatabatch(eval_data)) if num_batch !== nothing && nbatch == num_back break end - forward(self, eval_batch, is_train=false) - update_metric(self, eval_metric, label(eval_batch)) + forward(self, eval_batch, false) + update_metric(self, eval_metric, eval_batch) if batch_end_callback !== nothing error("Not implemented yet!") @@ -312,7 +372,7 @@ end forward_backward(self :: AbstractModule, data_batch) """ function forward_backward(self :: AbstractModule, data_batch) - forward(self, data_batch, is_train=true) + forward(self, data_batch, true) backward(self, data_batch) end diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index bd0b60598..2b19cb151 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -27,7 +27,6 @@ type SymbolModule <: AbstractModule data_shapes :: Vector{Tuple{Vararg{Int}}} label_shapes :: Vector{Tuple{Vararg{Int}}} - output_shapes :: Vector{Tuple{Vararg{Int}}} arg_arrays :: Nullable{Vector{NDArray}} aux_arrays :: Nullable{Vector{NDArray}} @@ -54,7 +53,6 @@ type SymbolModule <: AbstractModule false, false, false, false, false, Vector{Tuple{Int}}(), Vector{Tuple{Int}}(), - Vector{Tuple{Int}}(), Nullable{Vector{NDArray}}(), Nullable{Vector{NDArray}}(), Nullable{Vector{NDArray}}(), @@ -64,12 +62,12 @@ type SymbolModule <: AbstractModule end function SymbolModule(symbol::SymbolicNode; - data_names = [:data], label_names = [:softmax_label], + data_names = [:data], + label_names = [:softmax_label], context = [mx.cpu()], fixed_param_names = nothing) fixed_param_names = Nullable{Vector{Symbol}}(fixed_param_names) - if !isa(context, Vector{Context}) - context = [context] - end + label_names = Vector{Symbol}(label_names) + context = _wrap_context(context) @assert !isempty(data_names) @assert !isempty(context) return SymbolModule(symbol, data_names, label_names, context, fixed_param_names) @@ -82,26 +80,36 @@ isinitialized(self::SymbolModule) = self.params_initialized hasoptimizer(self::SymbolModule) = self.optimizer_initialized data_names(self::SymbolModule) = self.data_names +label_names(self::SymbolModule) = self.label_names output_names(self::SymbolModule) = list_outputs(self.symbol) +""" + get_symbol(self::SymbolModule) -> Nullable{SymbolicNode} + +Returns the associated [`SymbolicNode`](@ref) of the module. It might not be defined or change over time. +""" +function get_symbol(self::SymbolModule) + return Nullable{SymbolicNode}(self.symbol) +end + function data_shapes(self::SymbolModule) - !isbinded(self) && return Vector{Tuple{Int}}() - return self.data_shapes + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return Dict(k => v for (k, v) in zip(data_names(self), self.data_shapes)) end function label_shapes(self::SymbolModule) - !isbinded(self) && return Vector{Tuple{Int}}() - return self.label_shapes + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return Dict(k => v for (k, v) in zip(label_names(self), self.label_shapes)) end function output_shapes(self::SymbolModule) - !isbinded(self) && return Vector{Tuple{Int}}() - return self.output_shapes + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return mx.output_shapes(self.exec_group) end function get_params(self::SymbolModule) if !(isbinded(self) && isinitialized(self)) - return (Nullable{Dict{Symbol, NDArray}}(), Nullable{Dict{Symbol, NDArray}}()) + return (Dict{Symbol, NDArray}(), Dict{Symbol, NDArray}()) end if self.params_dirty mx.get_params!(self.exec_group, self.arg_params, self.aux_params) @@ -115,7 +123,7 @@ function init_params(self::SymbolModule; initializer=UniformInitializer(0.07), arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), - allow_missing=false, allow_extra_params=false, force_init=false) + allow_missing=false, force_init=false) if isinitialized(self) && !force_init return self @@ -150,7 +158,7 @@ function init_params(self::SymbolModule; end # copy the initialized parameters to devices - set_params!(self.exec_group, self.arg_params, self.aux_params, allow_extra_params=allow_extra_params) + set_params!(self.exec_group, self.arg_params, self.aux_params) self.params_dirty = false self.params_initialized = true @@ -274,10 +282,11 @@ function get_outputs(self::SymbolModule, merge_multi_context::Bool=true) mx.get_outputs(self.exec_group, merge_multi_context) end -# TODO add description """ forward(module, data_provider, data_batch; is_train) + Forward computation. + # Arguments * `data_batch` : AbstractDataBatch * `is_train` : Nullable{Bool} @@ -293,7 +302,7 @@ end backward(module, out_grads) Backward computation. # Arguments -* `out_grads` : NDArray or list of NDArray, optional +* `out_grads` : `NDArray` or vector of `NDArray`, default `nothing`. Gradient on the outputs to be propagated back. This parameter is only needed when bind is called on outputs that are not a loss function. @@ -329,6 +338,29 @@ function update_metric(self::SymbolModule, eval_metric::AbstractEvalMetric, prov mx.update_metric(self.exec_group, eval_metric, provider, batch) end +""" + get_input_grads(self::SymbolModule, merge_multi_context=true) + +Get the gradients with respect to the inputs of the module. + +# Arguments + +* `merge_multi_context` : `Bool` + Default is `true`. In the case when data-parallelism is used, the outputs +will be collected from multiple devices. A `true` value indicate that we +should merge the collected results so that they look like from a single +executor. + +# Returns +If `merge_multi_context` is `true`, it is like `[grad1, grad2]`. Otherwise, it +is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_input_grads(self::SymbolModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) && self.inputs_need_grad + return mx.get_input_grads(self.exec_group, merge_multi_context) +end + ## # Internals ## @@ -349,3 +381,6 @@ function borrow_optimizer!(self::SymbolModule, shared_module::SymbolModule) self.updater = shared_module.updater self.optimizer_initialized = true end + +_wrap_context(context::Context) = [context] +_wrap_context(context::Vector{Context}) = context diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index 741758489..2824bab8f 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -2,23 +2,25 @@ module TestSymbolModule using MXNet using Base.Test +using ..Main: reldiff + ################################################################################ # Utils ################################################################################ function create_network() arch = mx.@chain mx.Variable(:data) => - mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => - mx.SoftmaxOutput(name=:softmax, multi_output = true) + mx.Convolution(kernel = (3,3), pad = (1,1), stride = (1,1), num_filter = 64) => + mx.SoftmaxOutput(name=:softmax, multi_output = true) return arch end function create_linreg(num_hidden::Int=1) arch = @mx.chain mx.Variable(:data) => - mx.FullyConnected(name=:fc1, num_hidden=num_hidden) => - mx.FullyConnected(name=:fc2, num_hidden=1) => - mx.LinearRegressionOutput(name=:linout) + mx.FullyConnected(name=:fc1, num_hidden=num_hidden) => + mx.FullyConnected(name=:fc2, num_hidden=1) => + mx.LinearRegressionOutput(name=:linout) return arch end @@ -28,7 +30,7 @@ end function test_basic() info("SymbolModule::basic") - + m1 = mx.Module.SymbolModule(create_network()) @test !mx.Module.isbinded(m1) @@ -38,8 +40,8 @@ function test_basic() @test mx.Module.data_names(m1) == [:data] @test mx.Module.output_names(m1) == [:softmax_output] - - mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) + + mx.Module.bind(m1, [(20, 20, 1, 10)], [(400, 10)]) @test mx.Module.isbinded(m1) @test !mx.Module.isinitialized(m1) @test !mx.Module.hasoptimizer(m1) @@ -51,11 +53,24 @@ function test_basic() @test mx.Module.hasoptimizer(m1) end -function test_init_params(n_epoch::Int = 10) - info("SymbolModule::InitParams") +function test_shapes() + info("SymbolModule::Shapes") + + m1 = mx.Module.SymbolModule(create_network()) + mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) + + @test mx.Module.data_shapes(m1) == Dict(:data => (20, 20, 1, 10)) + @test mx.Module.label_shapes(m1) == Dict(:softmax_label => (20, 20, 1, 10)) + @test mx.Module.output_shapes(m1) == Dict(:softmax_output => (20, 20, 64, 10)) + + m2 = mx.Module.SymbolModule(create_network(), label_names=[]) + mx.Module.bind(m2, [(20, 20, 1, 10)]) + @test isempty(mx.Module.label_shapes(m2)) +end + +function test_linear_regression(n_epoch::Int = 10) + info("SymbolModule::LinearRegression") - #= x = reshape(collect(1:10), (1, 10)) =# - #= y = reshape(collect(2:11), (1, 10)) =# srand(123456) epsilon = randn(1, 10) x = rand(4, 10) @@ -90,13 +105,22 @@ function test_init_params(n_epoch::Int = 10) y_pred = Float64[] for batch in mx.eachdatabatch(data) - mx.Module.forward(m1, batch) + mx.Module.forward(m1, batch, false) append!(y_pred, Array{Float64}(mx.Module.get_outputs(m1)[1])) end + y_pred = reshape(y_pred, 1, 10) info("Prediction: $y_pred") info("Actual: $y") info("No noise: $(mapslices(sum, [1, 2, 3, 4] .* x, 1))") + + # High Level Api + name, score = mx.Module.score(m1, data, metric)[1] + ha_pred = mx.copy(mx.Module.predict(m1, data)) + info("Predict result: ", ha_pred) + info("Score $name : ", score) + + @test sum(abs(ha_pred-y_pred)) < 1e-6 end ################################################################################ @@ -105,7 +129,9 @@ end @testset "Symbol Module Test" begin test_basic() - test_init_params(500) + test_shapes() + #= test_init_params(500) =# + test_linear_regression() end end From 4a6df476db190a20dde5f9fc30667e8ef0085948 Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Sun, 22 Jan 2017 18:27:19 +0300 Subject: [PATCH 11/18] High Level Api for SymbolModule --- src/executor-group.jl | 32 ++++++- src/module/Module.jl | 147 +++++++++++++++++++++++++++++++-- src/module/symbol_module.jl | 6 +- test/unittest/symbol-module.jl | 14 +++- 4 files changed, 183 insertions(+), 16 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index ce1b8fa14..a3e69134a 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -145,14 +145,13 @@ end Split `data_batch` according to workload and run forward on each devices. # Arguments * `data_batch` : AbstractDataBatch -* `is_train` : Nullable{Bool} +* `is_train` : `Bool` The hint for the backend, indicating whether we are during training phase. Default is `nothing`, then the value `self.for_training` will be used. """ -function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train = nothing) +function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool = self.for_training) load_data!(data_provider, data_batch, self.data_arrays) - is_train = get(is_train, self.for_training) if is_train && !isempty(get_label(data_provider, data_batch)) load_label!(data_provider, data_batch, self.label_arrays) @@ -317,6 +316,33 @@ function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool= end end +""" + get_input_grads(self, merge_multi_context) + +Get the gradients with respect to the inputs of the module. + +# Arguments +* `merge_multi_context` : `Bool` + Default is `true`. In the case when data-parallelism is used, the outputs + will be collected from multiple devices. A `true` value indicate that we + should merge the collected results so that they look like from a single + executor. + +# Returns +If `merge_multi_context` is `True`, it is like `[grad1, grad2]`. Otherwise, it +is like `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`. All the output +elements are `NDArray`. +""" +function get_input_grads(self::DataParallelExecutorGroup, merge_multi_context::Bool=true) + !self.inputs_need_grad && NDArray[] + + if merge_multi_context + return _merge_multi_context(self.input_grad_arrays) + end + + return self.input_grad_arrays +end + function output_shapes(self:: DataParallelExecutorGroup) outputs = [size(out) for out in self.execs[1].outputs] return Dict(key => shape for (key, shape) in zip(list_outputs(self.symbol), outputs)) diff --git a/src/module/Module.jl b/src/module/Module.jl index 9ebdc6b2a..1c8dfd68b 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,7 +1,7 @@ module Module import ....MXNet: mx import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy!, concatenate, eachdatabatch, reset! +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy!, concatenate, eachdatabatch, reset!, Accuracy """ AbstractModule @@ -215,8 +215,8 @@ end ### """ """ -forward(self :: AbstractModule, data_batch :: DataBatch, is_train=nothing) = forward(self, StubProvider(), data_batch, is_train) -function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train=nothing) +forward(self :: AbstractModule, data_batch :: DataBatch, is_train) = forward(self, StubProvider(), data_batch, is_train) +function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train) throw(MethodError(forward, (self, ))) end @@ -265,12 +265,138 @@ end ### """ + fit(self::AbstractModule, train_data::AbstractDataProvider; kwargs...) + +Train the module parameters. + +# Arguments +* `train_data` : AbstractDataProvider +* `eval_data` : AbstractDataProvider + If not `nothing`, will be used as validation set and evaluate the performance + after each epoch. +* `eval_metric` : str or EvalMetric + Default `'acc'`. The performance measure used to display during training. +* `epoch_end_callback` : function or list of function + Each callback will be called with the current `epoch`, `symbol`, `arg_params` + and `aux_params`. +* `batch_end_callback` : function or list of function + Each callback will be called with a `BatchEndParam`. +* `kvstore` : Symbol or KVStore + Default `:local`. +* `optimizer` : AbstractOptimizer + Default `ADAM` +* `eval_end_callback` : function or list of function + These will be called at the end of each full evaluation, with the metrics over + the entire evaluation set. +* `eval_batch_end_callback` : function or list of function + These will be called at the end of each minibatch during evaluation +* `initializer` : Initializer + Will be called to initialize the module parameters if not already initialized. +* `arg_params` : dict + Default `nothing`, if not `nothing`, should be existing parameters from a trained + model or loaded from a checkpoint (previously saved model). In this case, + the value here will be used to initialize the module parameters, unless they + are already initialized by the user via a call to `init_params` or `fit`. +`arg_params` has higher priority to `initializer`. +* `aux_params` : dict + Default `None`. Similar to `arg_params`, except for auxiliary states. +* `allow_missing` : bool + Default `False`. Indicate whether we allow missing parameters when `arg_params` + and `aux_params` are not `None`. If this is `True`, then the missing parameters + will be initialized via the `initializer`. +* `force_rebind` : bool + Default `False`. Whether to force rebinding the executors if already binded. +* `force_init` : bool + Default `False`. Indicate whether we should force initialization even if the + parameters are already initialized. +* `begin_epoch` : int + Default `1`. Indicate the starting epoch. Usually, if we are resuming from a + checkpoint saved at a previous training phase at epoch N, then we should specify + this value as N+1. +* `num_epoch` : int + Number of epochs to run training. + +# Examples +An example of using fit for training:: +```julia +# Assume training train_data and validation eval_data are ready +mx.Module.fit(mod, train_data, 10, eval_data=eval_data, + optimizer=mx.SGD(lr=0.01, momentum=0.9)) +``` """ -function fit(self::AbstractModule, train_data) +function fit(self::AbstractModule, train_data, num_epoch; + initializer=UniformInitializer(0.07), + optimizer = ADAM(), + eval_data=nothing, + eval_metric::AbstractEvalMetric=Accuracy(), + validation_metric::AbstractEvalMetric = eval_metric, + epoch_end_callback = nothing, + batch_end_callback = nothing, + kvstore = :local, + eval_end_callback = nothing, + eval_batch_end_callback = nothing, + arg_params = Dict{Symbol, NDArray}(), + aux_params = Dict{Symbol, NDArray}(), + allow_missing = false, + force_rebind = false, + force_init = false, + begin_epoch = 1) + bind(self, train_data, for_training=true, force_rebind=force_rebind) + init_params(self, initializer=initializer, arg_params=arg_params, + aux_params=aux_params, allow_missing=allow_missing, + force_init=force_init) + init_optimizer(self, kvstore=kvstore, optimizer=optimizer, force_init=force_init) + + if validation_metric == nothing + validation_metric = eval_metric + end - error("Not yet implemented") -end + #################################################################### + # training loop + #################################################################### + for epoch in begin_epoch:num_epoch + time_start = time() + reset!(eval_metric) + for (nbatch, batch) in enumerate(eachdatabatch(train_data)) + forward_backward(self, batch) + update(self) + update_metric(self, eval_metric, batch) + + if batch_end_callback !== nothing + error("Not implemented yet") + end + end + # one epoch of training is finished + for (name, val) in get(eval_metric) + info("Epoch[$epoch] Train-$name=$val") + end + time_stop = time() + info("Epoch[$epoch] Time cost=$(time_stop - time_start)") + + # sync aux params across devices + arg_params, aux_params = get_params(self) + set_params(self, arg_params, aux_params) + + if epoch_end_callback !== nothing + error("Not implemented yet") + end + + ################################################################## + # evaluation on validation set + ################################################################## + if eval_data !== nothing + res = score(self, eval_data, validation_metric, + score_end_callback=eval_end_callback, + batch_end_callback=eval_batch_end_callback, + epoch=epoch) + #TODO: pull this into default + for (name, val) in res + info("Epoch[$epoch] Validation-$name=$val") + end + end + end +end # XXX: warning, this function is not type stable. """ @@ -348,7 +474,8 @@ end """ score(self::AbstractModule, eval_data, eval_metric; num_batch, batch_end_callback, reset=true, epoch=0) """ -function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, epoch=0) +function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing, batch_end_callback=nothing, score_end_callback=nothing, + epoch=0) @assert isbinded(self) && isinitialized(self) reset!(eval_metric) @@ -365,6 +492,10 @@ function score(self :: AbstractModule, eval_data, eval_metric; num_batch=nothing error("Not implemented yet!") end end + if score_end_callback !== nothing + error("Not implemented yet!") + end + get(eval_metric) end @@ -373,7 +504,7 @@ end """ function forward_backward(self :: AbstractModule, data_batch) forward(self, data_batch, true) - backward(self, data_batch) + backward(self) end # include implementations diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index 2b19cb151..f77a64151 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -289,11 +289,11 @@ Forward computation. # Arguments * `data_batch` : AbstractDataBatch -* `is_train` : Nullable{Bool} +* `is_train` : Bool Default is `nothing`, which means `is_train` takes the value of `self.for_training`. """ -forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train = nothing) = forward(self, data_provider, data_batch, Nullable{Bool}(is_train)) -function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Nullable{Bool}) +forward(self::SymbolModule, data_batch::DataBatch) = forward(self, StubProvider(), data_batch, self.for_training) +function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool) @assert isbinded(self) && isinitialized(self) mx.forward(self.exec_group, data_provider, data_batch, is_train) end diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index 2824bab8f..f0e121e31 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -91,10 +91,10 @@ function test_linear_regression(n_epoch::Int = 10) for i in 1:n_epoch for batch in mx.eachdatabatch(data) mx.Module.forward(m1, batch) - mx.Module.update_metric(m1, metric, batch) - mx.Module.backward(m1) mx.Module.update(m1) + + mx.Module.update_metric(m1, metric, batch) end for (name, value) in get(metric) @@ -121,6 +121,16 @@ function test_linear_regression(n_epoch::Int = 10) info("Score $name : ", score) @test sum(abs(ha_pred-y_pred)) < 1e-6 + + m2 = mx.Module.SymbolModule(create_linreg(4), + label_names = [:linout_label], + context=[mx.cpu(), mx.cpu()]) + mx.Module.fit(m2, data, 10, eval_metric=mx.MSE()) + name, score = mx.Module.score(m2, data, metric)[1] + ha_pred = mx.copy(mx.Module.predict(m2, data)) + info("Predict result: ", ha_pred) + info("Score $name : ", score) + @test sum(abs(ha_pred-y_pred)) < 1e-1 end ################################################################################ From 065fd253690c1379b690c75610989d085d9f5ecd Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 21 Jan 2017 11:51:43 +0900 Subject: [PATCH 12/18] fix num_args passed to infer_type --- src/symbolic-node.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index f5a518c35..a151cbb25 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -345,7 +345,7 @@ function _infer_type(self, keys, arg_type_data) Ref{MX_uint}, Ref{Ptr{Cint}}, Ref{MX_uint}, Ref{Ptr{Cint}}, Ref{Cint}), - self, length(arg_type_data)-1, keys, arg_type_data, + self, length(arg_type_data), keys, arg_type_data, ref_in_type_size, ref_in_type_data, ref_out_type_size, ref_out_type_data, ref_aux_type_size, ref_aux_type_data, From f941ef04d547b2b8deb35bb8f0a9f5ba7c441ff3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 23 Jan 2017 15:23:39 +0900 Subject: [PATCH 13/18] add type inference to executor-group --- src/executor-group.jl | 79 +++++++++++++++++++++------------- src/module/symbol_module.jl | 12 ++++-- test/unittest/symbol-module.jl | 2 +- 3 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index a3e69134a..14e3389b6 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -9,6 +9,14 @@ function forward(self::AbstractExecutorGroup, data_provider :: AbstractDataProvi throw(MethodError(forward, (self, ))) end +""" + DataParallelExecutorGroup + +Supports: + - Fixed parameters (freezing) + - Shape inference + - Type inference +""" type DataParallelExecutorGroup <: AbstractExecutorGroup symbol :: SymbolicNode context :: Vector{Context} @@ -39,15 +47,16 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup param_names :: Vector{Symbol} aux_names :: Vector{Symbol} end + function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context}, - data_shapes, data_names, label_shapes, label_names, for_training, inputs_need_grad, - shared_group, fixed_param_names, grad_req) + data_shapes, data_names, data_types, label_shapes, label_names, label_types, + for_training, inputs_need_grad, shared_group, fixed_param_names, grad_req) num_dev = length(context) arg_names = list_arguments(symbol) input_names = [data_names; label_names] param_names = setdiff(arg_names, input_names) - aux_names = list_auxiliary_states(symbol) + aux_names = list_auxiliary_states(symbol) batch_size = data_shapes[1][end] for shape in data_shapes @@ -59,39 +68,54 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context end end - # TODO imlplement workload + # TODO implement workload slices = _split_inputs(batch_size, num_dev) execs = Vector{Executor}(num_dev) - provided_shapes = merge(Dict(name => shape for (name, shape) in zip(data_names, data_shapes)), - Dict(name => shape for (name, shape) in zip(label_names, label_shapes))) + # Shape inference based on data_shapes and label_shapes + provided_shapes = merge( + Dict(name => shape for (name, shape) in zip(data_names, data_shapes)), + Dict(name => shape for (name, shape) in zip(label_names, label_shapes)) + ) + arg_shapes, out_shapes, aux_shapes = infer_shape(symbol; provided_shapes...) @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") + # Type inference based on data_types and lable_types + provided_types = merge( + Dict(name => T for (name, T) in zip(data_names, data_types)), + Dict(name => T for (name, T) in zip(label_names, label_types)) + ) + + arg_types, out_types, aux_types = infer_type(symbol; provided_types...) + + # Check for what arg we needs gradients and which are frozen grad_req, freeze_idx = get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) arg_params = Dict{Symbol, NDArray}() aux_params = Dict{Symbol, NDArray}() - for (name, shape) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes)) - arg_params[name] = empty(shape) + for (name, shape, T) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes, arg_types)) + arg_params[name] = empty(T, shape) end - for (name, shape) in zip(aux_names, aux_shapes) - aux_params[name] = empty(shape) + for (name, shape, T) in zip(aux_names, aux_shapes, aux_types) + aux_params[name] = empty(T, shape) end + dev_shapes(shapes, slice) = (tuple(shape[1:end-1]..., slice) for shape in shapes) + for i = 1:num_dev - data_shapes_dev = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(data_names, data_shapes)) - label_shapes_dev = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(label_names, label_shapes)) - arg_shapes_dev, out_shapes_dev, aux_shapes_dev = infer_shape(symbol; data_shapes_dev..., label_shapes_dev...) - @assert(!isa(arg_shapes_dev, Void), "Information not enough to perform complete shape inference") - arg_arrays = NDArray[zeros(shape, context[i]) for shape in arg_shapes_dev] - grad_arrays = Dict{Symbol,NDArray}() - aux_arrays = NDArray[zeros(shape, context[i]) for shape in aux_shapes_dev] + arg_shapes_dev = dev_shapes(arg_shapes, length(slices[i])) + aux_shapes_dev = dev_shapes(aux_shapes, length(slices[i])) + + arg_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(arg_shapes_dev, arg_types)] + aux_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(aux_shapes_dev, aux_types)] - shapes = zip(arg_names, arg_shapes_dev) + # Process arguments to create gradient arrays + grad_arrays = Dict{Symbol,NDArray}() + arg_info = zip(arg_names, arg_shapes_dev, arg_types) # if not in provided data, should be parameters if inputs_need_grad @@ -99,22 +123,18 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context else provided_data_names = [data_names; label_names] end - shapes = filter(x -> !in(x[1], provided_data_names), shapes) + arg_info = filter(x -> !in(x[1], provided_data_names), arg_info) # Remove all gradients for nop params - shapes = filter(x -> grad_req[x[1]] != GRAD_NOP, shapes) + arg_info = filter(x -> grad_req[x[1]] != GRAD_NOP, arg_info) - for (name, shape) in shapes - grad_arrays[name] = zeros(shape, context[i]) + for (name, shape, T) in arg_info + grad_arrays[name] = zeros(T, shape, context[i]) end execs[i] = bind(symbol, context[i], arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays) - #= dbg_str = mx.debug_str(train_execs[i]) =# - #= info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i])) =# end - # TODO: perform type inference - # set up input data structures data_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in data_names] label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in label_names] @@ -152,7 +172,8 @@ Split `data_batch` according to workload and run forward on each devices. function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool = self.for_training) load_data!(data_provider, data_batch, self.data_arrays) - + is_train = get(is_train, self.for_training) + if is_train && !isempty(get_label(data_provider, data_batch)) load_label!(data_provider, data_batch, self.label_arrays) end @@ -168,7 +189,7 @@ backward(self::DataParallelExecutorGroup, out_grads::Void) = backward(self, NDAr backward(self::DataParallelExecutorGroup, out_grads::NDArray) = backward(self, [out_grads]) function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}) @assert(self.for_training, "re-bind with for_training=true to run backward") - + for (i, exec) in enumerate(self.execs) out_grad_slices = NDArray[] for grad in out_grads @@ -179,7 +200,7 @@ function backward(self::DataParallelExecutorGroup, out_grads::Vector{NDArray}) end """ - set_params!(self::DataParallelExecutorGroup, arg_params, aux_params; allow_extra_params) + set_params!(self::DataParallelExecutorGroup, arg_params, aux_params; allow_extra_params) Assign, i.e. copy parameters to all the executors. # Arguments diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index f77a64151..fec4c90c1 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -195,9 +195,13 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}} self.data_shapes = data_shapes self.label_shapes = label_shapes + # TODO propagate type information + data_types = [Float32 for _ in 1:length(self.data_names)] + label_types = [Float32 for _ in 1:length(self.label_names)] + self.exec_group = DataParallelExecutorGroup(self.symbol, self.context, - self.data_shapes, self.data_names, - self.label_shapes, self.label_names, + self.data_shapes, self.data_names, data_types, + self.label_shapes, self.label_names, label_types, self.for_training, self.inputs_need_grad, shared_group, self.fixed_param_names, grad_req) return self @@ -247,7 +251,7 @@ function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(), end end end - + # TODO add preloaded states #= if !isa(self.preload_opt_states, Void) =# #= load_optimizer_states!(self, self.preload_opt_states) =# @@ -367,7 +371,7 @@ end """ borrow_optimizer!(module, shared_module) -Borrow optimizer from a shared module. Used in bucketing, where exactly the same +Borrow optimizer from a shared module. Used in bucketing, where exactly the same optimizer (esp. kvstore) is used. # Arguments * `module` : SymbolModule diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index f0e121e31..160d8f64a 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -78,7 +78,7 @@ function test_linear_regression(n_epoch::Int = 10) data = mx.ArrayDataProvider(:data => x, :linout_label => y; batch_size = 5) metric = mx.MSE() - m1 = mx.Module.SymbolModule(create_linreg(4), + m1 = mx.Module.SymbolModule(create_linreg(4), label_names = [:linout_label], context=[mx.cpu(), mx.cpu()]) mx.Module.bind(m1, data) From 9ea80731e7a346960dc6d8bc012a777c177a3bf5 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 23 Jan 2017 16:16:04 +0900 Subject: [PATCH 14/18] address my own review comments --- src/executor-group.jl | 31 ++++++++++++++++++++----------- src/module/Module.jl | 19 +++++++++++++------ src/module/symbol_module.jl | 10 +++++----- src/symbolic-node.jl | 10 ++++++++-- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/src/executor-group.jl b/src/executor-group.jl index 14e3389b6..2fbd5ff05 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -79,7 +79,8 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context Dict(name => shape for (name, shape) in zip(label_names, label_shapes)) ) - arg_shapes, out_shapes, aux_shapes = infer_shape(symbol; provided_shapes...) + # Run shape inference globally + arg_shapes, out_shapes, aux_shapes = infer_shape(symbol, provided_shapes) @assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference") # Type inference based on data_types and lable_types @@ -88,7 +89,7 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context Dict(name => T for (name, T) in zip(label_names, label_types)) ) - arg_types, out_types, aux_types = infer_type(symbol; provided_types...) + arg_types, out_types, aux_types = infer_type(symbol, provided_types) # Check for what arg we needs gradients and which are frozen grad_req, freeze_idx = get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req) @@ -107,8 +108,16 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context dev_shapes(shapes, slice) = (tuple(shape[1:end-1]..., slice) for shape in shapes) for i = 1:num_dev - arg_shapes_dev = dev_shapes(arg_shapes, length(slices[i])) - aux_shapes_dev = dev_shapes(aux_shapes, length(slices[i])) + slice = length(slices[i]) + # Shape inference based on data_shapes and label_shapes per device + provided_shapes_dev = merge( + Dict(name => shape for (name, shape) in zip(data_names, dev_shapes(data_shapes, slice))), + Dict(name => shape for (name, shape) in zip(label_names, dev_shapes(label_shapes, slice))) + ) + + # Run shape inference locally (per-device) + arg_shapes_dev, out_shapes_dev, aux_shapes_dev = infer_shape(symbol, provided_shapes_dev) + @assert(!isa(arg_shapes_dev, Void), "Information not enough to perform complete shape inference") arg_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(arg_shapes_dev, arg_types)] aux_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(aux_shapes_dev, aux_types)] @@ -172,7 +181,6 @@ Split `data_batch` according to workload and run forward on each devices. function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool = self.for_training) load_data!(data_provider, data_batch, self.data_arrays) - is_train = get(is_train, self.for_training) if is_train && !isempty(get_label(data_provider, data_batch)) load_label!(data_provider, data_batch, self.label_arrays) @@ -256,9 +264,9 @@ function update_params(self::DataParallelExecutorGroup, updater, update_on_kvsto end end end -end +end -""" +""" get_params!(self, arg_params, aux_params) Copy data from each executor to `arg_params` and `aux_params`. @@ -269,7 +277,7 @@ Copy data from each executor to `arg_params` and `aux_params`. # Notes This function will inplace update the NDArrays in arg_params and aux_params. """ -function get_params!(self::DataParallelExecutorGroup, arg_params::Dict{Symbol, NDArray}, +function get_params!(self::DataParallelExecutorGroup, arg_params::Dict{Symbol, NDArray}, aux_params::Dict{Symbol, NDArray}) for (name, block) in zip(self.param_names, self.param_arrays) w = empty(size(block[1])) @@ -290,6 +298,7 @@ function get_params!(self::DataParallelExecutorGroup, arg_params::Dict{Symbol, N end """ + update_metric Accumulate the performance according to `eval_metric` on all devices. # Parameters @@ -300,8 +309,8 @@ Accumulate the performance according to `eval_metric` on all devices. """ function update_metric(self::DataParallelExecutorGroup, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) - # XXX: there is a possibiilty, that label arrays lie in different - # context than cpu_output_arrays. It should be checked and labels + # XXX: there is a possibility, that label arrays lie in different + # context than cpu_output_arrays. It should be checked and labels # should be copied to corresponding context cpu_output_arrays = get_outputs(self) update!(eval_metric, get_label(provider, batch), cpu_output_arrays) @@ -360,7 +369,7 @@ function get_input_grads(self::DataParallelExecutorGroup, merge_multi_context::B if merge_multi_context return _merge_multi_context(self.input_grad_arrays) end - + return self.input_grad_arrays end diff --git a/src/module/Module.jl b/src/module/Module.jl index 1c8dfd68b..e8a5afc25 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,7 +1,12 @@ module Module import ....MXNet: mx import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy!, concatenate, eachdatabatch, reset!, Accuracy +import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, + GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, + KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, + AbstractOptimizer, get_updater, update_params, provide_data, + provide_label, AbstractEvalMetric, StubProvider, init, copy!, + concatenate, eachdatabatch, reset!, Accuracy """ AbstractModule @@ -122,7 +127,7 @@ end """ data_shapes(AbstractModule) -A Dict of (name, shape) pairs specifying the data inputs to this module. +A Dict of (name, shape) pairs specifying the data inputs to this module. """ function data_shapes(self :: AbstractModule) throw(MethodError(data_shapes, (self,))) @@ -131,7 +136,9 @@ end """ label_shapes(AbstractModule) -A Dict of (name, shape) pairs specifying the label inputs to this module. If this module does not accept labels -- either it is a module without loss function, or it is not binded for training, then this should return an empty Dict. +A Dict of (name, shape) pairs specifying the label inputs to this module. +If this module does not accept labels -- either it is a module without loss function, +or it is not binded for training, then this should return an empty Dict. """ function label_shapes(self :: AbstractModule) throw(MethodError(label_shapes, (self,))) @@ -173,7 +180,7 @@ Assign parameter and aux state values. * `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. * `force_init` : `Bool`. If true, will force re-initialize even if already initialized. """ -function set_params(self::AbstractModule, +function set_params(self::AbstractModule, arg_params::Dict{Symbol, NDArray}, aux_params::Dict{Symbol, NDArray}; allow_missing=false, force_init=false) @@ -327,7 +334,7 @@ mx.Module.fit(mod, train_data, 10, eval_data=eval_data, function fit(self::AbstractModule, train_data, num_epoch; initializer=UniformInitializer(0.07), optimizer = ADAM(), - eval_data=nothing, + eval_data=nothing, eval_metric::AbstractEvalMetric=Accuracy(), validation_metric::AbstractEvalMetric = eval_metric, epoch_end_callback = nothing, @@ -456,7 +463,7 @@ function predict(self::AbstractModule, eval_data::AbstractDataProvider; if merge_batches num_outputs = length(output_list[1]) for out in output_list - @assert(length(out) == num_outputs, + @assert(length(out) == num_outputs, "Cannot merge batches, as num of outputs is not the same in mini-batches. Maybe bucketing is used?") end output_list2 = [concatenate([out[i] for out in output_list]) for i = 1:num_outputs] diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index fec4c90c1..05d81879c 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -64,7 +64,7 @@ end function SymbolModule(symbol::SymbolicNode; data_names = [:data], label_names = [:softmax_label], - context = [mx.cpu()], fixed_param_names = nothing) + context = [mx.cpu()], fixed_param_names = nothing) fixed_param_names = Nullable{Vector{Symbol}}(fixed_param_names) label_names = Vector{Symbol}(label_names) context = _wrap_context(context) @@ -119,8 +119,8 @@ function get_params(self::SymbolModule) return (self.arg_params, self.aux_params) end -function init_params(self::SymbolModule; - initializer=UniformInitializer(0.07), +function init_params(self::SymbolModule; + initializer=UniformInitializer(0.07), arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), allow_missing=false, force_init=false) @@ -281,8 +281,7 @@ is like `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`. All the output elements are `NDArray`. """ function get_outputs(self::SymbolModule, merge_multi_context::Bool=true) - @assert isbinded(self) && isinitialized(self) - + @assert isbinded(self) && isinitialized(self) mx.get_outputs(self.exec_group, merge_multi_context) end @@ -371,6 +370,7 @@ end """ borrow_optimizer!(module, shared_module) + Borrow optimizer from a shared module. Used in bucketing, where exactly the same optimizer (esp. kvstore) is used. # Arguments diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index a151cbb25..d5b805ee1 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -297,6 +297,7 @@ end """ infer_shape(self :: SymbolicNode, args...) + infer_shape(self :: SymbolicNode, args::Dict) infer_shape(self :: SymbolicNode; kwargs...) Do shape inference according to the input shapes. The input shapes could be provided @@ -308,7 +309,8 @@ Returns a 3-tuple containing shapes of all the arguments, shapes of all the outp shapes of all the auxiliary variables. If shape inference failed due to incomplete or incompatible inputs, the return value will be `(nothing, nothing, nothing)`. """ -function infer_shape(self :: SymbolicNode; kwargs...) + +function infer_shape(self :: SymbolicNode, kwargs::Dict) sdata = MX_uint[] indptr = MX_uint[0] for (k,v) in kwargs @@ -318,6 +320,8 @@ function infer_shape(self :: SymbolicNode; kwargs...) keys = AbstractString[string(x[1]) for x in kwargs] _infer_shape(self, keys, indptr, sdata) end +infer_shape(self :: SymbolicNode; kwargs...) = infer_shape(self, Dict(kwargs)) + function infer_shape(self :: SymbolicNode, args :: Union{Tuple, Void}...) sdata = MX_uint[] indptr = MX_uint[0] @@ -365,6 +369,7 @@ end """ infer_type(self :: SymbolicNode; kwargs...) + infer_type(self :: SymbolicNode, args::Dict) infer_type(self :: SymbolicNode, args...) Do type inference according to the input types. The input types could be provided @@ -376,11 +381,12 @@ Returns a 3-tuple containing types of all the arguments, types of all the output types of all the auxiliary variables. If type inference failed due to incomplete or incompatible inputs, the return value will be `(nothing, nothing, nothing)`. """ -function infer_type(self :: SymbolicNode; kwargs...) +function infer_type(self :: SymbolicNode, kwargs::Dict) types = Cint[toTypeFlag(x[2]) for x in kwargs] keys = AbstractString[string(x[1]) for x in kwargs] _infer_type(self, keys, types) end +infer_type(self :: SymbolicNode; kwargs...) = infer_type(self, Dict(kwargs)) function infer_type(self :: SymbolicNode, args :: Union{Tuple, Void}...) types = Cint[] From 3a79bf2e764ef91050bf451246773ab6e85d5701 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 23 Jan 2017 14:11:09 +0900 Subject: [PATCH 15/18] start thinking about pipelines and native modules --- src/module/Module.jl | 2 ++ src/module/native_module.jl | 10 +++++++++ src/module/pipeline.jl | 39 ++++++++++++++++++++++++++++++++++ test/unittest/symbol-module.jl | 3 +++ 4 files changed, 54 insertions(+) create mode 100644 src/module/native_module.jl create mode 100644 src/module/pipeline.jl diff --git a/src/module/Module.jl b/src/module/Module.jl index e8a5afc25..4240e1d9a 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -516,5 +516,7 @@ end # include implementations include("symbol_module.jl") +include("pipeline.jl") +include("native_module.jl") end diff --git a/src/module/native_module.jl b/src/module/native_module.jl new file mode 100644 index 000000000..99e2b8699 --- /dev/null +++ b/src/module/native_module.jl @@ -0,0 +1,10 @@ +""" + NativeModule + +Allows the implementation of a MXNet module in native Julia. NDArrays +will be translated into native Julia arrays. +""" +type NativeModule{F<:Function,B<:Function} <: AbstractModule + forward :: F + backward :: B +end diff --git a/src/module/pipeline.jl b/src/module/pipeline.jl new file mode 100644 index 000000000..394fa16a7 --- /dev/null +++ b/src/module/pipeline.jl @@ -0,0 +1,39 @@ +abstract PipelineModule <: AbstractModule + +""" + SimplePipelineModule + +Allows the pipelining of several modules. + +# Arguments: +* `pipeline :: Vector{Module}` + The elements that are called sequentially + +# Functionality +* +""" +type SimplePipelineModule <: PipelineModule + pipeline :: Vector{Module} +end + +type ModuleDataProvider <: mx.AbstractDataProvider + mod :: Module +end + + +function forward(self :: SimplePipelineModule) + for mod in self.pipeline + forward(mod) + end +end + +function backward(self :: SimplePipelineModule) + for i in length(self.pipeline):-1:1 + mod = self.pipeline[i] + backward(mod) + end +end + +function get_outputs(self :: SimplePipelineModule) + return get_outputs(last(self.pipeline)) +end diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index 160d8f64a..f444f78c5 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -133,6 +133,9 @@ function test_linear_regression(n_epoch::Int = 10) @test sum(abs(ha_pred-y_pred)) < 1e-1 end +function test_simplepipeline() +end + ################################################################################ # Run tests ################################################################################ From 48c849fd631c63f106178fc1ff510b5b2cd9a39e Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Mon, 23 Jan 2017 00:45:12 +0300 Subject: [PATCH 16/18] Initial layout for sequential module Begin testing --- src/executor-group.jl | 13 +- src/io.jl | 3 + src/module/Module.jl | 45 +++-- src/module/sequential_module.jl | 287 +++++++++++++++++++++++++++++ src/module/symbol_module.jl | 60 +++--- test/test-module.jl | 15 +- test/unittest/sequential-module.jl | 50 +++++ test/unittest/symbol-module.jl | 8 +- 8 files changed, 427 insertions(+), 54 deletions(-) create mode 100644 src/module/sequential_module.jl create mode 100644 test/unittest/sequential-module.jl diff --git a/src/executor-group.jl b/src/executor-group.jl index 2fbd5ff05..8d86f2084 100644 --- a/src/executor-group.jl +++ b/src/executor-group.jl @@ -22,8 +22,8 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup context :: Vector{Context} execs :: Vector{Executor} - data_shapes :: Vector{Tuple{Vararg{Int}}} - label_shapes :: Vector{Tuple{Vararg{Int}}} + data_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} + label_shapes :: Dict{Symbol, Tuple{Vararg{Int}}} for_training :: Bool slices :: Vector{UnitRange{Int}} @@ -161,6 +161,9 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context input_grad_arrays = [] end + data_shapes = Dict(name => shape for (name, shape) in zip(data_names, data_shapes)) + label_shapes = Dict(name => shape for (name, shape) in zip(label_names, label_shapes)) + return DataParallelExecutorGroup( symbol, context, execs, data_shapes, label_shapes, for_training, slices, batch_size, @@ -340,7 +343,7 @@ function get_outputs(self::DataParallelExecutorGroup, merge_multi_context::Bool= # output was used. _merge_multi_context creates new array # each time it is called. Need to benchmark, may be it's better # to predefine cpu_output_arrays in self. - return _merge_multi_context(outputs) + return [concatenate(tensors, always_copy=false) for tensors in outputs] else return outputs end @@ -367,7 +370,7 @@ function get_input_grads(self::DataParallelExecutorGroup, merge_multi_context::B !self.inputs_need_grad && NDArray[] if merge_multi_context - return _merge_multi_context(self.input_grad_arrays) + return [concatenate(tensors, always_copy=false) for tensors in self.input_grad_arrays] end return self.input_grad_arrays @@ -421,5 +424,3 @@ function get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, return grad_req_dict, freeze_idx end - -_merge_multi_context(outputs) = [concatenate(tensors, always_copy=false) for tensors in outputs] diff --git a/src/io.jl b/src/io.jl index 89f6be129..218827247 100644 --- a/src/io.jl +++ b/src/io.jl @@ -121,6 +121,9 @@ type DataBatch <: AbstractDataBatch label :: Vector{NDArray} count :: Int end +DataBatch(provider::AbstractDataProvider, batch::AbstractDataBatch) = + DataBatch(get_data(provider, batch), get_label(provider, batch), count_samples(provider, batch)) +DataBatch(batch::DataBatch) = DataBatch(batch.data, batch.label, batch.count) count_samples(batch :: DataBatch) = batch.count get_data{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.data get_label{Provider<:AbstractDataProvider}(::Provider, batch :: DataBatch) = batch.label diff --git a/src/module/Module.jl b/src/module/Module.jl index 4240e1d9a..5bacf04b2 100644 --- a/src/module/Module.jl +++ b/src/module/Module.jl @@ -1,12 +1,12 @@ module Module import ....MXNet: mx -import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider -import ..mx: SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, +import ..mx: DataBatch, AbstractDataProvider, AbstractDataBatch, DataBatchProvider, + SymbolicNode, NDArray, Context, Executor, list_arguments, infer_shape, GRAD_NOP, AbstractExecutorGroup, list_outputs, DataParallelExecutorGroup, KVStore, OptimizationState, ADAM, UniformInitializer, set_params!, AbstractOptimizer, get_updater, update_params, provide_data, provide_label, AbstractEvalMetric, StubProvider, init, copy!, - concatenate, eachdatabatch, reset!, Accuracy + concatenate, eachdatabatch, reset!, Accuracy, @defstruct, AbstractInitializer """ AbstractModule @@ -187,6 +187,13 @@ function set_params(self::AbstractModule, init_params(self, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) end +@defstruct ModuleInitParamsOptions ( + initializer::AbstractInitializer=UniformInitializer(0.07), + arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), + allow_missing::Bool=false, + force_init::Bool=false +) """ init_params!(self; kwargs...) @@ -200,16 +207,31 @@ Initialize the parameters and auxiliary states. * `allow_missing` : `Bool`. If true, params could contain missing values, and the initializer will be called to fill those missing params. * `force_init` : `Bool`. If true, will force re-initialize even if already initialized. """ -function init_params(self :: AbstractModule, args...) - throw(MethodError(init_params, (self, args...))) +init_params(self :: AbstractModule; kwargs...) = init_params(self, ModuleInitParamsOptions(; kwargs...)) +function init_params(self :: AbstractModule, opts::ModuleInitParamsOptions) + throw(MethodError(init_params, (self, opts))) end ### # Setup ### -""" -""" -function bind(self :: AbstractModule, ) + +@defstruct ModuleBindOptions ( + for_training::Bool = true, + inputs_need_grad::Bool = true, + force_rebind::Bool = false, + grad_req::mx.GRAD_REQ = mx.GRAD_WRITE, + shared_module::Union{Void, AbstractModule} = nothing +) +""" +""" +bind(self::AbstractModule, data_provider::AbstractDataProvider; kwargs...) = + bind(self, + [x[2] for x in provide_data(data_provider)], + [x[2] for x in provide_label(data_provider)]; kwargs...) +bind(self::AbstractModule, data_shapes, label_shapes = Tuple{Int}[]; kwargs...) = bind(self, data_shapes, label_shapes, ModuleBindOptions(;kwargs...)) +function bind(self :: AbstractModule, data_shapes, label_shapes, opts::ModuleBindOptions) + throw(MethodError(bind, (self, data_shapes, label_shapes, opts))) end """ @@ -222,7 +244,7 @@ end ### """ """ -forward(self :: AbstractModule, data_batch :: DataBatch, is_train) = forward(self, StubProvider(), data_batch, is_train) +forward{T <: AbstractModule}(self :: T, data_batch :: DataBatch, is_train) = forward(self, StubProvider(), data_batch, is_train) function forward(self :: AbstractModule, provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train) throw(MethodError(forward, (self, ))) end @@ -272,7 +294,7 @@ end ### """ - fit(self::AbstractModule, train_data::AbstractDataProvider; kwargs...) + fit(self::AbstractModule, train_data::AbstractDataProvider, num_epoch::Int; kwargs...) Train the module parameters. @@ -516,7 +538,8 @@ end # include implementations include("symbol_module.jl") -include("pipeline.jl") +# include("pipeline.jl") include("native_module.jl") +include("sequential_module.jl") end diff --git a/src/module/sequential_module.jl b/src/module/sequential_module.jl new file mode 100644 index 000000000..66ced4bde --- /dev/null +++ b/src/module/sequential_module.jl @@ -0,0 +1,287 @@ +import ....MXNet: mx # in order to use mx. + +@defstruct SequentialModuleMetas ( + take_labels :: Bool = false, + auto_wiring :: Bool = false +) + +""" + SequentialModule + +A SequentialModule is a container module that can chain multiple modules together. +Note building a computation graph with this kind of imperative container is less +flexible and less efficient than the symbolic graph. So this should be only used as a +handy utility. + +# Parameters + +""" +type SequentialModule <: AbstractModule + modules :: Vector{AbstractModule} + metas :: Vector{SequentialModuleMetas} + + binded :: Bool + for_training :: Bool + inputs_need_grad :: Bool + params_initialized :: Bool + optimizer_initialized :: Bool + + label_names :: Vector{Symbol} + label_shapes :: Vector{Tuple{Vararg{Int}}} + function SequentialModule() + new(Vector{AbstractModule}(), + Vector{Symbol}[], + false, false, false, false, false) + end +end + +### default API +isbinded(self::SequentialModule) = self.binded +allows_training(self::SequentialModule) = self.for_training +isinitialized(self::SequentialModule) = self.params_initialized +hasoptimizer(self::SequentialModule) = self.optimizer_initialized + +data_names(self::SequentialModule) = length(self.modules) > 0 ? data_names(self.modules[1]) : Symbol[] +label_names(self::SequentialModule) = self.label_names +output_names(self::SequentialModule) = length(self.modules) > 0 ? output_names(self.modules[end]) : Symbol[] + +""" + add(self, module; kwargs...) + +Add a module to the chain. +# Arguments +* `module` : AbstractModule + The new module to add. +* `kwargs` : keywords + All the keyword arguments are saved as meta information + for the added module. The currently known meta includes + * `:take_labels`: indicating whether the module expect to + take labels when doing computation. Note any module in + the chain can take labels (not necessarily only the top + most one), and they all take the same labels passed + from the original data batch for the `SequentialModule`. + * `:auto_wiring`: TODO... + +# Returns + +This function returns `self` to allow us to easily chain a +series of `add` calls. + +# Examples +An example of addinging two modules to a chain:: +```julia +seq_mod = @mx.chain mx.Module.SequentialModule() => + add(mod1) => + add(mod2) +``` +""" +function add(self::SequentialModule, mod::AbstractModule; kwargs...) + push!(self.modules, mod) + + metas = SequentialModuleMetas(;kwargs...) + for (key, _) in (kwargs...) + @assert(key ∈ fieldnames(metas), "Unknown meta '$key', a typo?") + end + push!(self.metas, metas) + + # after adding new modules, we are reset back to raw states, needs + # to bind, init_params, etc. + self.binded = false + self.params_initialized = false + self.optimizer_initialized = false + + return self # for easier chaining +end + +function data_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return data_shapes(self.modules[1]) +end + +function label_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return self.label_shapes +end + +function output_shapes(self::SequentialModule) + !isbinded(self) && return Dict{Symbol, Vector{Tuple{Int}}}() + return output_shapes(self.modules[end]) +end + +function get_params(self::SequentialModule) + @assert isbinded(self) && isinitialized(self) + + reduce((Dict{Symbol, NDArray}(), Dict{Symbol, NDArray}()), self.modules) do acc, mod + arg, aux = get_params(mod) + merge(acc[1], arg) + merge(acc[2], aux) + end +end + +""" +""" +function init_params(self::SequentialModule, opts::ModuleInitParamsOptions) + if isinitialized(self) && !opts.force_init + return self + end + + @assert(isbinded(self), "call bind before initializing the parameters") + # make sure we do not have duplicated parameter names + arg_dict = Dict() + aux_dict = Dict() + duplicates = false + for (i, mod) in enumerate(self.modules) + arg_params, aux_params = get_params(mod) + map((arg_dict, arg_params), (aux_dict, aux_params)) do arg + dict, params = arg + for name in keys(params) + if haskey(dict, name) + info("Name $name in layer $i ($(typeof(mod))) is already used in layer $(dict[name][1])($(typeof(dict[name][2])))") + duplicates = true + else + dict[name] = (i, typeof(mod)) + end + end + end + end + if duplicates + error("Duplicates in layer names") + end + + for mod in self.modules + init_params(mod, opts) + end + + self.params_initialized = true + + return self +end + +""" +""" +function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBindOptions) + info("SequentialModule: label_shapes=$label_shapes") + if opts.inputs_need_grad + @assert opts.for_training + end + if opts.shared_module !== nothing + info("Shared module is not supported") + end + @assert(length(self.modules) > 0, "Attempting to bind empty SequentialModule") + + # the same label shapes are used for all chained modules + self.label_shapes = label_shapes + + module_data_shapes = data_shapes + anybody_ever_needs_label = false + for (i, mod) in enumerate(self.modules) + meta = self.metas[i] + if meta.take_labels + module_label_shapes = label_shapes + anybody_ever_needs_label = true + else + module_label_shapes = Tuple{Int}[] + end + + module_inputs_need_grad = opts.inputs_need_grad || (opts.for_training && i > 1) + + if meta.auto_wiring + data_names = data_names(mod) + @assert length(module_data_shapes) == length(data_names) + module_data_shapes = [(new_name, shape) for (new_name, (_, shape)) in zip(data_names, module_data_shapes)] + end + + bind(mod, module_data_shapes, module_label_shapes, + for_training=opts.for_training, inputs_need_grad=module_inputs_need_grad, + force_rebind=opts.force_rebind, shared_module=nothing, grad_req=opts.grad_req) + + # the output of the previous module is the data of the next module + module_data_shapes = output_shapes(mod) + end + + if !anybody_ever_needs_label + # then I do not need label either + self.label_shapes = Tuple{Int}[] + end + self.binded = true + + return self +end + +""" +""" +function init_optimizer(self::SequentialModule; optimizer::AbstractOptimizer=ADAM(), kvstore :: Union{Base.Symbol, KVStore}=:local, force_init :: Bool=false) + @assert isbinded(self) && isinitialized(self) + if hasoptimizer(self) && !force_init + warn("Optimizer already initialized, ignoring.") + end + for mod in self.modules + init_optimizer(mod, optimizer=optimizer, kvstore=kvstore, force_init=force_init) + end + + self.optimizer_initialized = true + return self +end + +""" +""" +function get_outputs(self::SequentialModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) + return get_outputs(last(self.modules), merge_multi_context) +end + +""" +""" +function forward(self::SequentialModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool=self.for_training) + @assert isbinded(self) && isinitialized(self) + + batch = DataBatch(data_provider, data_batch) + for (i, mod) in enumerate(self.modules) + forward(mod, batch, is_train) + batch.data = get_outputs(mod) + end +end + +""" +""" +function backward(self::SequentialModule, out_grads::Vector{NDArray}) + @assert isbinded(self) && isinitialized(self) + + for (i, mod) in reverse(zip(1:length(self.modules), self.modules)) + backward(mod, out_grads) + if i == 1 + break + end + + out_grads = get_input_grads(mod) + end +end + +""" +""" +function update(self::SequentialModule) + @assert isbinded(self) && isinitialized(self) && hasoptimizer(self) + + for mod in self.modules + update(mod) + end +end + +""" +""" +function update_metric(self::SequentialModule, eval_metric::AbstractEvalMetric, provider::AbstractDataProvider, batch::AbstractDataBatch) + @assert isbinded(self) && isinitialized(self) + for (meta, mod) in zip(self.metas, self.modules) + if meta.take_labels + update_metric(mod, eval_metric, provider, batch) + end + end +end + +""" +""" +function get_input_grads(self::SequentialModule, merge_multi_context::Bool=true) + @assert isbinded(self) && isinitialized(self) && self.inputs_need_grad + + return get_input_grads(self.modules[1], merge_multi_context) +end diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index 05d81879c..1c5e1ee1d 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -1,9 +1,9 @@ import ....MXNet: mx # in order to use mx. """ - Module + SymbolModule -Module is a basic module that wraps a `SymbolicNode`. It is functionally the same +SymbolModule is a basic module that wraps a `SymbolicNode`. It is functionally the same as the `FeedForward` model, except using the module API. A current limitation is that it only supports one context. @@ -46,7 +46,7 @@ type SymbolModule <: AbstractModule function SymbolModule(symbol::SymbolicNode, data_names::Vector{Symbol}, label_names::Vector{Symbol}, context :: Vector{Context}, - fixed_param_names::Nullable{Vector{Symbol}}) + fixed_param_names::Nullable{Vector{Symbol}}) aux_names = mx.list_auxiliary_states(symbol) return new(symbol, data_names, label_names, aux_names, context, @@ -64,10 +64,11 @@ end function SymbolModule(symbol::SymbolicNode; data_names = [:data], label_names = [:softmax_label], - context = [mx.cpu()], fixed_param_names = nothing) + context::Union{Context, Vector{Context}} = [mx.cpu()], + fixed_param_names = nothing) fixed_param_names = Nullable{Vector{Symbol}}(fixed_param_names) label_names = Vector{Symbol}(label_names) - context = _wrap_context(context) + context = Vector{Context}(context) @assert !isempty(data_names) @assert !isempty(context) return SymbolModule(symbol, data_names, label_names, context, fixed_param_names) @@ -119,13 +120,8 @@ function get_params(self::SymbolModule) return (self.arg_params, self.aux_params) end -function init_params(self::SymbolModule; - initializer=UniformInitializer(0.07), - arg_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), - aux_params::Dict{Symbol, NDArray}=Dict{Symbol, NDArray}(), - allow_missing=false, force_init=false) - - if isinitialized(self) && !force_init +function init_params(self::SymbolModule, opts::ModuleInitParamsOptions) + if isinitialized(self) && !opts.force_init return self end @assert isbinded(self) "Call `bind` before initialization" @@ -138,11 +134,11 @@ function init_params(self::SymbolModule; self.aux_params = Dict(k => mx.empty(size(v)) for (k, v) in self.exec_group.aux_params) end - map([[self.arg_params, arg_params], [self.aux_params, aux_params]]) do param_arr + map([[self.arg_params, opts.arg_params], [self.aux_params, opts.aux_params]]) do param_arr dst, src = param_arr for (name, arr) in dst if isempty(src) - init(initializer, name, arr) + init(opts.initializer, name, arr) else src = get(src) if name in keys(src) @@ -150,8 +146,8 @@ function init_params(self::SymbolModule; copy!(arr, src[name]) end else - @assert(!allow_missing, "$name is not presented") - init(initializer, name, arr) + @assert(!opts.allow_missing, "$name is not presented") + init(opts.initializer, name, arr) end end end @@ -166,12 +162,8 @@ function init_params(self::SymbolModule; return self end -bind(self::SymbolModule, data_provider::AbstractDataProvider; kwargs...) = bind(self, map((x) -> x[2], provide_data(data_provider)), - map((x) -> x[2], provide_label(data_provider)); kwargs...) -function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}}(); - for_training=true, inputs_need_grad=true, force_rebind=false, - grad_req=mx.GRAD_WRITE, shared_group = nothing) - if force_rebind +function bind(self::SymbolModule, data_shapes, label_shapes, opts::ModuleBindOptions) + if opts.force_rebind reset_bind(self) end @@ -180,15 +172,16 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}} return self end - if !for_training - @assert !inputs_need_grad + if !opts.for_training + @assert !opts.inputs_need_grad end - self.for_training = for_training - self.inputs_need_grad = inputs_need_grad + self.for_training = opts.for_training + self.inputs_need_grad = opts.inputs_need_grad self.binded = true + info("SymbolModule: self.label_names=$(self.label_names); self.label_shapes=$(self.label_shapes)") @assert length(self.data_names) == length(data_shapes) @assert length(self.label_names) == length(label_shapes) @@ -199,11 +192,17 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}} data_types = [Float32 for _ in 1:length(self.data_names)] label_types = [Float32 for _ in 1:length(self.label_names)] + if opts.shared_module !== nothing + shared_group = Nullable(opts.shared_module.exec_group) + else + shared_group = Nullable{DataParallelExecutorGroup}() + end + self.exec_group = DataParallelExecutorGroup(self.symbol, self.context, self.data_shapes, self.data_names, data_types, self.label_shapes, self.label_names, label_types, self.for_training, self.inputs_need_grad, shared_group, - self.fixed_param_names, grad_req) + self.fixed_param_names, opts.grad_req) return self end @@ -295,8 +294,8 @@ Forward computation. * `is_train` : Bool Default is `nothing`, which means `is_train` takes the value of `self.for_training`. """ -forward(self::SymbolModule, data_batch::DataBatch) = forward(self, StubProvider(), data_batch, self.for_training) -function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool) +forward(self::SymbolModule, data_batch::DataBatch) = forward(self, data_batch, self.for_training) +function forward(self::SymbolModule, data_provider :: AbstractDataProvider, data_batch :: AbstractDataBatch, is_train::Bool=self.for_training) @assert isbinded(self) && isinitialized(self) mx.forward(self.exec_group, data_provider, data_batch, is_train) end @@ -385,6 +384,3 @@ function borrow_optimizer!(self::SymbolModule, shared_module::SymbolModule) self.updater = shared_module.updater self.optimizer_initialized = true end - -_wrap_context(context::Context) = [context] -_wrap_context(context::Vector{Context}) = context diff --git a/test/test-module.jl b/test/test-module.jl index 66a1f63df..da0daaf73 100644 --- a/test/test-module.jl +++ b/test/test-module.jl @@ -1,4 +1,17 @@ using MXNet +using Base.Test + +# run test in the whole directory, latest modified files +# are run first, this makes waiting time shorter when writing +# or modifying unit-tests +function test_dir(dir) + jl_files = sort(filter(x -> ismatch(r".*module\.jl$", x), readdir(dir)), by = fn -> stat(joinpath(dir,fn)).mtime) + map(reverse(jl_files)) do file + include("$dir/$file") + end +end include(joinpath(dirname(@__FILE__), "common.jl")) -include(joinpath(dirname(@__FILE__), "unittest", "symbol-module.jl")) +@testset "Modules Test" begin + test_dir(joinpath(dirname(@__FILE__), "unittest")) +end diff --git a/test/unittest/sequential-module.jl b/test/unittest/sequential-module.jl new file mode 100644 index 000000000..4fac18e56 --- /dev/null +++ b/test/unittest/sequential-module.jl @@ -0,0 +1,50 @@ +module TestSequentialModule +using MXNet +using Base.Test + +using ..Main: reldiff + +################################################################################ +# Utils +################################################################################ + +################################################################################ +# Test Implementations +################################################################################ + +function test_basic() + info("SequentialModule::basic") + + net1 = @mx.chain mx.Variable(:data) => + mx.FullyConnected(name=:fc1, num_hidden=4) + net2 = @mx.chain mx.FullyConnected(mx.SymbolicNode, name=:fc2, num_hidden=1) => + mx.LinearRegressionOutput(name=:linout) + + m1 = mx.Module.SymbolModule(net1, label_names=Symbol[]) + m2 = mx.Module.SymbolModule(net2) + seq_mod = mx.Module.SequentialModule() + mx.Module.add(seq_mod, m1) + mx.Module.add(seq_mod, m2, take_labels=true) + @test !mx.Module.isbinded(seq_mod) + @test !mx.Module.allows_training(seq_mod) + @test !mx.Module.isinitialized(seq_mod) + @test !mx.Module.hasoptimizer(seq_mod) + + @test mx.Module.data_names(seq_mod) == [:data] + @test mx.Module.output_names(seq_mod) == [:linout_output] + + mx.Module.bind(seq_mod, [(4, 10)], [(1, 10)]) + @test mx.Module.isbinded(seq_mod) + @test !mx.Module.isinitialized(seq_mod) + @test !mx.Module.hasoptimizer(seq_mod) +end + +################################################################################ +# Run tests +################################################################################ + +@testset " Sequential Module Test" begin + test_basic() +end + +end diff --git a/test/unittest/symbol-module.jl b/test/unittest/symbol-module.jl index f444f78c5..f9ad74b08 100644 --- a/test/unittest/symbol-module.jl +++ b/test/unittest/symbol-module.jl @@ -41,7 +41,7 @@ function test_basic() @test mx.Module.data_names(m1) == [:data] @test mx.Module.output_names(m1) == [:softmax_output] - mx.Module.bind(m1, [(20, 20, 1, 10)], [(400, 10)]) + mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) @test mx.Module.isbinded(m1) @test !mx.Module.isinitialized(m1) @test !mx.Module.hasoptimizer(m1) @@ -57,10 +57,10 @@ function test_shapes() info("SymbolModule::Shapes") m1 = mx.Module.SymbolModule(create_network()) - mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)]) + mx.Module.bind(m1, [(20, 20, 1, 10)], [(400, 10)]) @test mx.Module.data_shapes(m1) == Dict(:data => (20, 20, 1, 10)) - @test mx.Module.label_shapes(m1) == Dict(:softmax_label => (20, 20, 1, 10)) + @test mx.Module.label_shapes(m1) == Dict(:softmax_label => (400, 10)) @test mx.Module.output_shapes(m1) == Dict(:softmax_output => (20, 20, 64, 10)) m2 = mx.Module.SymbolModule(create_network(), label_names=[]) @@ -140,7 +140,7 @@ end # Run tests ################################################################################ -@testset "Symbol Module Test" begin +@testset " Symbol Module Test" begin test_basic() test_shapes() #= test_init_params(500) =# From 133951ee19296f3aa608d320e29ddb811bfd7b66 Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Wed, 25 Jan 2017 00:10:08 +0300 Subject: [PATCH 17/18] Sequential Module, pre alpha --- src/module/sequential_module.jl | 10 +++++----- src/module/symbol_module.jl | 5 ++++- test/unittest/sequential-module.jl | 9 +++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/module/sequential_module.jl b/src/module/sequential_module.jl index 66ced4bde..f989bad89 100644 --- a/src/module/sequential_module.jl +++ b/src/module/sequential_module.jl @@ -1,4 +1,5 @@ import ....MXNet: mx # in order to use mx. +import Base.push! @defstruct SequentialModuleMetas ( take_labels :: Bool = false, @@ -46,7 +47,7 @@ label_names(self::SequentialModule) = self.label_names output_names(self::SequentialModule) = length(self.modules) > 0 ? output_names(self.modules[end]) : Symbol[] """ - add(self, module; kwargs...) + push!(self, module; kwargs...) Add a module to the chain. # Arguments @@ -71,11 +72,11 @@ series of `add` calls. An example of addinging two modules to a chain:: ```julia seq_mod = @mx.chain mx.Module.SequentialModule() => - add(mod1) => - add(mod2) + mx.Module.push!(mod1) => + mx.Module.push!(mod2) ``` """ -function add(self::SequentialModule, mod::AbstractModule; kwargs...) +function push!(self::SequentialModule, mod::AbstractModule; kwargs...) push!(self.modules, mod) metas = SequentialModuleMetas(;kwargs...) @@ -160,7 +161,6 @@ end """ """ function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBindOptions) - info("SequentialModule: label_shapes=$label_shapes") if opts.inputs_need_grad @assert opts.for_training end diff --git a/src/module/symbol_module.jl b/src/module/symbol_module.jl index 1c5e1ee1d..f5935b821 100644 --- a/src/module/symbol_module.jl +++ b/src/module/symbol_module.jl @@ -180,8 +180,11 @@ function bind(self::SymbolModule, data_shapes, label_shapes, opts::ModuleBindOpt self.inputs_need_grad = opts.inputs_need_grad self.binded = true + wrap_in_vector(x::Vector, names) = x + wrap_in_vector(x::Dict, names) = [x[name] for name in names] + data_shapes = wrap_in_vector(data_shapes, self.data_names) + label_shapes = wrap_in_vector(label_shapes, self.label_names) - info("SymbolModule: self.label_names=$(self.label_names); self.label_shapes=$(self.label_shapes)") @assert length(self.data_names) == length(data_shapes) @assert length(self.label_names) == length(label_shapes) diff --git a/test/unittest/sequential-module.jl b/test/unittest/sequential-module.jl index 4fac18e56..e4318424c 100644 --- a/test/unittest/sequential-module.jl +++ b/test/unittest/sequential-module.jl @@ -17,14 +17,15 @@ function test_basic() net1 = @mx.chain mx.Variable(:data) => mx.FullyConnected(name=:fc1, num_hidden=4) - net2 = @mx.chain mx.FullyConnected(mx.SymbolicNode, name=:fc2, num_hidden=1) => + net2 = @mx.chain mx.Variable(:fc1_output) => + mx.FullyConnected(name=:fc2, num_hidden=1) => mx.LinearRegressionOutput(name=:linout) m1 = mx.Module.SymbolModule(net1, label_names=Symbol[]) - m2 = mx.Module.SymbolModule(net2) + m2 = mx.Module.SymbolModule(net2, data_names=[:fc1_output], label_names=[:linout_label]) seq_mod = mx.Module.SequentialModule() - mx.Module.add(seq_mod, m1) - mx.Module.add(seq_mod, m2, take_labels=true) + mx.Module.push!(seq_mod, m1) + mx.Module.push!(seq_mod, m2, take_labels=true) @test !mx.Module.isbinded(seq_mod) @test !mx.Module.allows_training(seq_mod) @test !mx.Module.isinitialized(seq_mod) From e8c5bf1fa3e80e199e12a94a16ea0ef9eb56c9c0 Mon Sep 17 00:00:00 2001 From: Andrej Oskin Date: Thu, 26 Jan 2017 01:13:37 +0300 Subject: [PATCH 18/18] Changed bind behaviour --- src/module/sequential_module.jl | 47 ++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/module/sequential_module.jl b/src/module/sequential_module.jl index f989bad89..811e21c2d 100644 --- a/src/module/sequential_module.jl +++ b/src/module/sequential_module.jl @@ -3,7 +3,8 @@ import Base.push! @defstruct SequentialModuleMetas ( take_labels :: Bool = false, - auto_wiring :: Bool = false + auto_wiring :: Bool = false, + join :: Dict{Symbol, Symbol} = Dict{Symbol, Symbol}() ) """ @@ -29,10 +30,11 @@ type SequentialModule <: AbstractModule label_names :: Vector{Symbol} label_shapes :: Vector{Tuple{Vararg{Int}}} - function SequentialModule() + function SequentialModule(label_names = [:softmax_label]) new(Vector{AbstractModule}(), Vector{Symbol}[], - false, false, false, false, false) + false, false, false, false, false, + label_names) end end @@ -56,13 +58,16 @@ Add a module to the chain. * `kwargs` : keywords All the keyword arguments are saved as meta information for the added module. The currently known meta includes - * `:take_labels`: indicating whether the module expect to - take labels when doing computation. Note any module in - the chain can take labels (not necessarily only the top - most one), and they all take the same labels passed - from the original data batch for the `SequentialModule`. - * `:auto_wiring`: TODO... - + * `take_labels`: `Bool` indicating whether the module expect to + take labels when doing computation. Note any module in + the chain can take labels (not necessarily only the top + most one), and they all take the same labels passed + from the original data batch for the `SequentialModule`. + * `auto_wiring`: `Bool` transfer data outputs of previous module + to the data inputs of next module, with the order given by + `output_names()` + * `join`: `Dict{Symbol, Symbol}`. + # Returns This function returns `self` to allow us to easily chain a @@ -169,10 +174,15 @@ function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBin end @assert(length(self.modules) > 0, "Attempting to bind empty SequentialModule") + wrap_in_dict(x::Vector, names) = Dict(k => v for (k, v) in zip(names, x)) + wrap_in_dict(x::Dict, names) = x + data_shapes = wrap_in_dict(data_shapes, data_names(self)) + # the same label shapes are used for all chained modules self.label_shapes = label_shapes module_data_shapes = data_shapes + module_data_names = data_names(self) anybody_ever_needs_label = false for (i, mod) in enumerate(self.modules) meta = self.metas[i] @@ -184,11 +194,17 @@ function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBin end module_inputs_need_grad = opts.inputs_need_grad || (opts.for_training && i > 1) - - if meta.auto_wiring - data_names = data_names(mod) - @assert length(module_data_shapes) == length(data_names) - module_data_shapes = [(new_name, shape) for (new_name, (_, shape)) in zip(data_names, module_data_shapes)] + + @assert length(module_data_names) == length(data_names(mod)) + if length(module_data_names) == 1 + module_data_shapes = collect(values(module_data_shapes)) + elseif Set(module_data_names) != Set(data_names(mod)) + if meta.auto_wiring + module_data_shapes = [module_data_shapes[name] for name in module_data_names] + else + @assert union(setdiff(Set(module_data_names), Set(keys(meta.join))), Set(values(meta.join))) == Set(data_names(mod)) + module_data_shapes = Dict(get(meta.join, name, name) => value for (name, value) in zip(module_data_names, module_data_shapes)) + end end bind(mod, module_data_shapes, module_label_shapes, @@ -196,6 +212,7 @@ function bind(self::SequentialModule, data_shapes, label_shapes, opts::ModuleBin force_rebind=opts.force_rebind, shared_module=nothing, grad_req=opts.grad_req) # the output of the previous module is the data of the next module + module_data_names = output_names(mod) module_data_shapes = output_shapes(mod) end