Skip to content

Commit

Permalink
[SYSTEMDS-3701] Add test suite for Scuro
Browse files Browse the repository at this point in the history
Closes #2143.
  • Loading branch information
christinadionysio authored and mboehm7 committed Nov 21, 2024
1 parent 08875cb commit 4e00aa1
Show file tree
Hide file tree
Showing 10 changed files with 592 additions and 84 deletions.
7 changes: 6 additions & 1 deletion src/main/python/systemds/scuro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
from systemds.scuro.representations.bert import Bert
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.lstm import LSTM
from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
from systemds.scuro.representations.representation_dataloader import (
NPY,
Pickle,
HDF5,
JSON,
)
from systemds.scuro.models.model import Model
from systemds.scuro.models.discrete_model import DiscreteModel
from systemds.scuro.modality.aligned_modality import AlignedModality
Expand Down
41 changes: 35 additions & 6 deletions src/main/python/systemds/scuro/aligner/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,19 @@
from typing import List

from systemds.scuro.models.model import Model
import numpy as np
from sklearn.model_selection import KFold


class Task:
def __init__(
self, name: str, model: Model, labels, train_indices: List, val_indices: List
self,
name: str,
model: Model,
labels,
train_indices: List,
val_indices: List,
kfold=5,
):
"""
Parent class for the prediction task that is performed on top of the aligned representation
Expand All @@ -34,12 +42,15 @@ def __init__(
:param labels: Labels used for prediction
:param train_indices: Indices to extract training data
:param val_indices: Indices to extract validation data
:param kfold: Number of crossvalidation runs
"""
self.name = name
self.model = model
self.labels = labels
self.train_indices = train_indices
self.val_indices = val_indices
self.kfold = kfold

def get_train_test_split(self, data):
X_train = [data[i] for i in self.train_indices]
Expand All @@ -51,9 +62,27 @@ def get_train_test_split(self, data):

def run(self, data):
"""
The run method need to be implemented by every task class
It handles the training and validation procedures for the specific task
:param data: The aligned data used in the prediction process
:return: the validation accuracy
The run method needs to be implemented by every task class
It handles the training and validation procedures for the specific task
:param data: The aligned data used in the prediction process
:return: the validation accuracy
"""
pass
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
train_scores = []
test_scores = []
fold = 0
X, y, X_test, y_test = self.get_train_test_split(data)

for train, test in skf.split(X, y):
train_X = np.array(X)[train]
train_y = np.array(y)[train]

train_score = self.model.fit(train_X, train_y, X_test, y_test)
train_scores.append(train_score)

test_score = self.model.test(X_test, y_test)
test_scores.append(test_score)

fold += 1

return [np.mean(train_scores), np.mean(test_scores)]
6 changes: 4 additions & 2 deletions src/main/python/systemds/scuro/representations/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def parse_all(self, filepath, indices, get_sequences=False):
data = file.readlines()

model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(
model_name, clean_up_tokenization_spaces=True
)

if self.avg_layers is not None:
model = BertModel.from_pretrained(model_name, output_hidden_states=True)
Expand Down Expand Up @@ -89,7 +91,7 @@ def create_embeddings(self, data, model, tokenizer):
cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
else:
cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
embeddings.append(cls_embedding)
embeddings.append(cls_embedding.numpy())

embeddings = np.array(embeddings)
return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
Expand Down
2 changes: 0 additions & 2 deletions src/main/python/systemds/scuro/representations/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# -------------------------------------------------------------
from typing import List

from sklearn.preprocessing import StandardScaler

from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.representation import Representation

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------


import json
import pickle
import numpy as np
import h5py

from systemds.scuro.representations.unimodal import UnimodalRepresentation


class NPY(UnimodalRepresentation):
def __init__(self):
super().__init__("NPY")

def parse_all(self, filepath, indices, get_sequences=False):
data = np.load(filepath, allow_pickle=True)

if indices is not None:
return np.array([data[index] for index in indices])
else:
return np.array([data[index] for index in data])


class Pickle(UnimodalRepresentation):
def __init__(self):
super().__init__("Pickle")

def parse_all(self, file_path, indices, get_sequences=False):
with open(file_path, "rb") as f:
data = pickle.load(f)

embeddings = []
for n, idx in enumerate(indices):
embeddings.append(data[idx])

return np.array(embeddings)


class HDF5(UnimodalRepresentation):
def __init__(self):
super().__init__("HDF5")

def parse_all(self, filepath, indices=None, get_sequences=False):
data = h5py.File(filepath)

if get_sequences:
max_emb = 0
for index in indices:
if max_emb < len(data[index][()]):
max_emb = len(data[index][()])

emb = []
if indices is not None:
for index in indices:
emb_i = data[index].tolist()
for i in range(len(emb_i), max_emb):
emb_i.append([0 for x in range(0, len(emb_i[0]))])
emb.append(emb_i)

return np.array(emb)
else:
if indices is not None:
return np.array([np.mean(data[index], axis=0) for index in indices])
else:
return np.array([np.mean(data[index][()], axis=0) for index in data])


class JSON(UnimodalRepresentation):
def __init__(self):
super().__init__("JSON")

def parse_all(self, filepath, indices):
with open(filepath) as file:
return json.load(file)
73 changes: 0 additions & 73 deletions src/main/python/systemds/scuro/representations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,81 +19,8 @@
#
# -------------------------------------------------------------


import json
import pickle

import h5py
import numpy as np

from systemds.scuro.representations.unimodal import UnimodalRepresentation


class NPY(UnimodalRepresentation):
def __init__(self):
super().__init__("NPY")

def parse_all(self, filepath, indices, get_sequences=False):
data = np.load(filepath, allow_pickle=True)

if indices is not None:
return np.array([data[index] for index in indices])
else:
return np.array([data[index] for index in data])


class Pickle(UnimodalRepresentation):
def __init__(self):
super().__init__("Pickle")

def parse_all(self, file_path, indices, get_sequences=False):
with open(file_path, "rb") as f:
data = pickle.load(f)

embeddings = []
for n, idx in enumerate(indices):
embeddings.append(data[idx])

return np.array(embeddings)


class HDF5(UnimodalRepresentation):
def __init__(self):
super().__init__("HDF5")

def parse_all(self, filepath, indices=None, get_sequences=False):
data = h5py.File(filepath)

if get_sequences:
max_emb = 0
for index in indices:
if max_emb < len(data[index][()]):
max_emb = len(data[index][()])

emb = []
if indices is not None:
for index in indices:
emb_i = data[index].tolist()
for i in range(len(emb_i), max_emb):
emb_i.append([0 for x in range(0, len(emb_i[0]))])
emb.append(emb_i)

return np.array(emb)
else:
if indices is not None:
return np.array([np.mean(data[index], axis=0) for index in indices])
else:
return np.array([np.mean(data[index][()], axis=0) for index in data])


class JSON(UnimodalRepresentation):
def __init__(self):
super().__init__("JSON")

def parse_all(self, filepath, indices):
with open(filepath) as file:
return json.load(file)


def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
if maxlen is None:
Expand Down
20 changes: 20 additions & 0 deletions src/main/python/tests/scuro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------
Loading

0 comments on commit 4e00aa1

Please sign in to comment.