Skip to content

Commit

Permalink
feat: allow pytorch weight loaded as tf weights for bert models
Browse files Browse the repository at this point in the history
  • Loading branch information
souvikg10 committed Dec 11, 2023
1 parent cca30d4 commit 086097f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
13 changes: 10 additions & 3 deletions rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"distilbert": 512,
"roberta": 512,
"camembert": 512,
"other": 512
}


Expand Down Expand Up @@ -152,9 +153,15 @@ def _load_model_instance(self) -> None:
self.tokenizer = model_tokenizer_dict[self.model_name].from_pretrained(
self.model_weights, cache_dir=self.cache_dir
)
self.model = model_class_dict[self.model_name].from_pretrained(
self.model_weights, cache_dir=self.cache_dir
)
if self.model_name == "other":
#always load pytorch weights
self.model = model_class_dict[self.model_name].from_pretrained(
self.model_weights, cache_dir=self.cache_dir, from_pt= True
)
else:
self.model = model_class_dict[self.model_name].from_pretrained(
self.model_weights, cache_dir=self.cache_dir
)

# Use a universal pad token since all transformer architectures do not have a
# consistent token. Instead of pad_token_id we use unk_token_id because
Expand Down
8 changes: 8 additions & 0 deletions rasa/nlu/utils/hugging_face/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TFDistilBertModel,
TFRobertaModel,
TFCamembertModel,
TFAutoModel,
PreTrainedTokenizer,
BertTokenizer,
OpenAIGPTTokenizer,
Expand All @@ -24,6 +25,7 @@
DistilBertTokenizer,
RobertaTokenizer,
CamembertTokenizer,
AutoTokenizer
)
from rasa.nlu.utils.hugging_face.transformers_pre_post_processors import ( # noqa: E402, E501
bert_tokens_pre_processor,
Expand Down Expand Up @@ -52,6 +54,7 @@
"distilbert": TFDistilBertModel,
"roberta": TFRobertaModel,
"camembert": TFCamembertModel,
"other": TFAutoModel
}
model_tokenizer_dict: Dict[Text, Type[PreTrainedTokenizer]] = {
"bert": BertTokenizer,
Expand All @@ -62,6 +65,7 @@
"distilbert": DistilBertTokenizer,
"roberta": RobertaTokenizer,
"camembert": CamembertTokenizer,
"other": AutoTokenizer
}
model_weights_defaults = {
"bert": "rasa/LaBSE",
Expand All @@ -72,6 +76,7 @@
"distilbert": "distilbert-base-uncased",
"roberta": "roberta-base",
"camembert": "camembert-base",
"other": "sentence-transformers/all-MiniLM-L6-v2"
}

model_special_tokens_pre_processors = {
Expand All @@ -83,6 +88,7 @@
"distilbert": bert_tokens_pre_processor,
"roberta": roberta_tokens_pre_processor,
"camembert": camembert_tokens_pre_processor,
"other": bert_tokens_pre_processor,
}

model_tokens_cleaners = {
Expand All @@ -94,6 +100,7 @@
"distilbert": bert_tokens_cleaner, # uses the same as BERT
"roberta": gpt2_tokens_cleaner, # Uses the same as GPT2
"camembert": xlnet_tokens_cleaner, # Removing underscores _
"other": bert_tokens_cleaner
}

model_embeddings_post_processors = {
Expand All @@ -105,4 +112,5 @@
"distilbert": bert_embeddings_post_processor,
"roberta": roberta_embeddings_post_processor,
"camembert": roberta_embeddings_post_processor,
"other": bert_embeddings_post_processor
}

0 comments on commit 086097f

Please sign in to comment.