Skip to content

Commit

Permalink
fix: make derivative much faster (#3322)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulftw authored Nov 28, 2024
1 parent ae93f07 commit e0d56d2
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 84 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,6 @@ gauravchawhan <[email protected]>
Akki <[email protected]>
Neeraj Kumawat <[email protected]>
Emmanuel Ferdman <[email protected]>
Paul K <[email protected]>

# Generated by tools/update-authors.js
149 changes: 68 additions & 81 deletions src/function/algebra/derivative.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,19 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
* @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr`
*/
function plainDerivative (expr, variable, options = { simplify: true }) {
const constNodes = {}
constTag(constNodes, expr, variable.name)
const res = _derivative(expr, constNodes)
const cache = new Map()
const variableName = variable.name
function isConstCached (node) {
const cached = cache.get(node)
if (cached !== undefined) {
return cached
}
const res = _isConst(isConstCached, node, variableName)
cache.set(node, res)
return res
}

const res = _derivative(expr, isConstCached)
return options.simplify ? simplify(res) : res
}

Expand All @@ -96,9 +106,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
'Node, SymbolNode, ConstantNode': function (expr, variable, {order}) {
let res = expr
for (let i = 0; i < order; i++) {
let constNodes = {}
constTag(constNodes, expr, variable.name)
res = _derivative(res, constNodes)
<create caching isConst>
res = _derivative(res, isConst)
}
return res
}
Expand Down Expand Up @@ -143,96 +152,78 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
})

/**
* Does a depth-first search on the expression tree to identify what Nodes
* are constants (e.g. 2 + 2), and stores the ones that are constants in
* constNodes. Classification is done as follows:
* Checks if a node is constants (e.g. 2 + 2).
* Accepts (usually memoized) version of self as the first parameter for recursive calls.
* Classification is done as follows:
*
* 1. ConstantNodes are constants.
* 2. If there exists a SymbolNode, of which we are differentiating over,
* in the subtree it is not constant.
*
* @param {Object} constNodes Holds the nodes that are constant
* @param {function} isConst Function that tells whether sub-expression is a constant
* @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node
* @param {string} varName Variable that we are differentiating
* @return {boolean} if node is constant
*/
// TODO: can we rewrite constTag into a pure function?
const constTag = typed('constTag', {
'Object, ConstantNode, string': function (constNodes, node) {
constNodes[node] = true
const _isConst = typed('_isConst', {
'function, ConstantNode, string': function () {
return true
},

'Object, SymbolNode, string': function (constNodes, node, varName) {
'function, SymbolNode, string': function (isConst, node, varName) {
// Treat other variables like constants. For reasoning, see:
// https://en.wikipedia.org/wiki/Partial_derivative
if (node.name !== varName) {
constNodes[node] = true
return true
}
return false
return node.name !== varName
},

'Object, ParenthesisNode, string': function (constNodes, node, varName) {
return constTag(constNodes, node.content, varName)
'function, ParenthesisNode, string': function (isConst, node, varName) {
return isConst(node.content, varName)
},

'Object, FunctionAssignmentNode, string': function (constNodes, node, varName) {
'function, FunctionAssignmentNode, string': function (isConst, node, varName) {
if (!node.params.includes(varName)) {
constNodes[node] = true
return true
}
return constTag(constNodes, node.expr, varName)
return isConst(node.expr, varName)
},

'Object, FunctionNode | OperatorNode, string': function (constNodes, node, varName) {
if (node.args.length > 0) {
let isConst = constTag(constNodes, node.args[0], varName)
for (let i = 1; i < node.args.length; ++i) {
isConst = constTag(constNodes, node.args[i], varName) && isConst
}

if (isConst) {
constNodes[node] = true
return true
}
}
return false
'function, FunctionNode | OperatorNode, string': function (isConst, node, varName) {
return node.args.every(arg => isConst(arg, varName))
}
})

/**
* Applies differentiation rules.
*
* @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node
* @param {Object} constNodes Holds the nodes that are constant
* @param {function} isConst Function that tells if a node is constant
* @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr`
*/
const _derivative = typed('_derivative', {
'ConstantNode, Object': function (node) {
'ConstantNode, function': function () {
return createConstantNode(0)
},

'SymbolNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'SymbolNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}
return createConstantNode(1)
},

'ParenthesisNode, Object': function (node, constNodes) {
return new ParenthesisNode(_derivative(node.content, constNodes))
'ParenthesisNode, function': function (node, isConst) {
return new ParenthesisNode(_derivative(node.content, isConst))
},

'FunctionAssignmentNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'FunctionAssignmentNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}
return _derivative(node.expr, constNodes)
return _derivative(node.expr, isConst)
},

'FunctionNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'FunctionNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}

Expand Down Expand Up @@ -274,10 +265,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
node.args[1]
])

// Is a variable?
constNodes[arg1] = constNodes[node.args[1]]

return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), constNodes)
return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), isConst)
}
break
case 'log10':
Expand All @@ -289,7 +277,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
funcDerivative = arg0.clone()
div = true
} else if ((node.args.length === 1 && arg1) ||
(node.args.length === 2 && constNodes[node.args[1]] !== undefined)) {
(node.args.length === 2 && isConst(node.args[1]))) {
// d/dx(log(x, c)) = 1 / (x*ln(c))
funcDerivative = new OperatorNode('*', 'multiply', [
arg0.clone(),
Expand All @@ -301,14 +289,13 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return _derivative(new OperatorNode('/', 'divide', [
new FunctionNode('log', [arg0]),
new FunctionNode('log', [node.args[1]])
]), constNodes)
]), isConst)
}
break
case 'pow':
if (node.args.length === 2) {
constNodes[arg1] = constNodes[node.args[1]]
// Pass to pow operator node parser
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), constNodes)
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), isConst)
}
break
case 'exp':
Expand Down Expand Up @@ -585,58 +572,58 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
/* Apply chain rule to all functions:
F(x) = f(g(x))
F'(x) = g'(x)*f'(g(x)) */
let chainDerivative = _derivative(arg0, constNodes)
let chainDerivative = _derivative(arg0, isConst)
if (negative) {
chainDerivative = new OperatorNode('-', 'unaryMinus', [chainDerivative])
}
return new OperatorNode(op, func, [chainDerivative, funcDerivative])
},

'OperatorNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'OperatorNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}

if (node.op === '+') {
// d/dx(sum(f(x)) = sum(f'(x))
return new OperatorNode(node.op, node.fn, node.args.map(function (arg) {
return _derivative(arg, constNodes)
return _derivative(arg, isConst)
}))
}

if (node.op === '-') {
// d/dx(+/-f(x)) = +/-f'(x)
if (node.isUnary()) {
return new OperatorNode(node.op, node.fn, [
_derivative(node.args[0], constNodes)
_derivative(node.args[0], isConst)
])
}

// Linearity of differentiation, d/dx(f(x) +/- g(x)) = f'(x) +/- g'(x)
if (node.isBinary()) {
return new OperatorNode(node.op, node.fn, [
_derivative(node.args[0], constNodes),
_derivative(node.args[1], constNodes)
_derivative(node.args[0], isConst),
_derivative(node.args[1], isConst)
])
}
}

if (node.op === '*') {
// d/dx(c*f(x)) = c*f'(x)
const constantTerms = node.args.filter(function (arg) {
return constNodes[arg] !== undefined
return isConst(arg)
})

if (constantTerms.length > 0) {
const nonConstantTerms = node.args.filter(function (arg) {
return constNodes[arg] === undefined
return !isConst(arg)
})

const nonConstantNode = nonConstantTerms.length === 1
? nonConstantTerms[0]
: new OperatorNode('*', 'multiply', nonConstantTerms)

const newArgs = constantTerms.concat(_derivative(nonConstantNode, constNodes))
const newArgs = constantTerms.concat(_derivative(nonConstantNode, isConst))

return new OperatorNode('*', 'multiply', newArgs)
}
Expand All @@ -645,7 +632,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return new OperatorNode('+', 'add', node.args.map(function (argOuter) {
return new OperatorNode('*', 'multiply', node.args.map(function (argInner) {
return (argInner === argOuter)
? _derivative(argInner, constNodes)
? _derivative(argInner, isConst)
: argInner.clone()
}))
}))
Expand All @@ -656,16 +643,16 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
const arg1 = node.args[1]

// d/dx(f(x) / c) = f'(x) / c
if (constNodes[arg1] !== undefined) {
return new OperatorNode('/', 'divide', [_derivative(arg0, constNodes), arg1])
if (isConst(arg1)) {
return new OperatorNode('/', 'divide', [_derivative(arg0, isConst), arg1])
}

// Reciprocal Rule, d/dx(c / f(x)) = -c(f'(x)/f(x)^2)
if (constNodes[arg0] !== undefined) {
if (isConst(arg0)) {
return new OperatorNode('*', 'multiply', [
new OperatorNode('-', 'unaryMinus', [arg0]),
new OperatorNode('/', 'divide', [
_derivative(arg1, constNodes),
_derivative(arg1, isConst),
new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)])
])
])
Expand All @@ -674,8 +661,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
// Quotient rule, d/dx(f(x) / g(x)) = (f'(x)g(x) - f(x)g'(x)) / g(x)^2
return new OperatorNode('/', 'divide', [
new OperatorNode('-', 'subtract', [
new OperatorNode('*', 'multiply', [_derivative(arg0, constNodes), arg1.clone()]),
new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, constNodes)])
new OperatorNode('*', 'multiply', [_derivative(arg0, isConst), arg1.clone()]),
new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, isConst)])
]),
new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)])
])
Expand All @@ -685,7 +672,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
const arg0 = node.args[0]
const arg1 = node.args[1]

if (constNodes[arg0] !== undefined) {
if (isConst(arg0)) {
// If is secretly constant; 0^f(x) = 1 (in JS), 1^f(x) = 1
if (isConstantNode(arg0) && (isZero(arg0.value) || equal(arg0.value, 1))) {
return createConstantNode(0)
Expand All @@ -696,20 +683,20 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
node,
new OperatorNode('*', 'multiply', [
new FunctionNode('log', [arg0.clone()]),
_derivative(arg1.clone(), constNodes)
_derivative(arg1.clone(), isConst)
])
])
}

if (constNodes[arg1] !== undefined) {
if (isConst(arg1)) {
if (isConstantNode(arg1)) {
// If is secretly constant; f(x)^0 = 1 -> d/dx(1) = 0
if (isZero(arg1.value)) {
return createConstantNode(0)
}
// Ignore exponent; f(x)^1 = f(x)
if (equal(arg1.value, 1)) {
return _derivative(arg0, constNodes)
return _derivative(arg0, isConst)
}
}

Expand All @@ -725,7 +712,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return new OperatorNode('*', 'multiply', [
arg1.clone(),
new OperatorNode('*', 'multiply', [
_derivative(arg0, constNodes),
_derivative(arg0, isConst),
powMinusOne
])
])
Expand All @@ -736,11 +723,11 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
new OperatorNode('^', 'pow', [arg0.clone(), arg1.clone()]),
new OperatorNode('+', 'add', [
new OperatorNode('*', 'multiply', [
_derivative(arg0, constNodes),
_derivative(arg0, isConst),
new OperatorNode('/', 'divide', [arg1.clone(), arg0.clone()])
]),
new OperatorNode('*', 'multiply', [
_derivative(arg1, constNodes),
_derivative(arg1, isConst),
new FunctionNode('log', [arg0.clone()])
])
])
Expand Down
Loading

0 comments on commit e0d56d2

Please sign in to comment.