From e0d56d2da766aa23bb96e7cc1b1cd4d481aa3105 Mon Sep 17 00:00:00 2001 From: Paul Korzhyk Date: Thu, 28 Nov 2024 14:37:42 +0200 Subject: [PATCH] fix: make derivative much faster (#3322) --- AUTHORS | 1 + src/function/algebra/derivative.js | 149 ++++++++---------- test/benchmark/derivative.js | 32 ++++ test/benchmark/index.js | 1 + .../function/algebra/derivative.test.js | 6 +- 5 files changed, 105 insertions(+), 84 deletions(-) create mode 100644 test/benchmark/derivative.js diff --git a/AUTHORS b/AUTHORS index 4f5059f986..13f40f4cdd 100644 --- a/AUTHORS +++ b/AUTHORS @@ -257,5 +257,6 @@ gauravchawhan <117282267+gauravchawhan@users.noreply.github.com> Akki <63336443+Aakash-Rana@users.noreply.github.com> Neeraj Kumawat <42885320+nkumawat34@users.noreply.github.com> Emmanuel Ferdman +Paul K # Generated by tools/update-authors.js diff --git a/src/function/algebra/derivative.js b/src/function/algebra/derivative.js index aad2f323b3..119f3facd9 100644 --- a/src/function/algebra/derivative.js +++ b/src/function/algebra/derivative.js @@ -71,9 +71,19 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ * @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr` */ function plainDerivative (expr, variable, options = { simplify: true }) { - const constNodes = {} - constTag(constNodes, expr, variable.name) - const res = _derivative(expr, constNodes) + const cache = new Map() + const variableName = variable.name + function isConstCached (node) { + const cached = cache.get(node) + if (cached !== undefined) { + return cached + } + const res = _isConst(isConstCached, node, variableName) + cache.set(node, res) + return res + } + + const res = _derivative(expr, isConstCached) return options.simplify ? simplify(res) : res } @@ -96,9 +106,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ 'Node, SymbolNode, ConstantNode': function (expr, variable, {order}) { let res = expr for (let i = 0; i < order; i++) { - let constNodes = {} - constTag(constNodes, expr, variable.name) - res = _derivative(res, constNodes) + + res = _derivative(res, isConst) } return res } @@ -143,61 +152,43 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ }) /** - * Does a depth-first search on the expression tree to identify what Nodes - * are constants (e.g. 2 + 2), and stores the ones that are constants in - * constNodes. Classification is done as follows: + * Checks if a node is constants (e.g. 2 + 2). + * Accepts (usually memoized) version of self as the first parameter for recursive calls. + * Classification is done as follows: * * 1. ConstantNodes are constants. * 2. If there exists a SymbolNode, of which we are differentiating over, * in the subtree it is not constant. * - * @param {Object} constNodes Holds the nodes that are constant + * @param {function} isConst Function that tells whether sub-expression is a constant * @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node * @param {string} varName Variable that we are differentiating * @return {boolean} if node is constant */ - // TODO: can we rewrite constTag into a pure function? - const constTag = typed('constTag', { - 'Object, ConstantNode, string': function (constNodes, node) { - constNodes[node] = true + const _isConst = typed('_isConst', { + 'function, ConstantNode, string': function () { return true }, - 'Object, SymbolNode, string': function (constNodes, node, varName) { + 'function, SymbolNode, string': function (isConst, node, varName) { // Treat other variables like constants. For reasoning, see: // https://en.wikipedia.org/wiki/Partial_derivative - if (node.name !== varName) { - constNodes[node] = true - return true - } - return false + return node.name !== varName }, - 'Object, ParenthesisNode, string': function (constNodes, node, varName) { - return constTag(constNodes, node.content, varName) + 'function, ParenthesisNode, string': function (isConst, node, varName) { + return isConst(node.content, varName) }, - 'Object, FunctionAssignmentNode, string': function (constNodes, node, varName) { + 'function, FunctionAssignmentNode, string': function (isConst, node, varName) { if (!node.params.includes(varName)) { - constNodes[node] = true return true } - return constTag(constNodes, node.expr, varName) + return isConst(node.expr, varName) }, - 'Object, FunctionNode | OperatorNode, string': function (constNodes, node, varName) { - if (node.args.length > 0) { - let isConst = constTag(constNodes, node.args[0], varName) - for (let i = 1; i < node.args.length; ++i) { - isConst = constTag(constNodes, node.args[i], varName) && isConst - } - - if (isConst) { - constNodes[node] = true - return true - } - } - return false + 'function, FunctionNode | OperatorNode, string': function (isConst, node, varName) { + return node.args.every(arg => isConst(arg, varName)) } }) @@ -205,34 +196,34 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ * Applies differentiation rules. * * @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node - * @param {Object} constNodes Holds the nodes that are constant + * @param {function} isConst Function that tells if a node is constant * @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr` */ const _derivative = typed('_derivative', { - 'ConstantNode, Object': function (node) { + 'ConstantNode, function': function () { return createConstantNode(0) }, - 'SymbolNode, Object': function (node, constNodes) { - if (constNodes[node] !== undefined) { + 'SymbolNode, function': function (node, isConst) { + if (isConst(node)) { return createConstantNode(0) } return createConstantNode(1) }, - 'ParenthesisNode, Object': function (node, constNodes) { - return new ParenthesisNode(_derivative(node.content, constNodes)) + 'ParenthesisNode, function': function (node, isConst) { + return new ParenthesisNode(_derivative(node.content, isConst)) }, - 'FunctionAssignmentNode, Object': function (node, constNodes) { - if (constNodes[node] !== undefined) { + 'FunctionAssignmentNode, function': function (node, isConst) { + if (isConst(node)) { return createConstantNode(0) } - return _derivative(node.expr, constNodes) + return _derivative(node.expr, isConst) }, - 'FunctionNode, Object': function (node, constNodes) { - if (constNodes[node] !== undefined) { + 'FunctionNode, function': function (node, isConst) { + if (isConst(node)) { return createConstantNode(0) } @@ -274,10 +265,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ node.args[1] ]) - // Is a variable? - constNodes[arg1] = constNodes[node.args[1]] - - return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), constNodes) + return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), isConst) } break case 'log10': @@ -289,7 +277,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ funcDerivative = arg0.clone() div = true } else if ((node.args.length === 1 && arg1) || - (node.args.length === 2 && constNodes[node.args[1]] !== undefined)) { + (node.args.length === 2 && isConst(node.args[1]))) { // d/dx(log(x, c)) = 1 / (x*ln(c)) funcDerivative = new OperatorNode('*', 'multiply', [ arg0.clone(), @@ -301,14 +289,13 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ return _derivative(new OperatorNode('/', 'divide', [ new FunctionNode('log', [arg0]), new FunctionNode('log', [node.args[1]]) - ]), constNodes) + ]), isConst) } break case 'pow': if (node.args.length === 2) { - constNodes[arg1] = constNodes[node.args[1]] // Pass to pow operator node parser - return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), constNodes) + return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), isConst) } break case 'exp': @@ -585,22 +572,22 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ /* Apply chain rule to all functions: F(x) = f(g(x)) F'(x) = g'(x)*f'(g(x)) */ - let chainDerivative = _derivative(arg0, constNodes) + let chainDerivative = _derivative(arg0, isConst) if (negative) { chainDerivative = new OperatorNode('-', 'unaryMinus', [chainDerivative]) } return new OperatorNode(op, func, [chainDerivative, funcDerivative]) }, - 'OperatorNode, Object': function (node, constNodes) { - if (constNodes[node] !== undefined) { + 'OperatorNode, function': function (node, isConst) { + if (isConst(node)) { return createConstantNode(0) } if (node.op === '+') { // d/dx(sum(f(x)) = sum(f'(x)) return new OperatorNode(node.op, node.fn, node.args.map(function (arg) { - return _derivative(arg, constNodes) + return _derivative(arg, isConst) })) } @@ -608,15 +595,15 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ // d/dx(+/-f(x)) = +/-f'(x) if (node.isUnary()) { return new OperatorNode(node.op, node.fn, [ - _derivative(node.args[0], constNodes) + _derivative(node.args[0], isConst) ]) } // Linearity of differentiation, d/dx(f(x) +/- g(x)) = f'(x) +/- g'(x) if (node.isBinary()) { return new OperatorNode(node.op, node.fn, [ - _derivative(node.args[0], constNodes), - _derivative(node.args[1], constNodes) + _derivative(node.args[0], isConst), + _derivative(node.args[1], isConst) ]) } } @@ -624,19 +611,19 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ if (node.op === '*') { // d/dx(c*f(x)) = c*f'(x) const constantTerms = node.args.filter(function (arg) { - return constNodes[arg] !== undefined + return isConst(arg) }) if (constantTerms.length > 0) { const nonConstantTerms = node.args.filter(function (arg) { - return constNodes[arg] === undefined + return !isConst(arg) }) const nonConstantNode = nonConstantTerms.length === 1 ? nonConstantTerms[0] : new OperatorNode('*', 'multiply', nonConstantTerms) - const newArgs = constantTerms.concat(_derivative(nonConstantNode, constNodes)) + const newArgs = constantTerms.concat(_derivative(nonConstantNode, isConst)) return new OperatorNode('*', 'multiply', newArgs) } @@ -645,7 +632,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ return new OperatorNode('+', 'add', node.args.map(function (argOuter) { return new OperatorNode('*', 'multiply', node.args.map(function (argInner) { return (argInner === argOuter) - ? _derivative(argInner, constNodes) + ? _derivative(argInner, isConst) : argInner.clone() })) })) @@ -656,16 +643,16 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ const arg1 = node.args[1] // d/dx(f(x) / c) = f'(x) / c - if (constNodes[arg1] !== undefined) { - return new OperatorNode('/', 'divide', [_derivative(arg0, constNodes), arg1]) + if (isConst(arg1)) { + return new OperatorNode('/', 'divide', [_derivative(arg0, isConst), arg1]) } // Reciprocal Rule, d/dx(c / f(x)) = -c(f'(x)/f(x)^2) - if (constNodes[arg0] !== undefined) { + if (isConst(arg0)) { return new OperatorNode('*', 'multiply', [ new OperatorNode('-', 'unaryMinus', [arg0]), new OperatorNode('/', 'divide', [ - _derivative(arg1, constNodes), + _derivative(arg1, isConst), new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)]) ]) ]) @@ -674,8 +661,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ // Quotient rule, d/dx(f(x) / g(x)) = (f'(x)g(x) - f(x)g'(x)) / g(x)^2 return new OperatorNode('/', 'divide', [ new OperatorNode('-', 'subtract', [ - new OperatorNode('*', 'multiply', [_derivative(arg0, constNodes), arg1.clone()]), - new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, constNodes)]) + new OperatorNode('*', 'multiply', [_derivative(arg0, isConst), arg1.clone()]), + new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, isConst)]) ]), new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)]) ]) @@ -685,7 +672,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ const arg0 = node.args[0] const arg1 = node.args[1] - if (constNodes[arg0] !== undefined) { + if (isConst(arg0)) { // If is secretly constant; 0^f(x) = 1 (in JS), 1^f(x) = 1 if (isConstantNode(arg0) && (isZero(arg0.value) || equal(arg0.value, 1))) { return createConstantNode(0) @@ -696,12 +683,12 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ node, new OperatorNode('*', 'multiply', [ new FunctionNode('log', [arg0.clone()]), - _derivative(arg1.clone(), constNodes) + _derivative(arg1.clone(), isConst) ]) ]) } - if (constNodes[arg1] !== undefined) { + if (isConst(arg1)) { if (isConstantNode(arg1)) { // If is secretly constant; f(x)^0 = 1 -> d/dx(1) = 0 if (isZero(arg1.value)) { @@ -709,7 +696,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ } // Ignore exponent; f(x)^1 = f(x) if (equal(arg1.value, 1)) { - return _derivative(arg0, constNodes) + return _derivative(arg0, isConst) } } @@ -725,7 +712,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ return new OperatorNode('*', 'multiply', [ arg1.clone(), new OperatorNode('*', 'multiply', [ - _derivative(arg0, constNodes), + _derivative(arg0, isConst), powMinusOne ]) ]) @@ -736,11 +723,11 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({ new OperatorNode('^', 'pow', [arg0.clone(), arg1.clone()]), new OperatorNode('+', 'add', [ new OperatorNode('*', 'multiply', [ - _derivative(arg0, constNodes), + _derivative(arg0, isConst), new OperatorNode('/', 'divide', [arg1.clone(), arg0.clone()]) ]), new OperatorNode('*', 'multiply', [ - _derivative(arg1, constNodes), + _derivative(arg1, isConst), new FunctionNode('log', [arg0.clone()]) ]) ]) diff --git a/test/benchmark/derivative.js b/test/benchmark/derivative.js new file mode 100644 index 0000000000..6a8a3726bd --- /dev/null +++ b/test/benchmark/derivative.js @@ -0,0 +1,32 @@ +// test performance of derivative + +import Benchmark from 'benchmark' +import { derivative, parse } from '../../lib/esm/index.js' + +let expr = parse('0') +for (let i = 1; i <= 5; i++) { + for (let j = 1; j <= 5; j++) { + expr = parse(`${expr} + sin(${i + j} * x ^ ${i} + ${i * j} * y ^ ${j})`) + } +} + +const results = [] + +Benchmark.options.minSamples = 100 + +const suite = new Benchmark.Suite() +suite + .add('ddf', function () { + const res = derivative(derivative(expr, parse('x'), { simplify: false }), parse('x'), { simplify: false }) + results.splice(0, 1, res) + }) + .add('df ', function () { + const res = derivative(expr, parse('x'), { simplify: false }) + results.splice(0, 1, res) + }) + .on('cycle', function (event) { + console.log(String(event.target)) + }) + .on('complete', function () { + }) + .run() diff --git a/test/benchmark/index.js b/test/benchmark/index.js index 7ada9f253d..221e4722a9 100644 --- a/test/benchmark/index.js +++ b/test/benchmark/index.js @@ -1,6 +1,7 @@ // run all benchmarks import './expression_parser.js' import './algebra.js' +import './derivative.js' import './roots.js' import './matrix_operations.js' import './prime.js' diff --git a/test/unit-tests/function/algebra/derivative.test.js b/test/unit-tests/function/algebra/derivative.test.js index 0f798abef7..b06fb515cc 100644 --- a/test/unit-tests/function/algebra/derivative.test.js +++ b/test/unit-tests/function/algebra/derivative.test.js @@ -126,7 +126,7 @@ describe('derivative', function () { compareString(derivativeWithoutSimplify('nthRoot(6x)', 'x'), '6 * 1 / (2 * sqrt(6 x))') compareString(derivativeWithoutSimplify('nthRoot(6x, 3)', 'x'), '1 / 3 * 6 * 1 * (6 x) ^ (1 / 3 - 1)') - compareString(derivativeWithoutSimplify('nthRoot((6x), (2x))', 'x'), '(6 x) ^ (1 / (2 x)) * ((6 * 1) * 1 / (2 x) / (6 x) + (0 * (2 x) - 1 * (2 * 1)) / (2 x) ^ 2 * log((6 x)))') + compareString(derivativeWithoutSimplify('nthRoot((6x), (2x))', 'x'), '(6 x) ^ (1 / (2 x)) * ((6 * 1) * 1 / (2 x) / (6 x) + -1 * (2 * 1) / (2 x) ^ 2 * log((6 x)))') compareString(derivativeWithoutSimplify('log((6*x))', 'x'), '(6 * 1) / (6 * x)') compareString(derivativeWithoutSimplify('log10((6x))', 'x'), '(6 * 1) / ((6 x) * log(10))') compareString(derivativeWithoutSimplify('log((6x), 10)', 'x'), '(6 * 1) / ((6 x) * log(10))') @@ -270,11 +270,11 @@ describe('derivative', function () { assert.throws(function () { derivative('[1, 2; 3, 4]', 'x') - }, /TypeError: Unexpected type of argument in function constTag \(expected: ConstantNode or FunctionNode or FunctionAssignmentNode or OperatorNode or ParenthesisNode or SymbolNode, actual:.*ArrayNode.*, index: 1\)/) + }, /TypeError: Unexpected type of argument in function _derivative \(expected: ConstantNode or FunctionNode or FunctionAssignmentNode or OperatorNode or ParenthesisNode or SymbolNode, actual:.*ArrayNode.*, index: 0\)/) assert.throws(function () { derivative('x + [1, 2; 3, 4]', 'x') - }, /TypeError: Unexpected type of argument in function constTag \(expected: ConstantNode or FunctionNode or FunctionAssignmentNode or OperatorNode or ParenthesisNode or SymbolNode, actual:.*ArrayNode.*, index: 1\)/) + }, /TypeError: Unexpected type of argument in function _derivative \(expected: ConstantNode or FunctionNode or FunctionAssignmentNode or OperatorNode or ParenthesisNode or SymbolNode, actual:.*ArrayNode.*, index: 0\)/) }) it('should throw error if incorrect number of arguments', function () {