-
-
Notifications
You must be signed in to change notification settings - Fork 456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How do I use a masking layer for TCN? I want to mask certain time steps which are missing. #234
Comments
Like this: https://keras.io/api/layers/core_layers/masking import numpy as np
import tensorflow as tf
from tcn import TCN
samples, timesteps, features = 32, 10, 8
inputs = np.random.random([samples, timesteps, features]).astype(np.float32)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Masking(mask_value=0., input_shape=(timesteps, features)))
model.add(TCN(32))
output = model(inputs)
print(output) |
I'm adding more comments here. Might be useful for later. It seems to work correctly because: inputs[:, :, :] = 1.
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Masking(mask_value=1., input_shape=(timesteps, features))) The model will output only zeros. If we change it to Another example: import numpy as np
import tensorflow as tf
from tcn import TCN
samples, timesteps, features = 32, 10, 8
inputs = np.random.random([samples, timesteps, features]).astype(np.float32)
inputs[:, :-1, :] = 1.
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Masking(mask_value=1., input_shape=(timesteps, features)))
model.add(TCN(16, return_sequences=True))
output = model(inputs)
print(output.shape)
print(np.mean(np.abs(output[:, :-1, :]))) # 0.0 EXPECTED (EQUAL TO 0, due to masking).
print(np.mean(np.abs(output[:, -1:, :]))) # 0.8137888 EXPECTED (DIFFERENT THAN 0). |
Thank you so much @philipperemy! |
@arvind267 happy I could help! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
No description provided.
The text was updated successfully, but these errors were encountered: