From 713004c17c970a354ee5a04593114d7082a2fac9 Mon Sep 17 00:00:00 2001 From: null-a Date: Tue, 21 Feb 2017 15:47:07 +0000 Subject: [PATCH 1/9] Support passing init. function to param. --- docs/optimization/parameters.rst | 24 ++++-- src/guide.js | 7 +- src/params/header.js | 47 ++++++------ src/params/params.js | 73 ++++++++++--------- .../deterministic/expected/guides.json | 2 +- .../deterministic/models/guides.wppl | 1 + 6 files changed, 87 insertions(+), 67 deletions(-) diff --git a/docs/optimization/parameters.rst b/docs/optimization/parameters.rst index 6227f7c2..67cca0db 100644 --- a/docs/optimization/parameters.rst +++ b/docs/optimization/parameters.rst @@ -5,9 +5,8 @@ Parameters .. js:function:: param([options]) - Retrieves the value of a parameter by name. If the parameter does - not exist, it is created and initialized with a draw from a - Gaussian distribution. + Retrieves the value of a parameter by name. The parameter is + created if it does not already exist. The following options are supported: @@ -18,19 +17,29 @@ Parameters When ``dims`` is omitted, ``param`` returns a scalar. + .. describe:: init + + A function that computes the initial value of the parameter. The + function is passed the dimension of a tensor as its only + argument, and should return a tensor of that dimension. + + When ``init`` is omitted, the parameter is initialized with a + draw from the Gaussian distribution described by the ``mu`` and + ``sigma`` options. + .. describe:: mu The mean of the Gaussian distribution from which the initial - parameter value is drawn. + parameter value is drawn when ``init`` is omitted. Default: ``0`` .. describe:: sigma The standard deviation of the Gaussian distribution from which - the initial parameter value is drawn. Specify a standard - deviation of ``0`` to deterministically initialize the parameter - to ``mu``. + the initial parameter value is drawn when ``init`` is omitted. + Specify a standard deviation of ``0`` to deterministically + initialize the parameter to ``mu``. Default: ``0.1`` @@ -46,6 +55,7 @@ Parameters param({name: 'myparam'}) param({mu: 0, sigma: 0.01, name: 'myparam'}) param({dims: [10, 10]}) + param({dims: [2, 1], init: function(dims) { return ones(dims); }}) .. js:function:: modelParam([options]) diff --git a/src/guide.js b/src/guide.js index d4e127ff..3720c4ec 100644 --- a/src/guide.js +++ b/src/guide.js @@ -143,9 +143,10 @@ function makeParam(paramSpec, paramName, baseName, env) { } function registerParam(env, name, dims) { - return params.register(env, name, function() { - return [new Tensor(dims)]; - })[0]; + if (!params.exists(name)) { + params.create(name, new Tensor(dims)); + } + return params.fetch(name, env); } // This function generates a description of the guide distribution diff --git a/src/params/header.js b/src/params/header.js index e04c2889..03baeea6 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -44,12 +44,16 @@ function deserializeParams(s, k, a, str) { return k(s, serialize.deserializeParams(str)); } +function defaultInit(mu, sigma) { + return function(s, k, a, dims) { + return k(s, dists.tensorGaussianSample(mu, sigma, dims)); + }; +} + module.exports = function(env) { var dimsForScalarParam = [1]; - // param provides a convenient wrapper around the primitive - // params.register. var param = function(s, k, a, options) { options = util.mergeDefaults(options, { mu: 0, @@ -61,30 +65,31 @@ module.exports = function(env) { util.warn('Warning: Parameter created outside of the guide.', true); } - var mu = options.mu; - var sigma = options.sigma; var dims = options.dims; var name = _.has(options, 'name') ? options.name : util.relativizeAddress(env, a); - var val = params.register(env, name, function() { - - // Initialization. - - var val = new Tensor(dims); - if (sigma === 0) { - val.fill(mu); - } else { - for (var i = 0; i < val.length; i++) { - val.data[i] = dists.gaussianSample(mu, sigma); - } + if (params.exists(name)) { + return finish(s); + } else { + var init = _.has(options, 'init') ? options.init : defaultInit(options.mu, options.sigma); + if (!_.isFunction(init)) { + throw new Error('Expected the init argument to be a function.'); } + return init(s, function(s, initialVal) { + params.create(name, initialVal); + if (!_.isEqual(dims, initialVal.dims)) { + var msg = 'The init function did not return a tensor with the expected shape.'; + throw new Error(msg); + } + return finish(s); + }, a, dims); + } - // params.register tracks an array of parameters for each - // name/address. - return [val]; - - })[0]; - return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val); + function finish(s) { + var val = params.fetch(name, env); + var valDims = ad.value(val).dims; + return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val); + }; }; return { diff --git a/src/params/params.js b/src/params/params.js index 4a408729..c2084ecd 100644 --- a/src/params/params.js +++ b/src/params/params.js @@ -4,6 +4,7 @@ var assert = require('assert'); var _ = require('lodash'); var fs = require('fs'); var ad = require('../ad'); +var util = require('../util'); var config = require('./config'); var serializeParams = require('./serialize').serializeParams; @@ -54,6 +55,9 @@ function get() { return _params; } +function exists(name) { + return _.has(_params, name); +} // Save the local parameter table to a file function save(filename) { @@ -72,55 +76,54 @@ function set(params, k) { return store.setParams(id, params, next); } +function create(name, initialVal) { + if (exists(name)) { + throw new Error('Parameter "' + name + '" already exists.'); + } + if (!util.isTensor(initialVal)) { + throw new Error('Expected an (unlifted) tensor.'); + } + var paramTable = get(); + paramTable[name] = [initialVal]; +} -function register(env, name, initParams) { +function fetch(name, env) { + if (!exists(name)) { + throw new Error('Parameter "' + name + '" does not exist.'); + } var paramTable = get(); var paramsSeen = env.coroutine.paramsSeen; - if (paramsSeen && _.has(paramsSeen, name)) { - - // We've already lifted these parameters during this execution. - // Re-use ad graph nodes. - - return paramsSeen[name]; + // If we're outside of optimization, just return the value of the + // parameter, unlifted. + if (!paramsSeen) { + return paramTable[name][0]; + } + // Otherwise we're doing optimization. + if (_.has(paramsSeen, name)) { + // Return the same AD graph node that was seen earlier this + // execution. + return paramsSeen[name][0]; } else { - - // Get parameter values from the store, or initialize if this is a - // new parameter. - var _params; - if (_.has(paramTable, name)) { - // Parameters already initialized. Fetch values from store. - _params = paramTable[name]; - } else { - // Never seen. Fetch initial values and add to store. - _params = initParams(); - assert.ok(_.every(_params, _.negate(ad.isLifted)), - 'initParams unexpectedly returned a lifted value.'); - paramTable[name] = _params; - } - - if (paramsSeen) { - // Lift parameters if the current coroutine is tracking - // parameters for optimization. - var params = _params.map(ad.lift); - paramsSeen[name] = params; - return params; - } else { - return _params; - } - + // Fetch the value and lift. Add to paramsSeen so that the + // coroutine knows to update this parameter. + var _param = paramTable[name][0]; + var param = ad.lift(_param); + paramsSeen[name] = [param]; + return param; } } - module.exports = { get: get, set: set, init: init, stop: stop, - register: register, save: save, - sync: sync + sync: sync, + exists: exists, + create: create, + fetch: fetch }; diff --git a/tests/test-data/deterministic/expected/guides.json b/tests/test-data/deterministic/expected/guides.json index 42fe74c4..5fed8656 100644 --- a/tests/test-data/deterministic/expected/guides.json +++ b/tests/test-data/deterministic/expected/guides.json @@ -1,3 +1,3 @@ { - "result": [true, true, true, true, true, true, true] + "result": [true, true, true, true, true, true, true, true] } diff --git a/tests/test-data/deterministic/models/guides.wppl b/tests/test-data/deterministic/models/guides.wppl index b4e45bf7..c1373ceb 100644 --- a/tests/test-data/deterministic/models/guides.wppl +++ b/tests/test-data/deterministic/models/guides.wppl @@ -6,6 +6,7 @@ var numParamsCreatedBy = function(thunk) { [ param({mu: 1, sigma: 0}) === 1, + param({init: ones}) === 1, T.sumreduce(param({dims: [3, 2], mu: 1, sigma: 0})) === 6, // Check (indirectly) that a guide is automatically generated, by From 6133cd534819cd9872b795aa12a994ed683171d9 Mon Sep 17 00:00:00 2001 From: null-a Date: Tue, 21 Feb 2017 16:03:34 +0000 Subject: [PATCH 2/9] Check for dimension mismatch in param. --- src/params/header.js | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/params/header.js b/src/params/header.js index 03baeea6..c7d5fc78 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -88,6 +88,14 @@ module.exports = function(env) { function finish(s) { var val = params.fetch(name, env); var valDims = ad.value(val).dims; + if (!_.isEqual(dims, valDims)) { + var msg = 'The dims specified here (' + JSON.stringify(dims) + + ') do not match the dims of the current value (' + + JSON.stringify(valDims) + '). The current value may ' + + 'come from an earlier call to param, or from a previous ' + + 'execution when a persistent parameter store is used.'; + throw new Error(msg); + } return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val); }; }; From d27b1afc3481139e933c4978c538d5f832ab7f0b Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 24 Feb 2017 11:37:00 +0000 Subject: [PATCH 3/9] Don't allow sample/factor during parameter init. --- src/headerUtils.js | 26 ++++++++++++++++++++++++++ src/params/header.js | 6 ++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/headerUtils.js b/src/headerUtils.js index 27cafc41..b738d39d 100644 --- a/src/headerUtils.js +++ b/src/headerUtils.js @@ -54,6 +54,31 @@ module.exports = function(env) { return wpplFn.apply(global, [s, k, a].concat(args)); } + function notAllowed(fn, name) { + return function() { + throw new Error(fn + ' is not allowed in ' + name + '.'); + }; + } + + function makeDeterministicCoroutine(name) { + return { + sample: notAllowed('sample', name), + factor: notAllowed('factor', name), + incrementalize: env.defaultCoroutine.incrementalize + }; + } + + // Applies a deterministic function. Attempts by wpplFn to call + // sample or factor generate an error. + function applyd(s, k, a, wpplFn, args, name) { + var coroutine = env.coroutine; + env.coroutine = makeDeterministicCoroutine(name); + return apply(s, function(s, val) { + env.coroutine = coroutine; + return k(s, val); + }, a, wpplFn, args); + } + // Annotating a function object with its lexical id and // a list of its free variable values. var __uniqueid = 0; @@ -182,6 +207,7 @@ module.exports = function(env) { display: display, cache: cache, apply: apply, + applyd: applyd, _Fn: _Fn, _addr: _addr, zeros: zeros, diff --git a/src/params/header.js b/src/params/header.js index c7d5fc78..5597930a 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -52,6 +52,8 @@ function defaultInit(mu, sigma) { module.exports = function(env) { + var applyd = require('../headerUtils')(env).applyd; + var dimsForScalarParam = [1]; var param = function(s, k, a, options) { @@ -75,14 +77,14 @@ module.exports = function(env) { if (!_.isFunction(init)) { throw new Error('Expected the init argument to be a function.'); } - return init(s, function(s, initialVal) { + return applyd(s, function(s, initialVal) { params.create(name, initialVal); if (!_.isEqual(dims, initialVal.dims)) { var msg = 'The init function did not return a tensor with the expected shape.'; throw new Error(msg); } return finish(s); - }, a, dims); + }, a, init, [dims], 'parameter initialization'); } function finish(s) { From 4ab9c7f7bc9b0f6dfdcf29853552b81ec413bcd3 Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 24 Feb 2017 13:39:08 +0000 Subject: [PATCH 4/9] Use applyd for drift kernels. --- src/inference/driftKernel.js | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/inference/driftKernel.js b/src/inference/driftKernel.js index c7d23201..9572cd1e 100644 --- a/src/inference/driftKernel.js +++ b/src/inference/driftKernel.js @@ -4,11 +4,7 @@ var util = require('../util'); module.exports = function(env) { - var driftKernelCoroutine = { - sample: notAllowed('sample'), - factor: notAllowed('factor'), - incrementalize: env.defaultCoroutine.incrementalize - }; + var applyd = require('../headerUtils')(env).applyd; // A cps function to get the MH proposal distribution based on the // args passed to a sample statement and the value selected for this @@ -24,26 +20,13 @@ module.exports = function(env) { function getProposalDist(s, a, dist, options, prevVal, k) { if (options && options.driftKernel) { - var coroutine = env.coroutine; - env.coroutine = driftKernelCoroutine; - - return options.driftKernel(s, function(s, val) { - // Restore the previous coroutine. - env.coroutine = coroutine; - return k(s, val); - }, a, prevVal); + return applyd(s, k, a, options.driftKernel, [prevVal], 'drift kernel'); } else { // Use the prior as the proposal distribution. return k(s, dist); } } - function notAllowed(fn) { - return function() { - throw new Error(fn + ' not allowed inside drift kernels.'); - }; - } - // We show a warning when the score of a drift proposal is -Infinity // as it's likely this is caused by a bug in the drift kernel // function. From 7c4aec08dc2bcfac757aba64cb8554137efdcaee Mon Sep 17 00:00:00 2001 From: null-a Date: Mon, 27 Feb 2017 08:55:25 +0000 Subject: [PATCH 5/9] Mention how to implement random init. --- docs/optimization/parameters.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/optimization/parameters.rst b/docs/optimization/parameters.rst index 67cca0db..3bdacc18 100644 --- a/docs/optimization/parameters.rst +++ b/docs/optimization/parameters.rst @@ -27,6 +27,12 @@ Parameters draw from the Gaussian distribution described by the ``mu`` and ``sigma`` options. + Calling ``sample(dist)`` from an initialization function is not + supported, and will generate a run time error. Random + initialization strategies should instead be implemented in terms + of ``dist.sample()``. (Where ``dist`` is a :ref:`distribution + object `.) + .. describe:: mu The mean of the Gaussian distribution from which the initial From 74a74d86745e669f0f95f11845a498492e1169fc Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 3 Mar 2017 13:36:50 +0000 Subject: [PATCH 6/9] Refactor ForwardSample. Allowing the coroutine to be used without collecting return values in a distribution. --- src/inference/forwardSample.js | 109 ++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/src/inference/forwardSample.js b/src/inference/forwardSample.js index 6b90105e..e0075788 100644 --- a/src/inference/forwardSample.js +++ b/src/inference/forwardSample.js @@ -11,58 +11,31 @@ var guide = require('../guide'); module.exports = function(env) { - function ForwardSample(s, k, a, wpplFn, options) { - this.opts = util.mergeDefaults(options, { - samples: 1, - guide: false, // true = sample guide, false = sample target - onlyMAP: false, - verbose: false - }); - - this.wpplFn = wpplFn; + function RunForward(s, k, a, wpplFn, sampleGuide) { this.s = s; this.k = k; this.a = a; - this.guideRequired = this.opts.guide; + this.wpplFn = wpplFn; + this.sampleGuide = sampleGuide; - this.factorWarningIssued = false; + // Indicate that guide thunks should run. + this.guideRequired = sampleGuide; + + this.score = 0; + this.logWeight = 0; this.coroutine = env.coroutine; env.coroutine = this; } - ForwardSample.prototype = { + RunForward.prototype = { run: function() { - - var hist = new CountAggregator(this.opts.onlyMAP); - var logWeights = []; // Save total factor weights - - return util.cpsLoop( - this.opts.samples, - - // Loop body. - function(i, next) { - this.score = 0; - this.logWeight = 0; - return this.wpplFn(_.clone(this.s), function(s, val) { - logWeights.push(this.logWeight); - hist.add(val, this.score); - return next(); - }.bind(this), this.a); - }.bind(this), - - // Continuation. - function() { - env.coroutine = this.coroutine; - var dist = hist.toDist(); - if (!this.opts.guide) { - var numSamples = this.opts.samples; - dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(numSamples); - } - return this.k(this.s, dist); - }.bind(this)); - + return this.wpplFn(_.clone(this.s), function(s, val) { + env.coroutine = this.coroutine; + var ret = {val: val, score: this.score, logWeight: this.logWeight}; + return this.k(this.s, ret); + }.bind(this), this.a); }, sample: function(s, k, a, dist, options) { @@ -72,7 +45,7 @@ module.exports = function(env) { return k(s, val); }.bind(this); - if (this.opts.guide) { + if (this.sampleGuide) { options = options || {}; return guide.getDist( options.guide, options.noAutoGuide, dist, env, s, a, @@ -85,27 +58,61 @@ module.exports = function(env) { }, factor: function(s, k, a, score) { - if (!this.opts.guide && !this.factorWarningIssued) { - this.factorWarningIssued = true; + if (!this.sampleGuide) { var msg = 'Note that factor, condition and observe statements are ' + 'ignored when forward sampling from a model.'; - util.warn(msg); + util.warn(msg, true); } this.logWeight += ad.value(score); return k(s); }, incrementalize: env.defaultCoroutine.incrementalize, - constructor: ForwardSample + constructor: RunForward }; + function runForward() { + var coroutine = Object.create(RunForward.prototype); + RunForward.apply(coroutine, arguments); + return coroutine.run(); + } + + function ForwardSample(s, k, a, wpplFn, options) { + var opts = util.mergeDefaults(options, { + samples: 1, + guide: false, // true = sample guide, false = sample target + onlyMAP: false, + verbose: false + }); + + var hist = new CountAggregator(opts.onlyMAP); + var logWeights = []; // Save total factor weights + + return util.cpsLoop( + opts.samples, + // Loop body. + function(i, next) { + return runForward(s, function(s, ret) { + logWeights.push(ret.logWeight); + hist.add(ret.val, ret.score); + return next(); + }, a, wpplFn, opts.guide); + }, + // Continuation. + function() { + var dist = hist.toDist(); + if (!opts.guide) { + dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(opts.samples); + } + return k(s, dist); + } + ); + } + return { - ForwardSample: function() { - var coroutine = Object.create(ForwardSample.prototype); - ForwardSample.apply(coroutine, arguments); - return coroutine.run(); - } + ForwardSample: ForwardSample, + runForward: runForward }; }; From 0f14add134929b2bf943d1a4bf122c0719946067 Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 3 Mar 2017 13:37:08 +0000 Subject: [PATCH 7/9] Run parameter init. in forward sampling coroutine. This allows `sample(dist)` to be used for random initialization, rather than `dist.sample()`. --- src/inference/forwardSample.js | 2 +- src/params/header.js | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/inference/forwardSample.js b/src/inference/forwardSample.js index e0075788..5bd6da5f 100644 --- a/src/inference/forwardSample.js +++ b/src/inference/forwardSample.js @@ -60,7 +60,7 @@ module.exports = function(env) { factor: function(s, k, a, score) { if (!this.sampleGuide) { var msg = 'Note that factor, condition and observe statements are ' + - 'ignored when forward sampling from a model.'; + 'ignored when forward sampling.'; util.warn(msg, true); } this.logWeight += ad.value(score); diff --git a/src/params/header.js b/src/params/header.js index 5597930a..cfa9441d 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -52,7 +52,7 @@ function defaultInit(mu, sigma) { module.exports = function(env) { - var applyd = require('../headerUtils')(env).applyd; + var runForward = require('../inference/forwardSample')(env).runForward; var dimsForScalarParam = [1]; @@ -77,14 +77,22 @@ module.exports = function(env) { if (!_.isFunction(init)) { throw new Error('Expected the init argument to be a function.'); } - return applyd(s, function(s, initialVal) { + + var appliedInit = function(s, k, a) { + return init.apply(global, [s, k, a, dims]); + }; + + var next = function(k, ret) { + var initialVal = ret.val; params.create(name, initialVal); if (!_.isEqual(dims, initialVal.dims)) { var msg = 'The init function did not return a tensor with the expected shape.'; throw new Error(msg); } return finish(s); - }, a, init, [dims], 'parameter initialization'); + }; + + return runForward(s, next, a, appliedInit); } function finish(s) { From bb56b2e3cca84c937181532cf80f5c238889feb6 Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 3 Mar 2017 13:50:13 +0000 Subject: [PATCH 8/9] Revert "Mention how to implement random init." This reverts commit 7c4aec08dc2bcfac757aba64cb8554137efdcaee. --- docs/optimization/parameters.rst | 6 ------ 1 file changed, 6 deletions(-) diff --git a/docs/optimization/parameters.rst b/docs/optimization/parameters.rst index 3bdacc18..67cca0db 100644 --- a/docs/optimization/parameters.rst +++ b/docs/optimization/parameters.rst @@ -27,12 +27,6 @@ Parameters draw from the Gaussian distribution described by the ``mu`` and ``sigma`` options. - Calling ``sample(dist)`` from an initialization function is not - supported, and will generate a run time error. Random - initialization strategies should instead be implemented in terms - of ``dist.sample()``. (Where ``dist`` is a :ref:`distribution - object `.) - .. describe:: mu The mean of the Gaussian distribution from which the initial From 35526bee888f56ea238920444c82bd72a1c05a52 Mon Sep 17 00:00:00 2001 From: null-a Date: Fri, 3 Mar 2017 14:01:30 +0000 Subject: [PATCH 9/9] Tweak error text. --- src/params/header.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/params/header.js b/src/params/header.js index cfa9441d..dec671b6 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -100,8 +100,8 @@ module.exports = function(env) { var valDims = ad.value(val).dims; if (!_.isEqual(dims, valDims)) { var msg = 'The dims specified here (' + JSON.stringify(dims) + - ') do not match the dims of the current value (' + - JSON.stringify(valDims) + '). The current value may ' + + ') do not match the dims of the current parameter value (' + + JSON.stringify(valDims) + '). This value may ' + 'come from an earlier call to param, or from a previous ' + 'execution when a persistent parameter store is used.'; throw new Error(msg);