Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passing NeuralLDA.train_model(top_words=10) to AVITM_model.train_model(top_words=10) #103

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions octis/models/CTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
topic_prior_variance=self.hyperparameters["prior_variance"],
top_words=top_words)

self.model.fit(x_train, x_valid, verbose=False)
self.model.fit(x_train, x_valid, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"])
result = self.inference(x_test)
return result

Expand All @@ -161,7 +161,7 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
topic_prior_variance=self.hyperparameters["prior_variance"],
top_words=top_words)

self.model.fit(x_train, None, verbose=False)
self.model.fit(x_train, None, verbose=self.hyperparameters["verbose"], save_dir=self.hyperparameters["save_dir"])
result = self.model.get_info()
return result

Expand Down
4 changes: 2 additions & 2 deletions octis/models/contextualized_topic_models/models/ctm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def fit(self, train_dataset, validation_dataset=None,

train_loader = DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)

# init training variables
train_loss = 0
Expand Down Expand Up @@ -301,7 +301,7 @@ def fit(self, train_dataset, validation_dataset=None,
if self.validation_data is not None:
validation_loader = DataLoader(
self.validation_data, batch_size=self.batch_size,
shuffle=True, num_workers=self.num_data_loader_workers)
shuffle=True, num_workers=self.num_data_loader_workers, drop_last=True)
# train epoch
s = datetime.datetime.now()
val_samples_processed, val_loss = self._validation(
Expand Down
6 changes: 3 additions & 3 deletions octis/models/pytorchavitm/AVITM.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def train_model(self, dataset, hyperparameters=None, top_words=10):
solver=self.hyperparameters['solver'], num_epochs=self.hyperparameters['num_epochs'],
reduce_on_plateau=self.hyperparameters['reduce_on_plateau'], num_samples=self.hyperparameters[
'num_samples'], topic_prior_mean=self.hyperparameters["prior_mean"],
topic_prior_variance=self.hyperparameters["prior_variance"]
topic_prior_variance=self.hyperparameters["prior_variance"], verbose=self.hyperparameters["verbose"], top_words=top_words,
)

if self.use_partitions:
self.model.fit(x_train, x_valid)
self.model.fit(x_train, x_valid, save_dir=self.hyperparameters["save_dir"])
result = self.inference(x_test)
else:
self.model.fit(x_train, None)
self.model.fit(x_train, None, save_dir=self.hyperparameters["save_dir"])
result = self.model.get_info()
return result

Expand Down
9 changes: 5 additions & 4 deletions octis/models/pytorchavitm/avitm/avitm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AVITM_model(object):
def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes=(100, 100),
activation='softplus', dropout=0.2, learn_priors=True, batch_size=64, lr=2e-3, momentum=0.99,
solver='adam', num_epochs=100, reduce_on_plateau=False, topic_prior_mean=0.0,
topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False):
topic_prior_variance=None, num_samples=10, num_data_loader_workers=0, verbose=False, top_words=10):
"""
Initialize AVITM model.

Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, input_size, num_topics=10, model_type='prodLDA', hidden_sizes
# assert isinstance(topic_prior_variance, float), \
# "topic prior_variance must be type float"

self.top_words = top_words
self.input_size = input_size
self.num_topics = num_topics
self.verbose = verbose
Expand Down Expand Up @@ -240,7 +241,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None):
self.validation_data = validation_dataset
train_loader = DataLoader(
self.train_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)

# init training variables
train_loss = 0
Expand All @@ -267,7 +268,7 @@ def fit(self, train_dataset, validation_dataset, save_dir=None):
if self.validation_data is not None:
validation_loader = DataLoader(
self.validation_data, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_data_loader_workers)
num_workers=self.num_data_loader_workers, drop_last=True)
# train epoch
s = datetime.datetime.now()
val_samples_processed, val_loss = self._validation(validation_loader)
Expand Down Expand Up @@ -347,7 +348,7 @@ def get_topics(self, k=10):

def get_info(self):
info = {}
topic_word = self.get_topics()
topic_word = self.get_topics(k=self.top_words) # or self.input_size
topic_word_dist = self.get_topic_word_mat()
# topic_document_dist = self.get_topic_document_mat()
info['topics'] = topic_word
Expand Down