Skip to content

Commit

Permalink
Supported SFT for openai
Browse files Browse the repository at this point in the history
  • Loading branch information
JingofXin committed Nov 29, 2024
1 parent 2cea61d commit ef9bb24
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
12 changes: 6 additions & 6 deletions lazyllm/module/onlineChatModule/openaiModule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import uuid
import requests
from typing import Tuple, List
from urllib.parse import urljoin
Expand Down Expand Up @@ -87,10 +88,10 @@ def _update_kw(self, data, normal_config):
current_train_data = self.default_train_data.copy()
current_train_data.update(data)

current_train_data["hyper_parameters"]["n_epochs"] = normal_config["num_epochs"]
current_train_data["hyper_parameters"]["learning_rate_multiplier"] = str(normal_config["learning_rate"])
current_train_data["hyper_parameters"]["batch_size"] = normal_config["batch_size"]
current_train_data["suffix"] = normal_config["finetune_model_name"]
current_train_data["hyperparameters"]["n_epochs"] = normal_config["num_epochs"]
current_train_data["hyperparameters"]["learning_rate_multiplier"] = str(normal_config["learning_rate"])
current_train_data["hyperparameters"]["batch_size"] = normal_config["batch_size"]
current_train_data["suffix"] = str(uuid.uuid4())[:7]

return current_train_data

Expand Down Expand Up @@ -150,8 +151,7 @@ def _get_finetuned_model_names(self) -> Tuple[List[Tuple[str, str]], List[Tuple[
model_data = self._query_finetuned_jobs()
res = list()
for model in model_data['data']:
status = 'Done'if 'successful' in model['message'] else 'Failed'
res.append([model['id'], model['fine_tuned_model'], status])
res.append([model['id'], model['fine_tuned_model'], self._status_mapping(model['status'])])
return res

def _status_mapping(self, status):
Expand Down
14 changes: 7 additions & 7 deletions lazyllm/tools/train_service/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def train(self, train_config, token, source):
Args:
- train_config (dict): Configuration parameters for the training task.
- token (str): API-Key provided by the supplier, used for authentication.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- tuple: A tuple containing the Job-ID and its status if the training starts successfully.
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_all_trained_models(self, token, source):
Args:
- token (str): API-Key provided by the supplier, used for authentication.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- list of lists: Each sublist contains [job_id, model_name, status] for each trained model.
Expand All @@ -307,7 +307,7 @@ def get_training_status(self, token, job_id, source):
Args:
- token (str): API-Key provided by the supplier, used for authentication.
- job_id (str): The unique identifier of the training job to query.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- str: A string representing the current status of the training task. This could be one of:
Expand All @@ -332,7 +332,7 @@ def cancel_training(self, token, job_id, source):
Args:
- token (str): API-Key provided by the supplier, used for authentication.
- job_id (str): The unique identifier of the training job to be cancelled.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- bool or str: Returns True if the training task was successfully cancelled. If the cancellation fails,
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_training_log(self, token, job_id, source, target_path=None):
Args:
- token (str): API-Key provided by the supplier, used for authentication.
- job_id (str): The unique identifier of the training job for which to retrieve the log.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
- target_path (str, optional): The path where the log file should be saved. If not provided,
the log will be saved to a temporary directory.
Expand Down Expand Up @@ -389,7 +389,7 @@ def get_training_cost(self, token, job_id, source):
Args:
- token (str): API-Key provided by the supplier, used for authentication.
- job_id (str): The unique identifier of the traning job for which to retrieve the token consumption.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- int or str: The number of tokens consumed by the traning task if the query is successful.
Expand All @@ -413,7 +413,7 @@ def validate_api_key(self, token, source):
Args:
- token (str): API-Key provided by the user, used for authentication.
- source (str): Specifies the supplier. Supported suppliers are 'glm' and 'qwen'.
- source (str): Specifies the supplier. Supported suppliers are 'openai', 'glm' and 'qwen'.
Returns:
- bool: True if the API key is valid, False otherwise.
Expand Down

0 comments on commit ef9bb24

Please sign in to comment.