Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax] Gaussian and GaussLegendre throw errors #214

Open
javier-garcia-tilburg opened this issue Nov 25, 2024 · 0 comments · May be fixed by #215
Open

[jax] Gaussian and GaussLegendre throw errors #214

javier-garcia-tilburg opened this issue Nov 25, 2024 · 0 comments · May be fixed by #215

Comments

@javier-garcia-tilburg
Copy link

javier-garcia-tilburg commented Nov 25, 2024

Issue

Problem Description

Gaussian and GaussLegendre integrators don't work with jax backend, they throw

TypeError: prod requires ndarray or scalar arguments, got <class 'list'> at position 0.

and this error comes from

return anp.prod(
anp.meshgrid(*([weights] * dim), like=backend), axis=0
).ravel()

What Needs to be Done

I propose to convert the list returned by anp.meshgrid into an array with anp.stack. This fix works for me in jax, however we should first check that this doesn't cause problems in other backends.

            return anp.prod(
                anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0
            ).ravel()

How Can It Be Tested or Reproduced

Using jax-0.4.35 and torchquad 0.4.0 run

import jax
import jax.numpy as jnp
from torchquad import set_up_backend, MonteCarlo, Gaussian, GaussLegendre

set_up_backend(backend="jax")

@jax.jit
def some_function(x):
    return jnp.power(x[:, 0] - x[:, 1], 2)

g = Gaussian()
# It also fails with GaussLegendre
# g = GaussLegendre()

integral_value = g.integrate(
    lambda x: some_function(x),
    dim=2,
    N=10000,
    integration_domain=jnp.asarray([[-1.0, 1.0], [-1.0, 1.0]]),
)
javier-garcia-tilburg added a commit to javier-garcia-tilburg/torchquad that referenced this issue Nov 25, 2024
Gaussian and GaussLegendre integrators throw an error on jax if anp.prod is called on a list instead of an array. See esa#214
@javier-garcia-tilburg javier-garcia-tilburg linked a pull request Nov 25, 2024 that will close this issue
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant