Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BruteForce layer not working after loading a saved model #744

Open
Toby1218 opened this issue Nov 11, 2024 · 0 comments
Open

BruteForce layer not working after loading a saved model #744

Toby1218 opened this issue Nov 11, 2024 · 0 comments

Comments

@Toby1218
Copy link

I have trained a deep retrieval model with NCF. Yet when saving the model, I tried both methods of saving and loading by keras and tf. I cannot specify the top_k items to retrieve from the BruteForce layer.

brute = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
ds2 = items.map(lambda x: (x['parent_asin'], model.item_model(x)))
brute.index_from_dataset(ds2)
brute(tf.constant(['B1000000028']), k = 15)
brute.save('nc')
loaded = tf.keras.models.load_model('nc')
print(loaded(tf.constant(['B1000000028']), k = 5))

It returned such error

ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
  Positional arguments (2 total):
    * <tf.Tensor 'queries:0' shape=(1,) dtype=string>
    * 5
  Keyword arguments: {'training': False}

 Expected these arguments to match one of the following 2 option(s):

Option 1:
  Positional arguments (2 total):
    * TensorSpec(shape=(None,), dtype=tf.string, name='input_1')
    * None
  Keyword arguments: {'training': True}

Option 2:
  Positional arguments (2 total):
    * TensorSpec(shape=(None,), dtype=tf.string, name='input_1')
    * None
  Keyword arguments: {'training': False}

I want to know how can I specify the number of items to retrieve after saving the model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant