diff --git a/instruct_qa/generation/generator.py b/instruct_qa/generation/generator.py index 175b994..d616ae2 100644 --- a/instruct_qa/generation/generator.py +++ b/instruct_qa/generation/generator.py @@ -62,6 +62,8 @@ def post_process_response(self, response): class GPTx(BaseGenerator): def __init__(self, *args, **kwargs): + completion_type = kwargs.pop("completion_type", None) + super().__init__(*args, **kwargs) openai.api_key = self.api_key self.model_map = { @@ -70,8 +72,9 @@ def __init__(self, *args, **kwargs): "text-davinci-003": "completions", "text-davinci-002": "completions", } - if "completion_type" in kwargs: - self.model_map[self.model_name] = kwargs["completion_type"] + + if completion_type is not None: + self.model_map[model_name] = completion_type assert ( self.model_name in self.model_map