Skip to content

Commit

Permalink
keras_nlp workspace changes
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Nov 27, 2024
1 parent 78093fc commit a00ade2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 86 deletions.
10 changes: 5 additions & 5 deletions openfl-workspace/keras_nlp/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ aggregator :
defaults : plan/defaults/aggregator.yaml
template : openfl.component.Aggregator
settings :
init_state_path : save/keras_nlp_init.pbuf
best_state_path : save/keras_nlp_best.pbuf
last_state_path : save/keras_nlp_last.pbuf
init_state_path : save/init.pbuf
best_state_path : save/best.pbuf
last_state_path : save/last.pbuf
rounds_to_train : 10

collaborator :
Expand All @@ -21,7 +21,7 @@ collaborator :

data_loader :
defaults : plan/defaults/data_loader.yaml
template : src.nlp_dataloader.NLPDataLoader
template : src.dataloader.NLPDataLoader
settings :
collaborator_count : 2
batch_size : 64
Expand All @@ -30,7 +30,7 @@ data_loader :

task_runner :
defaults : plan/defaults/task_runner.yaml
template : src.nlp_taskrunner.KerasNLP
template : src.taskrunner.KerasNLP
settings :
latent_dim : 256

Expand Down
3 changes: 2 additions & 1 deletion openfl-workspace/keras_nlp/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tensorflow==2.13
keras==3.6.0
tensorflow==2.18.0
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
license agreement between Intel Corporation and you.
"""
from logging import getLogger
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
Expand Down Expand Up @@ -57,7 +55,7 @@ def get_feature_shape(self) -> Tuple[int, ...]:
"""Get the shape of an example feature array."""
return self.X_train[0].shape

def get_train_loader(self, batch_size: Optional[int] = None) -> Iterator[List[np.ndarray]]:
def get_train_loader(self, batch_size: Optional[int] = None):
"""
Get training data loader.
Expand All @@ -68,7 +66,7 @@ def get_train_loader(self, batch_size: Optional[int] = None) -> Iterator[List[np
return self._get_batch_generator(X1=self.X_train[0], X2=self.X_train[1],
y=self.y_train, batch_size=batch_size)

def get_valid_loader(self, batch_size: Optional[int] = None) -> Iterator[List[np.ndarray]]:
def get_valid_loader(self, batch_size: Optional[int] = None):
"""
Get validation data loader.
Expand Down Expand Up @@ -100,7 +98,7 @@ def get_valid_data_size(self) -> int:
def _batch_generator(X1: np.ndarray, X2: np.ndarray,
y: np.ndarray, idxs: np.ndarray,
batch_size: int,
num_batches: int) -> Iterator[List[np.ndarray]]:
num_batches: int):
"""
Generate batch of data.
Expand All @@ -116,11 +114,11 @@ def _batch_generator(X1: np.ndarray, X2: np.ndarray,
for i in range(num_batches):
a = i * batch_size
b = a + batch_size
yield [X1[idxs[a:b]], X2[idxs[a:b]]], y[idxs[a:b]]
yield (X1[idxs[a:b]], X2[idxs[a:b]]), y[idxs[a:b]]

def _get_batch_generator(self, X1: np.ndarray, X2: np.ndarray,
y: np.ndarray,
batch_size: Union[int, None]) -> Iterator[List[np.ndarray]]:
batch_size: Union[int, None]):
"""
Return the dataset generator.
Expand Down
73 changes: 0 additions & 73 deletions openfl-workspace/keras_nlp/src/nlp_taskrunner.py

This file was deleted.

72 changes: 72 additions & 0 deletions openfl-workspace/keras_nlp/src/taskrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Copyright (C) 2020-2021 Intel Corporation
SPDX-License-Identifier: Apache-2.0
Licensed subject to the terms of the separately executed evaluation
license agreement between Intel Corporation and you.
"""
import keras as ke

from openfl.federated import KerasTaskRunner


class KerasNLP(KerasTaskRunner):
"""A basic convolutional neural network model."""

def __init__(self, latent_dim, **kwargs):
"""
Init taskrunner.
Args:
**kwargs: Additional parameters to pass to the function
"""
super().__init__(**kwargs)

self.model = self.build_model(latent_dim,
self.data_loader.num_encoder_tokens,
self.data_loader.num_decoder_tokens,
**kwargs)

self.initialize_tensorkeys_for_functions()

self.model.summary(print_fn=self.logger.info)

self.logger.info(f'Train Set Size : {self.get_train_data_size()}')

def build_model(self, latent_dim, num_encoder_tokens, num_decoder_tokens, **kwargs):
"""
Define the model architecture.
Args:
input_shape (numpy.ndarray): The shape of the data
num_classes (int): The number of classes of the dataset
Returns:
tensorflow.python.keras.engine.sequential.Sequential: The model defined in Keras
"""
encoder_inputs = ke.Input(shape=(None, num_encoder_tokens))
encoder = ke.layers.LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)

# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = ke.Input(shape=(None, num_decoder_tokens))

# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = ke.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = ke.layers.Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = ke.Model([encoder_inputs, decoder_inputs], decoder_outputs)

model.compile(
optimizer="RMSprop",
loss='categorical_crossentropy', metrics=['accuracy']
)

return model

0 comments on commit a00ade2

Please sign in to comment.