Skip to content

Commit

Permalink
add type inference to executor-group
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Jan 21, 2017
1 parent 75fcd81 commit 2426f16
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 27 deletions.
61 changes: 41 additions & 20 deletions src/executor-group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ function forward(self::AbstractExecutorGroup, data_provider :: AbstractDataProvi
throw(MethodError(forward, (self, )))
end

"""
DataParallelExecutorGroup
Supports:
- Fixed parameters (freezing)
- Shape inference
- Type inference
"""
type DataParallelExecutorGroup <: AbstractExecutorGroup
symbol :: SymbolicNode
context :: Vector{Context}
Expand All @@ -34,15 +42,16 @@ type DataParallelExecutorGroup <: AbstractExecutorGroup
arg_params :: Dict{Symbol, NDArray}
aux_params :: Dict{Symbol, NDArray}
end

function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context},
data_shapes, data_names, label_shapes, label_names, for_training, inputs_need_grad,
shared_group, fixed_param_names, grad_req)
data_shapes, data_names, data_types, label_shapes, label_names, label_types,
for_training, inputs_need_grad, shared_group, fixed_param_names, grad_req)

num_dev = length(context)
arg_names = list_arguments(symbol)
input_names = [data_names; label_names]
param_names = setdiff(arg_names, input_names)
aux_names = list_auxiliary_states(symbol)
aux_names = list_auxiliary_states(symbol)

batch_size = data_shapes[1][end]
for shape in data_shapes
Expand All @@ -54,60 +63,72 @@ function DataParallelExecutorGroup(symbol::SymbolicNode, context::Vector{Context
end
end

# TODO imlplement workload
# TODO implement workload
slices = _split_inputs(batch_size, num_dev)

execs = Vector{Executor}(num_dev)

provided_shapes = merge(Dict(name => shape for (name, shape) in zip(data_names, data_shapes)),
Dict(name => shape for (name, shape) in zip(label_names, label_shapes)))
# Shape inference based on data_shapes and label_shapes
provided_shapes = merge(
Dict(name => shape for (name, shape) in zip(data_names, data_shapes)),
Dict(name => shape for (name, shape) in zip(label_names, label_shapes))
)

arg_shapes, out_shapes, aux_shapes = infer_shape(symbol; provided_shapes...)
@assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference")

# Type inference based on data_types and lable_types
provided_types = merge(
Dict(name => _type for (name, _type) in zip(data_names, data_types)),
Dict(name => _type for (name, _type) in zip(label_names, label_types))
)

arg_types, out_types, aux_types = infer_type(symbol; provided_types...)

# Check for what arg we needs gradients and which are frozen
grad_req, freeze_idx = get_grads(symbol, param_names, arg_names, data_names, inputs_need_grad, fixed_param_names, grad_req)

arg_params = Dict{Symbol, NDArray}()
aux_params = Dict{Symbol, NDArray}()

for (name, shape) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes))
arg_params[name] = empty(shape)
for (name, shape, T) in filter(x -> in(x[1], param_names), zip(arg_names, arg_shapes, arg_types))
arg_params[name] = empty(T, shape)
end

for (name, shape) in zip(aux_names, aux_shapes)
aux_params[name] = empty(shape)
for (name, shape, T) in zip(aux_names, aux_shapes, aux_types)
aux_params[name] = empty(T, shape)
end

for i = 1:num_dev
data_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(data_names, data_shapes))
label_shapes = Dict(k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in zip(label_names, label_shapes))
arg_arrays = NDArray[zeros(shape, context[i]) for shape in arg_shapes]
arg_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(arg_shapes, arg_types)]
grad_arrays = Dict{Symbol,NDArray}()
aux_arrays = NDArray[zeros(shape, context[i]) for shape in aux_shapes]
aux_arrays = NDArray[zeros(T, shape, context[i]) for (shape, T) in zip(aux_shapes, aux_types)]

shapes = zip(arg_names, arg_shapes)
# Process arguments
arg_info = zip(arg_names, arg_shapes, arg_types)

# if not in provided data, should be parameters
if inputs_need_grad
provided_data_names = label_names
else
provided_data_names = [data_names; label_names]
end
shapes = filter(x -> !in(x[1], provided_data_names), shapes)
arg_info = filter(x -> !in(x[1], provided_data_names), arg_info)

# Remove all gradients for nop params
shapes = filter(x -> grad_req[x[1]] != GRAD_NOP, shapes)
arg_info = filter(x -> grad_req[x[1]] != GRAD_NOP, arg_info)

for (name, shape) in shapes
grad_arrays[name] = zeros(shape, context[i])
for (name, shape, T) in arg_info
grad_arrays[name] = zeros(T, shape, context[i])
end

execs[i] = bind(symbol, context[i], arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays)
#= dbg_str = mx.debug_str(train_execs[i]) =#
#= info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i])) =#
end

# TODO: perform type inference

# set up input data structures
data_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in data_names]
label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(execs)] for name in label_names]
Expand Down Expand Up @@ -146,7 +167,7 @@ function forward(self:: DataParallelExecutorGroup, data_provider :: AbstractData

load_data!(data_provider, data_batch, self.data_arrays)
is_train = get(is_train, self.for_training)

if is_train && !isempty(get_label(data_provider, data_batch))
load_label!(data_provider, data_batch, self.label_arrays)
end
Expand Down
12 changes: 8 additions & 4 deletions src/module/symbol_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ function bind(self::SymbolModule, data_shapes, label_shapes = Vector{Tuple{Int}}
self.data_shapes = data_shapes
self.label_shapes = label_shapes

# TODO propagate type information
data_types = [Float32 for _ in 1:length(self.data_names)]
label_types = [Float32 for _ in 1:length(self.label_names)]

self.exec_group = DataParallelExecutorGroup(self.symbol, self.context,
self.data_shapes, self.data_names,
self.label_shapes, self.label_names,
self.data_shapes, self.data_names, data_types,
self.label_shapes, self.label_names, label_types,
self.for_training, self.inputs_need_grad, shared_group,
self.fixed_param_names, grad_req)
return self
Expand Down Expand Up @@ -213,7 +217,7 @@ function init_optimizer(self::SymbolModule; optimizer::AbstractOptimizer=ADAM(),
end
end
end

# TODO add preloaded states
#= if !isa(self.preload_opt_states, Void) =#
#= load_optimizer_states!(self, self.preload_opt_states) =#
Expand Down Expand Up @@ -284,7 +288,7 @@ end

"""
borrow_optimizer!(module, shared_module)
Borrow optimizer from a shared module. Used in bucketing, where exactly the same
Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
# Arguments
* `module` : SymbolModule
Expand Down
6 changes: 3 additions & 3 deletions test/unittest/symbol-module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

function test_basic()
info("SymbolModule::basic")

m1 = mx.Module.SymbolModule(create_network())

@test !mx.Module.isbinded(m1)
Expand All @@ -37,7 +37,7 @@ function test_basic()

@test mx.Module.data_names(m1) == [:data]
@test mx.Module.output_names(m1) == [:softmax_output]

mx.Module.bind(m1, [(20, 20, 1, 10)], [(20, 20, 1, 10)])
@test mx.Module.isbinded(m1)
@test !mx.Module.isinitialized(m1)
Expand All @@ -52,7 +52,7 @@ end

function test_init_params()
info("SymbolModule::InitParams")
m1 = mx.Module.SymbolModule(create_single_neuron(),
m1 = mx.Module.SymbolModule(create_single_neuron(),
label_names = [:linout_label])
mx.Module.bind(m1, [(1, 10)], [(1, 10)])
mx.Module.init_params(m1)
Expand Down

0 comments on commit 2426f16

Please sign in to comment.