Skip to content

Commit

Permalink
use xla to compile the tf graph, add _some_ jax support
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Nov 22, 2024
1 parent 8bf4639 commit 45fe562
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
13 changes: 6 additions & 7 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@
import re

from keras import Variable
from keras import backend as K
from keras import ops as Kops
from keras import optimizers as Kopt
from keras.models import Model
import numpy as np

import n3fit.backends.keras_backend.operations as op

# Starting with TF 2.16, a memory leak in TF https://github.com/tensorflow/tensorflow/issues/64170
# makes jit compilation unusable in GPU.
# Before TF 2.16 it was set to `False` by default. From 2.16 onwards, it is set to `True`
JIT_COMPILE = False

# Define in this dictionary new optimizers as well as the arguments they accept
# (with default values if needed be)
optimizers = {
Expand Down Expand Up @@ -296,13 +292,16 @@ def compile(

# If given target output is None, target_output is unnecesary, save just a zero per output
if target_output is None:
self.target_tensors = [op.numpy_to_tensor(np.zeros((1, 1))) for i in self.output_shape]
self.target_tensors = [op.numpy_to_tensor(np.zeros((1, 1))) for _ in self.output_shape]
else:
if not isinstance(target_output, list):
target_output = [target_output]
self.target_tensors = target_output

super().compile(optimizer=opt, loss=loss, jit_compile=JIT_COMPILE)
# For debug purposes it may be interesting to set in the compile call
# jit_compile = False
# run_eager = True
super().compile(optimizer=opt, loss=loss)

def set_masks_to(self, names, val=0.0):
"""Set all mask value to the selected value
Expand Down
2 changes: 1 addition & 1 deletion n3fit/src/n3fit/backends/keras_backend/internal_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
log = logging.getLogger(__name__)

# Prepare Keras-backend dependent functions
if K.backend() == "torch":
if K.backend() in ("torch", "jax"):

def set_eager(flag=True):
"""Pytorch is eager by default"""
Expand Down
9 changes: 5 additions & 4 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
equally operations are automatically converted to layers when used as such.
"""

from typing import Optional

from keras import backend as K
from keras import ops as Kops
from keras.layers import ELU, Input
Expand All @@ -37,9 +35,12 @@

# Backend dependent functions and operations
if K.backend() == "torch":
tensor_to_numpy_or_python = lambda x: x.detach().numpy()
tensor_to_numpy_or_python = lambda x: x.detach().cpu().numpy()
decorator_compiler = lambda f: f
elif K.backend() == "jax":
tensor_to_numpy_or_python = lambda x: np.array(x.block_until_ready())
decorator_compiler = lambda f: f
else:
elif K.backend() == "tensorflow":
tensor_to_numpy_or_python = lambda x: x.numpy()
lambda ret: {k: i.numpy() for k, i in ret.items()}
import tensorflow as tf
Expand Down
2 changes: 1 addition & 1 deletion n3fit/src/n3fit/layers/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def compute_float_mask(bool_mask):
"""
# Create a tensor with the shape (**bool_mask.shape, num_active_flavours)
masked_to_full = []
for idx in np.argwhere(np.array(bool_mask)):
for idx in np.argwhere(op.tensor_to_numpy_or_python(bool_mask)):
temp_matrix = np.zeros(bool_mask.shape)
temp_matrix[tuple(idx)] = 1
masked_to_full.append(temp_matrix)
Expand Down

0 comments on commit 45fe562

Please sign in to comment.