diff --git a/docs/optimization/parameters.rst b/docs/optimization/parameters.rst index 6227f7c2..67cca0db 100644 --- a/docs/optimization/parameters.rst +++ b/docs/optimization/parameters.rst @@ -5,9 +5,8 @@ Parameters .. js:function:: param([options]) - Retrieves the value of a parameter by name. If the parameter does - not exist, it is created and initialized with a draw from a - Gaussian distribution. + Retrieves the value of a parameter by name. The parameter is + created if it does not already exist. The following options are supported: @@ -18,19 +17,29 @@ Parameters When ``dims`` is omitted, ``param`` returns a scalar. + .. describe:: init + + A function that computes the initial value of the parameter. The + function is passed the dimension of a tensor as its only + argument, and should return a tensor of that dimension. + + When ``init`` is omitted, the parameter is initialized with a + draw from the Gaussian distribution described by the ``mu`` and + ``sigma`` options. + .. describe:: mu The mean of the Gaussian distribution from which the initial - parameter value is drawn. + parameter value is drawn when ``init`` is omitted. Default: ``0`` .. describe:: sigma The standard deviation of the Gaussian distribution from which - the initial parameter value is drawn. Specify a standard - deviation of ``0`` to deterministically initialize the parameter - to ``mu``. + the initial parameter value is drawn when ``init`` is omitted. + Specify a standard deviation of ``0`` to deterministically + initialize the parameter to ``mu``. Default: ``0.1`` @@ -46,6 +55,7 @@ Parameters param({name: 'myparam'}) param({mu: 0, sigma: 0.01, name: 'myparam'}) param({dims: [10, 10]}) + param({dims: [2, 1], init: function(dims) { return ones(dims); }}) .. js:function:: modelParam([options]) diff --git a/src/guide.js b/src/guide.js index d4e127ff..3720c4ec 100644 --- a/src/guide.js +++ b/src/guide.js @@ -143,9 +143,10 @@ function makeParam(paramSpec, paramName, baseName, env) { } function registerParam(env, name, dims) { - return params.register(env, name, function() { - return [new Tensor(dims)]; - })[0]; + if (!params.exists(name)) { + params.create(name, new Tensor(dims)); + } + return params.fetch(name, env); } // This function generates a description of the guide distribution diff --git a/src/headerUtils.js b/src/headerUtils.js index 27cafc41..b738d39d 100644 --- a/src/headerUtils.js +++ b/src/headerUtils.js @@ -54,6 +54,31 @@ module.exports = function(env) { return wpplFn.apply(global, [s, k, a].concat(args)); } + function notAllowed(fn, name) { + return function() { + throw new Error(fn + ' is not allowed in ' + name + '.'); + }; + } + + function makeDeterministicCoroutine(name) { + return { + sample: notAllowed('sample', name), + factor: notAllowed('factor', name), + incrementalize: env.defaultCoroutine.incrementalize + }; + } + + // Applies a deterministic function. Attempts by wpplFn to call + // sample or factor generate an error. + function applyd(s, k, a, wpplFn, args, name) { + var coroutine = env.coroutine; + env.coroutine = makeDeterministicCoroutine(name); + return apply(s, function(s, val) { + env.coroutine = coroutine; + return k(s, val); + }, a, wpplFn, args); + } + // Annotating a function object with its lexical id and // a list of its free variable values. var __uniqueid = 0; @@ -182,6 +207,7 @@ module.exports = function(env) { display: display, cache: cache, apply: apply, + applyd: applyd, _Fn: _Fn, _addr: _addr, zeros: zeros, diff --git a/src/inference/driftKernel.js b/src/inference/driftKernel.js index c7d23201..9572cd1e 100644 --- a/src/inference/driftKernel.js +++ b/src/inference/driftKernel.js @@ -4,11 +4,7 @@ var util = require('../util'); module.exports = function(env) { - var driftKernelCoroutine = { - sample: notAllowed('sample'), - factor: notAllowed('factor'), - incrementalize: env.defaultCoroutine.incrementalize - }; + var applyd = require('../headerUtils')(env).applyd; // A cps function to get the MH proposal distribution based on the // args passed to a sample statement and the value selected for this @@ -24,26 +20,13 @@ module.exports = function(env) { function getProposalDist(s, a, dist, options, prevVal, k) { if (options && options.driftKernel) { - var coroutine = env.coroutine; - env.coroutine = driftKernelCoroutine; - - return options.driftKernel(s, function(s, val) { - // Restore the previous coroutine. - env.coroutine = coroutine; - return k(s, val); - }, a, prevVal); + return applyd(s, k, a, options.driftKernel, [prevVal], 'drift kernel'); } else { // Use the prior as the proposal distribution. return k(s, dist); } } - function notAllowed(fn) { - return function() { - throw new Error(fn + ' not allowed inside drift kernels.'); - }; - } - // We show a warning when the score of a drift proposal is -Infinity // as it's likely this is caused by a bug in the drift kernel // function. diff --git a/src/inference/forwardSample.js b/src/inference/forwardSample.js index 6b90105e..5bd6da5f 100644 --- a/src/inference/forwardSample.js +++ b/src/inference/forwardSample.js @@ -11,58 +11,31 @@ var guide = require('../guide'); module.exports = function(env) { - function ForwardSample(s, k, a, wpplFn, options) { - this.opts = util.mergeDefaults(options, { - samples: 1, - guide: false, // true = sample guide, false = sample target - onlyMAP: false, - verbose: false - }); - - this.wpplFn = wpplFn; + function RunForward(s, k, a, wpplFn, sampleGuide) { this.s = s; this.k = k; this.a = a; - this.guideRequired = this.opts.guide; + this.wpplFn = wpplFn; + this.sampleGuide = sampleGuide; - this.factorWarningIssued = false; + // Indicate that guide thunks should run. + this.guideRequired = sampleGuide; + + this.score = 0; + this.logWeight = 0; this.coroutine = env.coroutine; env.coroutine = this; } - ForwardSample.prototype = { + RunForward.prototype = { run: function() { - - var hist = new CountAggregator(this.opts.onlyMAP); - var logWeights = []; // Save total factor weights - - return util.cpsLoop( - this.opts.samples, - - // Loop body. - function(i, next) { - this.score = 0; - this.logWeight = 0; - return this.wpplFn(_.clone(this.s), function(s, val) { - logWeights.push(this.logWeight); - hist.add(val, this.score); - return next(); - }.bind(this), this.a); - }.bind(this), - - // Continuation. - function() { - env.coroutine = this.coroutine; - var dist = hist.toDist(); - if (!this.opts.guide) { - var numSamples = this.opts.samples; - dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(numSamples); - } - return this.k(this.s, dist); - }.bind(this)); - + return this.wpplFn(_.clone(this.s), function(s, val) { + env.coroutine = this.coroutine; + var ret = {val: val, score: this.score, logWeight: this.logWeight}; + return this.k(this.s, ret); + }.bind(this), this.a); }, sample: function(s, k, a, dist, options) { @@ -72,7 +45,7 @@ module.exports = function(env) { return k(s, val); }.bind(this); - if (this.opts.guide) { + if (this.sampleGuide) { options = options || {}; return guide.getDist( options.guide, options.noAutoGuide, dist, env, s, a, @@ -85,27 +58,61 @@ module.exports = function(env) { }, factor: function(s, k, a, score) { - if (!this.opts.guide && !this.factorWarningIssued) { - this.factorWarningIssued = true; + if (!this.sampleGuide) { var msg = 'Note that factor, condition and observe statements are ' + - 'ignored when forward sampling from a model.'; - util.warn(msg); + 'ignored when forward sampling.'; + util.warn(msg, true); } this.logWeight += ad.value(score); return k(s); }, incrementalize: env.defaultCoroutine.incrementalize, - constructor: ForwardSample + constructor: RunForward }; + function runForward() { + var coroutine = Object.create(RunForward.prototype); + RunForward.apply(coroutine, arguments); + return coroutine.run(); + } + + function ForwardSample(s, k, a, wpplFn, options) { + var opts = util.mergeDefaults(options, { + samples: 1, + guide: false, // true = sample guide, false = sample target + onlyMAP: false, + verbose: false + }); + + var hist = new CountAggregator(opts.onlyMAP); + var logWeights = []; // Save total factor weights + + return util.cpsLoop( + opts.samples, + // Loop body. + function(i, next) { + return runForward(s, function(s, ret) { + logWeights.push(ret.logWeight); + hist.add(ret.val, ret.score); + return next(); + }, a, wpplFn, opts.guide); + }, + // Continuation. + function() { + var dist = hist.toDist(); + if (!opts.guide) { + dist.normalizationConstant = util.logsumexp(logWeights) - Math.log(opts.samples); + } + return k(s, dist); + } + ); + } + return { - ForwardSample: function() { - var coroutine = Object.create(ForwardSample.prototype); - ForwardSample.apply(coroutine, arguments); - return coroutine.run(); - } + ForwardSample: ForwardSample, + runForward: runForward }; }; diff --git a/src/params/header.js b/src/params/header.js index e04c2889..dec671b6 100644 --- a/src/params/header.js +++ b/src/params/header.js @@ -44,12 +44,18 @@ function deserializeParams(s, k, a, str) { return k(s, serialize.deserializeParams(str)); } +function defaultInit(mu, sigma) { + return function(s, k, a, dims) { + return k(s, dists.tensorGaussianSample(mu, sigma, dims)); + }; +} + module.exports = function(env) { + var runForward = require('../inference/forwardSample')(env).runForward; + var dimsForScalarParam = [1]; - // param provides a convenient wrapper around the primitive - // params.register. var param = function(s, k, a, options) { options = util.mergeDefaults(options, { mu: 0, @@ -61,30 +67,47 @@ module.exports = function(env) { util.warn('Warning: Parameter created outside of the guide.', true); } - var mu = options.mu; - var sigma = options.sigma; var dims = options.dims; var name = _.has(options, 'name') ? options.name : util.relativizeAddress(env, a); - var val = params.register(env, name, function() { + if (params.exists(name)) { + return finish(s); + } else { + var init = _.has(options, 'init') ? options.init : defaultInit(options.mu, options.sigma); + if (!_.isFunction(init)) { + throw new Error('Expected the init argument to be a function.'); + } - // Initialization. + var appliedInit = function(s, k, a) { + return init.apply(global, [s, k, a, dims]); + }; - var val = new Tensor(dims); - if (sigma === 0) { - val.fill(mu); - } else { - for (var i = 0; i < val.length; i++) { - val.data[i] = dists.gaussianSample(mu, sigma); + var next = function(k, ret) { + var initialVal = ret.val; + params.create(name, initialVal); + if (!_.isEqual(dims, initialVal.dims)) { + var msg = 'The init function did not return a tensor with the expected shape.'; + throw new Error(msg); } - } + return finish(s); + }; - // params.register tracks an array of parameters for each - // name/address. - return [val]; + return runForward(s, next, a, appliedInit); + } - })[0]; - return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val); + function finish(s) { + var val = params.fetch(name, env); + var valDims = ad.value(val).dims; + if (!_.isEqual(dims, valDims)) { + var msg = 'The dims specified here (' + JSON.stringify(dims) + + ') do not match the dims of the current parameter value (' + + JSON.stringify(valDims) + '). This value may ' + + 'come from an earlier call to param, or from a previous ' + + 'execution when a persistent parameter store is used.'; + throw new Error(msg); + } + return k(s, dims === dimsForScalarParam ? ad.tensor.get(val, 0) : val); + }; }; return { diff --git a/src/params/params.js b/src/params/params.js index 4a408729..c2084ecd 100644 --- a/src/params/params.js +++ b/src/params/params.js @@ -4,6 +4,7 @@ var assert = require('assert'); var _ = require('lodash'); var fs = require('fs'); var ad = require('../ad'); +var util = require('../util'); var config = require('./config'); var serializeParams = require('./serialize').serializeParams; @@ -54,6 +55,9 @@ function get() { return _params; } +function exists(name) { + return _.has(_params, name); +} // Save the local parameter table to a file function save(filename) { @@ -72,55 +76,54 @@ function set(params, k) { return store.setParams(id, params, next); } +function create(name, initialVal) { + if (exists(name)) { + throw new Error('Parameter "' + name + '" already exists.'); + } + if (!util.isTensor(initialVal)) { + throw new Error('Expected an (unlifted) tensor.'); + } + var paramTable = get(); + paramTable[name] = [initialVal]; +} -function register(env, name, initParams) { +function fetch(name, env) { + if (!exists(name)) { + throw new Error('Parameter "' + name + '" does not exist.'); + } var paramTable = get(); var paramsSeen = env.coroutine.paramsSeen; - if (paramsSeen && _.has(paramsSeen, name)) { - - // We've already lifted these parameters during this execution. - // Re-use ad graph nodes. - - return paramsSeen[name]; + // If we're outside of optimization, just return the value of the + // parameter, unlifted. + if (!paramsSeen) { + return paramTable[name][0]; + } + // Otherwise we're doing optimization. + if (_.has(paramsSeen, name)) { + // Return the same AD graph node that was seen earlier this + // execution. + return paramsSeen[name][0]; } else { - - // Get parameter values from the store, or initialize if this is a - // new parameter. - var _params; - if (_.has(paramTable, name)) { - // Parameters already initialized. Fetch values from store. - _params = paramTable[name]; - } else { - // Never seen. Fetch initial values and add to store. - _params = initParams(); - assert.ok(_.every(_params, _.negate(ad.isLifted)), - 'initParams unexpectedly returned a lifted value.'); - paramTable[name] = _params; - } - - if (paramsSeen) { - // Lift parameters if the current coroutine is tracking - // parameters for optimization. - var params = _params.map(ad.lift); - paramsSeen[name] = params; - return params; - } else { - return _params; - } - + // Fetch the value and lift. Add to paramsSeen so that the + // coroutine knows to update this parameter. + var _param = paramTable[name][0]; + var param = ad.lift(_param); + paramsSeen[name] = [param]; + return param; } } - module.exports = { get: get, set: set, init: init, stop: stop, - register: register, save: save, - sync: sync + sync: sync, + exists: exists, + create: create, + fetch: fetch }; diff --git a/tests/test-data/deterministic/expected/guides.json b/tests/test-data/deterministic/expected/guides.json index 42fe74c4..5fed8656 100644 --- a/tests/test-data/deterministic/expected/guides.json +++ b/tests/test-data/deterministic/expected/guides.json @@ -1,3 +1,3 @@ { - "result": [true, true, true, true, true, true, true] + "result": [true, true, true, true, true, true, true, true] } diff --git a/tests/test-data/deterministic/models/guides.wppl b/tests/test-data/deterministic/models/guides.wppl index b4e45bf7..c1373ceb 100644 --- a/tests/test-data/deterministic/models/guides.wppl +++ b/tests/test-data/deterministic/models/guides.wppl @@ -6,6 +6,7 @@ var numParamsCreatedBy = function(thunk) { [ param({mu: 1, sigma: 0}) === 1, + param({init: ones}) === 1, T.sumreduce(param({dims: [3, 2], mu: 1, sigma: 0})) === 6, // Check (indirectly) that a guide is automatically generated, by