Skip to content

Commit

Permalink
Merge pull request #786 from null-a/param-init
Browse files Browse the repository at this point in the history
More flexible parameter initialization.
  • Loading branch information
stuhlmueller authored Mar 3, 2017
2 parents 6f7f2b7 + 35526be commit 3bd0d8e
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 135 deletions.
24 changes: 17 additions & 7 deletions docs/optimization/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ Parameters

.. js:function:: param([options])

Retrieves the value of a parameter by name. If the parameter does
not exist, it is created and initialized with a draw from a
Gaussian distribution.
Retrieves the value of a parameter by name. The parameter is
created if it does not already exist.

The following options are supported:

Expand All @@ -18,19 +17,29 @@ Parameters

When ``dims`` is omitted, ``param`` returns a scalar.

.. describe:: init

A function that computes the initial value of the parameter. The
function is passed the dimension of a tensor as its only
argument, and should return a tensor of that dimension.

When ``init`` is omitted, the parameter is initialized with a
draw from the Gaussian distribution described by the ``mu`` and
``sigma`` options.

.. describe:: mu

The mean of the Gaussian distribution from which the initial
parameter value is drawn.
parameter value is drawn when ``init`` is omitted.

Default: ``0``

.. describe:: sigma

The standard deviation of the Gaussian distribution from which
the initial parameter value is drawn. Specify a standard
deviation of ``0`` to deterministically initialize the parameter
to ``mu``.
the initial parameter value is drawn when ``init`` is omitted.
Specify a standard deviation of ``0`` to deterministically
initialize the parameter to ``mu``.

Default: ``0.1``

Expand All @@ -46,6 +55,7 @@ Parameters
param({name: 'myparam'})
param({mu: 0, sigma: 0.01, name: 'myparam'})
param({dims: [10, 10]})
param({dims: [2, 1], init: function(dims) { return ones(dims); }})

.. js:function:: modelParam([options])

Expand Down
7 changes: 4 additions & 3 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ function makeParam(paramSpec, paramName, baseName, env) {
}

function registerParam(env, name, dims) {
return params.register(env, name, function() {
return [new Tensor(dims)];
})[0];
if (!params.exists(name)) {
params.create(name, new Tensor(dims));
}
return params.fetch(name, env);
}

// This function generates a description of the guide distribution
Expand Down
26 changes: 26 additions & 0 deletions src/headerUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ module.exports = function(env) {
return wpplFn.apply(global, [s, k, a].concat(args));
}

function notAllowed(fn, name) {
return function() {
throw new Error(fn + ' is not allowed in ' + name + '.');
};
}

function makeDeterministicCoroutine(name) {
return {
sample: notAllowed('sample', name),
factor: notAllowed('factor', name),
incrementalize: env.defaultCoroutine.incrementalize
};
}

// Applies a deterministic function. Attempts by wpplFn to call
// sample or factor generate an error.
function applyd(s, k, a, wpplFn, args, name) {
var coroutine = env.coroutine;
env.coroutine = makeDeterministicCoroutine(name);
return apply(s, function(s, val) {
env.coroutine = coroutine;
return k(s, val);
}, a, wpplFn, args);
}

// Annotating a function object with its lexical id and
// a list of its free variable values.
var __uniqueid = 0;
Expand Down Expand Up @@ -182,6 +207,7 @@ module.exports = function(env) {
display: display,
cache: cache,
apply: apply,
applyd: applyd,
_Fn: _Fn,
_addr: _addr,
zeros: zeros,
Expand Down
21 changes: 2 additions & 19 deletions src/inference/driftKernel.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ var util = require('../util');

module.exports = function(env) {

var driftKernelCoroutine = {
sample: notAllowed('sample'),
factor: notAllowed('factor'),
incrementalize: env.defaultCoroutine.incrementalize
};
var applyd = require('../headerUtils')(env).applyd;

// A cps function to get the MH proposal distribution based on the
// args passed to a sample statement and the value selected for this
Expand All @@ -24,26 +20,13 @@ module.exports = function(env) {

function getProposalDist(s, a, dist, options, prevVal, k) {
if (options && options.driftKernel) {
var coroutine = env.coroutine;
env.coroutine = driftKernelCoroutine;

return options.driftKernel(s, function(s, val) {
// Restore the previous coroutine.
env.coroutine = coroutine;
return k(s, val);
}, a, prevVal);
return applyd(s, k, a, options.driftKernel, [prevVal], 'drift kernel');
} else {
// Use the prior as the proposal distribution.
return k(s, dist);
}
}

function notAllowed(fn) {
return function() {
throw new Error(fn + ' not allowed inside drift kernels.');
};
}

// We show a warning when the score of a drift proposal is -Infinity
// as it's likely this is caused by a bug in the drift kernel
// function.
Expand Down
111 changes: 59 additions & 52 deletions src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,31 @@ var guide = require('../guide');

module.exports = function(env) {

function ForwardSample(s, k, a, wpplFn, options) {
this.opts = util.mergeDefaults(options, {
samples: 1,
guide: false, // true = sample guide, false = sample target
onlyMAP: false,
verbose: false
});

this.wpplFn = wpplFn;
function RunForward(s, k, a, wpplFn, sampleGuide) {
this.s = s;
this.k = k;
this.a = a;
this.guideRequired = this.opts.guide;
this.wpplFn = wpplFn;
this.sampleGuide = sampleGuide;

this.factorWarningIssued = false;
// Indicate that guide thunks should run.
this.guideRequired = sampleGuide;

this.score = 0;
this.logWeight = 0;

this.coroutine = env.coroutine;
env.coroutine = this;
}

ForwardSample.prototype = {
RunForward.prototype = {

run: function() {

var hist = new CountAggregator(this.opts.onlyMAP);
var logWeights = []; // Save total factor weights

return util.cpsLoop(
this.opts.samples,

// Loop body.
function(i, next) {
this.score = 0;
this.logWeight = 0;
return this.wpplFn(_.clone(this.s), function(s, val) {
logWeights.push(this.logWeight);
hist.add(val, this.score);
return next();
}.bind(this), this.a);
}.bind(this),

// Continuation.
function() {
env.coroutine = this.coroutine;
var dist = hist.toDist();
if (!this.opts.guide) {
var numSamples = this.opts.samples;
dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(numSamples);
}
return this.k(this.s, dist);
}.bind(this));

return this.wpplFn(_.clone(this.s), function(s, val) {
env.coroutine = this.coroutine;
var ret = {val: val, score: this.score, logWeight: this.logWeight};
return this.k(this.s, ret);
}.bind(this), this.a);
},

sample: function(s, k, a, dist, options) {
Expand All @@ -72,7 +45,7 @@ module.exports = function(env) {
return k(s, val);
}.bind(this);

if (this.opts.guide) {
if (this.sampleGuide) {
options = options || {};
return guide.getDist(
options.guide, options.noAutoGuide, dist, env, s, a,
Expand All @@ -85,27 +58,61 @@ module.exports = function(env) {
},

factor: function(s, k, a, score) {
if (!this.opts.guide && !this.factorWarningIssued) {
this.factorWarningIssued = true;
if (!this.sampleGuide) {
var msg = 'Note that factor, condition and observe statements are ' +
'ignored when forward sampling from a model.';
util.warn(msg);
'ignored when forward sampling.';
util.warn(msg, true);
}
this.logWeight += ad.value(score);
return k(s);
},

incrementalize: env.defaultCoroutine.incrementalize,
constructor: ForwardSample
constructor: RunForward

};

function runForward() {
var coroutine = Object.create(RunForward.prototype);
RunForward.apply(coroutine, arguments);
return coroutine.run();
}

function ForwardSample(s, k, a, wpplFn, options) {
var opts = util.mergeDefaults(options, {
samples: 1,
guide: false, // true = sample guide, false = sample target
onlyMAP: false,
verbose: false
});

var hist = new CountAggregator(opts.onlyMAP);
var logWeights = []; // Save total factor weights

return util.cpsLoop(
opts.samples,
// Loop body.
function(i, next) {
return runForward(s, function(s, ret) {
logWeights.push(ret.logWeight);
hist.add(ret.val, ret.score);
return next();
}, a, wpplFn, opts.guide);
},
// Continuation.
function() {
var dist = hist.toDist();
if (!opts.guide) {
dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(opts.samples);
}
return k(s, dist);
}
);
}

return {
ForwardSample: function() {
var coroutine = Object.create(ForwardSample.prototype);
ForwardSample.apply(coroutine, arguments);
return coroutine.run();
}
ForwardSample: ForwardSample,
runForward: runForward
};

};
59 changes: 41 additions & 18 deletions src/params/header.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@ function deserializeParams(s, k, a, str) {
return k(s, serialize.deserializeParams(str));
}

function defaultInit(mu, sigma) {
return function(s, k, a, dims) {
return k(s, dists.tensorGaussianSample(mu, sigma, dims));
};
}

module.exports = function(env) {

var runForward = require('../inference/forwardSample')(env).runForward;

var dimsForScalarParam = [1];

// param provides a convenient wrapper around the primitive
// params.register.
var param = function(s, k, a, options) {
options = util.mergeDefaults(options, {
mu: 0,
Expand All @@ -61,30 +67,47 @@ module.exports = function(env) {
util.warn('Warning: Parameter created outside of the guide.', true);
}

var mu = options.mu;
var sigma = options.sigma;
var dims = options.dims;
var name = _.has(options, 'name') ? options.name : util.relativizeAddress(env, a);

var val = params.register(env, name, function() {
if (params.exists(name)) {
return finish(s);
} else {
var init = _.has(options, 'init') ? options.init : defaultInit(options.mu, options.sigma);
if (!_.isFunction(init)) {
throw new Error('Expected the init argument to be a function.');
}

// Initialization.
var appliedInit = function(s, k, a) {
return init.apply(global, [s, k, a, dims]);
};

var val = new Tensor(dims);
if (sigma === 0) {
val.fill(mu);
} else {
for (var i = 0; i < val.length; i++) {
val.data[i] = dists.gaussianSample(mu, sigma);
var next = function(k, ret) {
var initialVal = ret.val;
params.create(name, initialVal);
if (!_.isEqual(dims, initialVal.dims)) {
var msg = 'The init function did not return a tensor with the expected shape.';
throw new Error(msg);
}
}
return finish(s);
};

// params.register tracks an array of parameters for each
// name/address.
return [val];
return runForward(s, next, a, appliedInit);
}

})[0];
return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val);
function finish(s) {
var val = params.fetch(name, env);
var valDims = ad.value(val).dims;
if (!_.isEqual(dims, valDims)) {
var msg = 'The dims specified here (' + JSON.stringify(dims) +
') do not match the dims of the current parameter value (' +
JSON.stringify(valDims) + '). This value may ' +
'come from an earlier call to param, or from a previous ' +
'execution when a persistent parameter store is used.';
throw new Error(msg);
}
return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val);
};
};

return {
Expand Down
Loading

0 comments on commit 3bd0d8e

Please sign in to comment.