From 012380373db2fd750e3035b10857ecbf3f4e00f7 Mon Sep 17 00:00:00 2001 From: piglei Date: Fri, 27 Dec 2024 08:45:25 +0800 Subject: [PATCH 1/2] feat: support DeepSeek API --- poetry.lock | 25 +++- pyproject.toml | 1 + tests/builder/test_ai_svc.py | 164 ++++++++++++++++++++++ voc_builder/builder/ai_svc.py | 204 ++++++++++++++++++---------- voc_builder/builder/views.py | 25 ++-- voc_builder/infras/ai.py | 77 ++++++++++- voc_builder/learn/views.py | 5 +- voc_builder/system/constants.py | 5 + voc_builder/system/models.py | 13 ++ voc_builder/system/serializers.py | 14 ++ voc_builder/system/views.py | 7 + voc_frontend/src/views/HomeView.vue | 2 +- voc_frontend/src/views/Settings.vue | 38 ++++++ 13 files changed, 491 insertions(+), 89 deletions(-) create mode 100644 tests/builder/test_ai_svc.py diff --git a/poetry.lock b/poetry.lock index 99ae7cf..6b69743 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1381,6 +1381,29 @@ type = "legacy" url = "https://mirrors.tencent.com/pypi/simple" reference = "tencent-mirror" +[[package]] +name = "pytest-asyncio" +version = "0.25.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3"}, + {file = "pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + +[package.source] +type = "legacy" +url = "https://mirrors.tencent.com/pypi/simple" +reference = "tencent-mirror" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2052,4 +2075,4 @@ reference = "tencent-mirror" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "560e1350ce4373ec76a21e2d4230f47f230b662af40730e1eac6e3304a0ad230" +content-hash = "a60ae7d091c5c67f774b56c9deda3caa5fdda135fed2a9776cd5337f8c5242e2" diff --git a/pyproject.toml b/pyproject.toml index 4ab7581..8cbc883 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ nox = "^2024.10.9" pytest = "^8.3.4" mypy = "^1.13.0" types-requests = "^2.32.0.20241016" +pytest-asyncio = "^0.25.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/builder/test_ai_svc.py b/tests/builder/test_ai_svc.py new file mode 100644 index 0000000..8cab4bf --- /dev/null +++ b/tests/builder/test_ai_svc.py @@ -0,0 +1,164 @@ +from types import SimpleNamespace +from unittest import mock + +import pytest + +from voc_builder.builder.ai_svc import ( + ManuallyWordQuerier, + RareWordQuerier, + WordChoiceModelResp, +) +from voc_builder.exceptions import AIServiceError +from voc_builder.infras.ai import AIResultMode + +# A valid JSON word reply +VALID_JSON_REPLY = """{ + "word": "synergy", + "word_base_form": "synergy", + "definitions": "[noun] 协同作用,协同效应", + "pronunciation": "ˈsɪnərdʒi" +}""" + +# A valid pydantic word reply +VALID_PYDANTIC_REPLY = WordChoiceModelResp( + word="synergy", + word_base_form="synergy", + definitions="[noun] 协同作用,协同效应", + pronunciation="ˈsɪnərdʒi", +) + + +@pytest.mark.asyncio +class TestRareWordQuerierJsonResult: + @pytest.mark.parametrize( + "data", + [ + # The API sometimes returns the JSON string with triple quotes + f"""```json\n{VALID_JSON_REPLY}\n```""", + VALID_JSON_REPLY, + ], + ) + @mock.patch("voc_builder.builder.ai_svc.JsonWordDefGetter.agent_request") + async def test_json_valid_response(self, mocker, data): + mocker.return_value = SimpleNamespace(data=data) + + word = await RareWordQuerier(None, result_mode=AIResultMode.JSON).query( + "The team's synergy was evident in their performance.", + set(), + "Simplified Chinese", + ) + + # Check the prompt + prompt = mocker.call_args[0][1] + assert all( + keyword in prompt.system + for keyword in [ + "rarely encountered word", + "JSON OUTPUT", + "List all possible definitions", + ] + ) + assert all(keyword in prompt.user for keyword in ["Word List:", "synergy"]) + + # Check the word object + assert word.word == "synergy" + assert word.definitions == ["[noun] 协同作用,协同效应"] + + @mock.patch("voc_builder.builder.ai_svc.JsonWordDefGetter.agent_request") + async def test_invalid_response(self, mocker): + mocker.return_value = SimpleNamespace(data="not a valid json") + + with pytest.raises(AIServiceError): + await RareWordQuerier(None, result_mode=AIResultMode.JSON).query( + "The team's synergy was evident in their performance.", + set(), + "Simplified Chinese", + ) + + +@pytest.mark.asyncio +class TestRareWordQuerierPydanticResult: + @mock.patch("voc_builder.builder.ai_svc.PydanticWordDefGetter.agent_request") + async def test_normal(self, mocker): + mocker.return_value = SimpleNamespace(data=VALID_PYDANTIC_REPLY) + + word = await RareWordQuerier(None, result_mode=AIResultMode.PYDANTIC).query( + "The team's synergy was evident in their performance.", + set(), + "Simplified Chinese", + ) + + # Check the prompt + prompt = mocker.call_args[0][1] + assert all( + keyword in prompt.system + for keyword in [ + "rarely encountered word", + "List all possible definitions", + ] + ) + assert all(keyword in prompt.user for keyword in ["Word List:", "synergy"]) + assert "JSON OUTPUT" not in prompt.system + + # Check the word object + assert word.word == "synergy" + assert word.definitions == ["[noun] 协同作用,协同效应"] + + +@pytest.mark.asyncio +class TestManuallyWordQuerierJSONResult: + @mock.patch("voc_builder.builder.ai_svc.JsonWordDefGetter.agent_request") + async def test_valid_json_result(self, mocker): + data = f"""```json\n{VALID_JSON_REPLY}\n```""" + mocker.return_value = SimpleNamespace(data=data) + + word = await ManuallyWordQuerier(None, result_mode=AIResultMode.JSON).query( + "The team's synergy was evident in their performance.", + "synergy", + "Simplified Chinese", + ) + + # Check the prompt + prompt = mocker.call_args[0][1] + assert all( + keyword in prompt.system + for keyword in [ + "an english word", + "JSON OUTPUT", + "List all possible definitions", + ] + ) + assert all(keyword in prompt.user for keyword in ["Word:", "synergy"]) + + # Check the word object + assert word.word == "synergy" + assert word.definitions == ["[noun] 协同作用,协同效应"] + + +@pytest.mark.asyncio +class TestManuallyWordQuerierPydanticResult: + @mock.patch("voc_builder.builder.ai_svc.PydanticWordDefGetter.agent_request") + async def test_valid_pydantic_result(self, mocker): + mocker.return_value = SimpleNamespace(data=VALID_PYDANTIC_REPLY) + + word = await ManuallyWordQuerier(None, result_mode=AIResultMode.PYDANTIC).query( + "The team's synergy was evident in their performance.", + "synergy", + "Simplified Chinese", + ) + + # Check the prompt + prompt = mocker.call_args[0][1] + assert all( + keyword in prompt.system + for keyword in [ + "an english word", + "List all possible definitions", + ] + ) + assert "JSON OUTPUT" not in prompt.system + assert all(keyword in prompt.user for keyword in ["Word:", "synergy"]) + + # Check the word object + assert word.word == "synergy" + assert word.definitions == ["[noun] 协同作用,协同效应"] diff --git a/voc_builder/builder/ai_svc.py b/voc_builder/builder/ai_svc.py index 5c18201..cbaa79a 100644 --- a/voc_builder/builder/ai_svc.py +++ b/voc_builder/builder/ai_svc.py @@ -1,5 +1,6 @@ import logging -from typing import AsyncGenerator, List, Set +import re +from typing import Any, AsyncGenerator, List, Set from pydantic import BaseModel from pydantic_ai import Agent @@ -7,6 +8,7 @@ from voc_builder.builder.models import WordChoice from voc_builder.common.text import get_word_candidates from voc_builder.exceptions import AIServiceError +from voc_builder.infras.ai import AIResultMode, PromptText logger = logging.getLogger() @@ -71,91 +73,153 @@ async def query_translation( yield message -# The prompt being used to extract multiple words -prompt_rare_word_system = """You are a english reading specialist, I will give you a list \ -of english words separated by ",", please find the most rarely encountered word as the result. \ +class RareWordQuerier: + """Query the AI to get the rare word.""" -Reply the result word, the base form, the {language} definition and \ -the pronunciation of the result word. + prompt_system_tmpl = """\ +You are a english reading specialist, I will give you a list \ +of english words separated by ",", please find the most rarely encountered word.""" -- List all possible definitions, separated by "$", with each formatted as \ -"[{{part of speech(adj/noun/...)}}] {{ {language} definition }}". - - Example: [noun] {language} definition1 $ [verb] {language} definition2 -- A paragraph will be given as a reference because there might be homographs. -""" # noqa: E501 + prompt_user_tmpl = """\ +Word List: {words} +Paragraph for reference: {text}""" -prompt_rare_word_user_tmpl = """\ -Words: {words} + def __init__(self, model, result_mode: AIResultMode): + self.model = model + self.result_mode = result_mode -Paragraph for reference: {text} -""" + async def query(self, text: str, known_words: Set[str], language: str) -> WordChoice: + """Query the most rarely word in the given text.""" + words = get_word_candidates(text, known_words=known_words) + if not words: + raise AIServiceError( + "Text does not contain any words that meet the criteria" + ) + prompt = PromptText( + system_lines=[self.prompt_system_tmpl.format(language=language)], + user_lines=[self.prompt_user_tmpl.format(text=text, words=", ".join(words))], + ) + return await word_def_getter_factory(self.result_mode).query( + self.model, prompt, language + ) -async def get_rare_word( - model, text: str, known_words: Set[str], language: str -) -> WordChoice: - """Get the most rarely word in given text.""" - words = get_word_candidates(text, known_words=known_words) - if not words: - raise AIServiceError("Text does not contain any words that meet the criteria") - user_content = prompt_rare_word_user_tmpl.format(text=text, words=", ".join(words)) - prompt = prompt_rare_word_system.format(language=language) + user_content - agent: Agent = Agent(model, result_type=WordChoiceModelResp) - try: - result = await agent.run(prompt) - except Exception as e: - raise AIServiceError("Error calling AI backend API: %s" % e) +class ManuallyWordQuerier: + """Get a word that is manually selected by user.""" - item = result.data - return WordChoice( - word=item.word, - word_normal=item.word_base_form, - pronunciation=item.pronunciation, - definitions=item.get_definition_list(), + prompt_system_tmpl = ( + "You are a translation assistant, I will give you an english word." ) + prompt_user_tmpl = """\ +Word: {word} -prompt_word_manually_system = """You are a translation assistant, I will give you a \ -a english word. +Paragraph for reference: {text}""" -Reply the word, the base form, the {language} definition and the \ -pronunciation of the word. + def __init__(self, model, result_mode: AIResultMode): + self.model = model + self.result_mode = result_mode -- List all possible definitions, separated by "$", with each formatted as \ -"[{{part of speech(adj/noun/...)}}] {{ {language} definition }}". - - Example: [noun] {language} definition1 $ [verb] {language} definition2 -- A paragraph will be given as a reference because there might be homographs. -""" # noqa: E501 + async def query(self, text: str, word: str, language: str) -> WordChoice: + """Query the manually selected word.""" + prompt = PromptText( + system_lines=[self.prompt_system_tmpl.format(language=language)], + user_lines=[self.prompt_user_tmpl.format(text=text, word=word)], + ) + return await word_def_getter_factory(self.result_mode).query( + self.model, prompt, language + ) -prompt_word_manually_user_tmpl = """\ -Word: {word} - -Paragraph for reference: {text} -""" +def word_def_getter_factory(result_mode: AIResultMode) -> "BaseWordDefGetter": + if result_mode == AIResultMode.PYDANTIC: + return PydanticWordDefGetter() + elif result_mode == AIResultMode.JSON: + return JsonWordDefGetter() + raise ValueError("Invalid result getting mode") -async def get_word_manually(model, text: str, word: str, language: str) -> WordChoice: - """Get a word that is manually selected by user. +class BaseWordDefGetter: + """Base class for getting word definition.""" - :param text: The text which contains the word. - :param word: The selected word. - :raise: AIServiceError - """ - user_content = prompt_word_manually_user_tmpl.format(text=text, word=word) - prompt = prompt_word_manually_system.format(language=language) + user_content - agent: Agent = Agent(model, result_type=WordChoiceModelResp) - try: - result = await agent.run(prompt) - except Exception as e: - raise AIServiceError("Error calling AI backend API: %s" % e) - - item = result.data - return WordChoice( - word=item.word, - word_normal=item.word_base_form, - pronunciation=item.pronunciation, - definitions=item.get_definition_list(), - ) + prompt_word_extra_reqs = """\ +- Reply the word, the base form, the {language} definition and \ +the pronunciation of the word. +- List all possible definitions, separated by "$", with each formatted as \ +"[{{part of speech(adj/noun/...)}}] {{ {language} definition }}". + - Example: [noun] {language} definition1 $ [verb] {language} definition2 +- A paragraph will be given as a reference because there might be homographs.""" + + async def query(self, model, prompt: PromptText, language: str) -> WordChoice: + """Query the AI to get the word definition. + + :param model: The AI model object. + :param prompt: The prompt text, it should make the AI return a word. + :param language: The language of the word definition. + """ + raise NotImplementedError + + def _to_word_choice(self, item: WordChoiceModelResp) -> WordChoice: + return WordChoice( + word=item.word, + word_normal=item.word_base_form, + pronunciation=item.pronunciation, + definitions=item.get_definition_list(), + ) + + +class JsonWordDefGetter(BaseWordDefGetter): + """Get a word's definitions, AI agent return JSON result.""" + + prompt_json_output = """\ +output the result in JSON format. + +EXAMPLE JSON OUTPUT: +{{ + "word": "...", + "word_base_form": "...", + "definitions": "...", + "pronunciation": "..." +}}""" + + async def query(self, model, prompt: PromptText, language: str) -> WordChoice: + prompt.system_lines.append(self.prompt_word_extra_reqs.format(language=language)) + prompt.system_lines.append(self.prompt_json_output) + result = await self.agent_request(model, prompt) + item = self._parse_json_output(result.data) + return self._to_word_choice(item) + + async def agent_request(self, model, prompt: PromptText) -> Any: + agent: Agent = Agent(model, system_prompt=prompt.system) + try: + return await agent.run(prompt.user) + except Exception as e: + raise AIServiceError("Error calling AI backend API: %s" % e) + + def _parse_json_output(self, data: str) -> WordChoiceModelResp: + """Parse the JSON output to get the word object.""" + obj = re.search(r"{[\s\S]*}", data, flags=re.MULTILINE) + if not obj: + raise AIServiceError("Invalid JSON output") + return WordChoiceModelResp.model_validate_json(obj.group()) + + +class PydanticWordDefGetter(BaseWordDefGetter): + """Get a word's definitions, AI agent return Pydantic result.""" + + async def query(self, model, prompt: PromptText, language: str) -> WordChoice: + """Query the word using Pydantic mode.""" + prompt.system_lines.append(self.prompt_word_extra_reqs.format(language=language)) + result = await self.agent_request(model, prompt) + return self._to_word_choice(result.data) + + async def agent_request(self, model, prompt: PromptText) -> Any: + agent: Agent = Agent( + model, system_prompt=prompt.system, result_type=WordChoiceModelResp + ) + try: + return await agent.run(prompt.user) + except Exception as e: + raise AIServiceError("Error calling AI backend API: %s" % e) diff --git a/voc_builder/builder/views.py b/voc_builder/builder/views.py index de5131b..9b4467d 100644 --- a/voc_builder/builder/views.py +++ b/voc_builder/builder/views.py @@ -11,11 +11,15 @@ from voc_builder.common.errors import error_codes from voc_builder.common.text import tokenize_text from voc_builder.exceptions import AIServiceError -from voc_builder.infras.ai import create_ai_model +from voc_builder.infras.ai import create_ai_model_config from voc_builder.infras.store import get_mastered_word_store, get_word_store from voc_builder.system.language import get_target_language -from .ai_svc import get_rare_word, get_translation, get_word_manually +from .ai_svc import ( + ManuallyWordQuerier, + RareWordQuerier, + get_translation, +) from .serializers import ( DeleteWordsInput, GetKnownWordsByTextInput, @@ -44,8 +48,9 @@ async def gen_translation_sse(text: str) -> AsyncGenerator[Dict, None]: """ try: + model_config = create_ai_model_config() async for translated_text in get_translation( - create_ai_model(), text, get_target_language() + model_config.model, text, get_target_language() ): yield { "event": "trans_partial", @@ -72,9 +77,10 @@ async def create_word_sample(trans_obj: TranslatedTextInput, response: Response) known_words = word_store.filter(orig_words) | mastered_word_s.filter(orig_words) try: - choice = await get_rare_word( - create_ai_model(), trans_obj.orig_text, known_words, get_target_language() - ) + model_config = create_ai_model_config() + choice = await RareWordQuerier( + model_config.model, model_config.result_mode + ).query(trans_obj.orig_text, known_words, get_target_language()) except Exception as exc: logger.exception("Error extracting word.") raise error_codes.EXACTING_WORD_FAILED.format(str(exc)) @@ -162,9 +168,10 @@ async def manually_save(req: ManuallySelectInput, response: Response): word_store = get_word_store() try: - choice = await get_word_manually( - create_ai_model(), req.orig_text, req.word, get_target_language() - ) + model_config = create_ai_model_config() + choice = await ManuallyWordQuerier( + model_config.model, model_config.result_mode + ).query(req.orig_text, req.word, get_target_language()) except Exception as exc: raise error_codes.MANUALLY_SAVE_WORD_FAILED.format(str(exc)) diff --git a/voc_builder/infras/ai.py b/voc_builder/infras/ai.py index de1e71c..32155f7 100644 --- a/voc_builder/infras/ai.py +++ b/voc_builder/infras/ai.py @@ -1,5 +1,7 @@ import logging -from typing import List +from dataclasses import dataclass +from enum import Enum +from typing import Any, List from anthropic import AsyncAnthropic from openai import AsyncOpenAI @@ -10,10 +12,51 @@ from voc_builder.exceptions import AIModelNotConfiguredError from voc_builder.infras.store import get_sys_settings_store +from voc_builder.system.models import SystemSettings logger = logging.getLogger(__name__) +@dataclass +class PromptText: + """A simple prompt type helps configuring AI prompt. + + :param system: The system prompts. + :param user: The user prompts. + """ + + system_lines: List[str] + user_lines: List[str] + + @property + def system(self) -> str: + return "\n\n".join(self.system_lines) + + @property + def user(self) -> str: + return "\n\n".join(self.user_lines) + + +class AIResultMode(str, Enum): + """The mode to get the AI result.""" + + PYDANTIC = "pydantic" + # JSON mode is for some OpenAI compatible API that doesn't support function call + JSON = "json" + + +@dataclass +class AIModelConfig: + """The AI model configuration, it controls how to interact with the AI model. + + :param model: The AI model object. + :param result_mode: The result mode to get the AI result. + """ + + model: Any + result_mode: AIResultMode + + class WordChoiceModelResp(BaseModel): """The word returned by LLM service.""" @@ -31,15 +74,29 @@ def get_definition_list(self) -> List[str]: return [d.strip() for d in self.definitions.split("$")] -def create_ai_model(): - """Create the AI model object for calling with LLM service. - - :raise AIModelNotConfiguredError: when the model settings is invalid. - """ +def create_ai_model_config() -> AIModelConfig: + """Create the AI model configuration.""" settings = get_sys_settings_store().get_system_settings() if not settings: raise AIModelNotConfiguredError("System settings not found") + model = create_ai_model(settings) + if settings.model_provider == "deepseek": + result_mode = AIResultMode.JSON + else: + result_mode = AIResultMode.PYDANTIC + return AIModelConfig(model, result_mode) + + +# The default base URL for Deepseek API +DEEPSEEK_DEFAULT_BASE_URL = "https://api.deepseek.com" + + +def create_ai_model(settings: SystemSettings): + """Create the AI model object for calling with LLM service. + + :raise AIModelNotConfiguredError: when the model settings is invalid. + """ if settings.model_provider == "openai": openai_config = settings.openai_config client = AsyncOpenAI( @@ -67,5 +124,13 @@ def create_ai_model(): api_key=anthropic_config.api_key, base_url=anthropic_config.api_host or None ) return AnthropicModel(anthropic_config.model, anthropic_client=a_client) + elif settings.model_provider == "deepseek": + deepseek_config = settings.deepseek_config + assert deepseek_config + client = AsyncOpenAI( + api_key=deepseek_config.api_key, + base_url=deepseek_config.api_host or DEEPSEEK_DEFAULT_BASE_URL, + ) + return OpenAIModel(deepseek_config.model, openai_client=client) else: raise AIModelNotConfiguredError("Unknown model provider") diff --git a/voc_builder/learn/views.py b/voc_builder/learn/views.py index 48f7e70..83f0174 100644 --- a/voc_builder/learn/views.py +++ b/voc_builder/learn/views.py @@ -11,7 +11,7 @@ from voc_builder.builder.models import WordSample from voc_builder.builder.serializers import WordSampleOutput from voc_builder.exceptions import AIServiceError -from voc_builder.infras.ai import create_ai_model +from voc_builder.infras.ai import create_ai_model_config from voc_builder.infras.store import get_mastered_word_store, get_word_store from voc_builder.misc.export import VocCSVWriter @@ -48,7 +48,8 @@ async def gen_story_sse(words: List[WordSample]) -> AsyncGenerator[Dict, None]: } try: - async for text in get_story(create_ai_model(), words): + model_config = create_ai_model_config() + async for text in get_story(model_config.model, words): yield {"event": "story_partial", "data": text} except AIServiceError as e: yield {"event": "error", "data": json.dumps({"message": str(e)})} diff --git a/voc_builder/system/constants.py b/voc_builder/system/constants.py index bfcb7cd..f72c0bc 100644 --- a/voc_builder/system/constants.py +++ b/voc_builder/system/constants.py @@ -10,6 +10,7 @@ class ModelProvider(Enum): OPENAI = "openai" GEMINI = "gemini" ANTHROPIC = "anthropic" + DEEPSEEK = "deepseek" @define @@ -97,3 +98,7 @@ def get_by_code(cls, code: str) -> "Optional[TargetLanguage]": "claude-3-5-sonnet-latest", "claude-3-opus-latest", ] + +DEEPSEEK_MODELS = [ + "deepseek-chat", +] diff --git a/voc_builder/system/models.py b/voc_builder/system/models.py index f1004fd..f9f9865 100644 --- a/voc_builder/system/models.py +++ b/voc_builder/system/models.py @@ -29,6 +29,15 @@ class AnthropicConfig: model: str +@dataclass +class DeepSeekConfig: + """The configuration of DeepSeek service.""" + + api_key: str + api_host: str + model: str + + @dataclass class SystemSettings: """The system settings for the project. @@ -41,11 +50,14 @@ class SystemSettings: gemini_config: GeminiConfig anthropic_config: Optional[AnthropicConfig] = None + deepseek_config: Optional[DeepSeekConfig] = None target_language: str = "" def __post_init__(self): if self.anthropic_config is None: self.anthropic_config = AnthropicConfig(api_key="", api_host="", model="") + if self.deepseek_config is None: + self.deepseek_config = DeepSeekConfig(api_key="", api_host="", model="") def build_default_settings() -> SystemSettings: @@ -55,4 +67,5 @@ def build_default_settings() -> SystemSettings: openai_config=OpenAIConfig(api_key="", api_host="", model=""), gemini_config=GeminiConfig(api_key="", api_host="", model=""), anthropic_config=AnthropicConfig(api_key="", api_host="", model=""), + deepseek_config=DeepSeekConfig(api_key="", api_host="", model=""), ) diff --git a/voc_builder/system/serializers.py b/voc_builder/system/serializers.py index 2fbeea6..224bc15 100644 --- a/voc_builder/system/serializers.py +++ b/voc_builder/system/serializers.py @@ -15,6 +15,7 @@ class SettingsInput(BaseModel): openai_config: Dict[str, Any] gemini_config: Dict[str, Any] anthropic_config: Dict[str, Any] + deepseek_config: Dict[str, Any] @field_validator("target_language") @classmethod @@ -72,3 +73,16 @@ def validate_model(cls, value): if not value: raise ValueError("model is required") return value + + +class DeepSeekConfigInput(BaseModel): + api_key: str = Field(..., min_length=1) + api_host: Union[AnyHttpUrl, Literal[""]] + model: str + + @field_validator("model") + @classmethod + def validate_model(cls, value): + if not value: + raise ValueError("model is required") + return value diff --git a/voc_builder/system/views.py b/voc_builder/system/views.py index b6b7e6b..dd8e603 100644 --- a/voc_builder/system/views.py +++ b/voc_builder/system/views.py @@ -9,6 +9,7 @@ from voc_builder.misc.version import get_new_version from voc_builder.system.constants import ( ANTHROPIC_MODELS, + DEEPSEEK_MODELS, GEMINI_MODELS, OPENAI_MODELS, ModelProvider, @@ -17,6 +18,7 @@ from voc_builder.system.language import get_target_language from voc_builder.system.models import ( AnthropicConfig, + DeepSeekConfig, GeminiConfig, OpenAIConfig, build_default_settings, @@ -24,6 +26,7 @@ from .serializers import ( AnthropicConfigInput, + DeepSeekConfigInput, GeminiConfigInput, OpenAIConfigInput, SettingsInput, @@ -75,6 +78,7 @@ async def get_settings(response: Response): "gemini": GEMINI_MODELS, "openai": OPENAI_MODELS, "anthropic": ANTHROPIC_MODELS, + "deepseek": DEEPSEEK_MODELS, }, "target_language_options": [ cattrs.unstructure(lan.value) for lan in TargetLanguage @@ -104,6 +108,9 @@ async def save_settings(settings_input: SettingsInput, response: Response): elif settings_input.model_provider == ModelProvider.ANTHROPIC.value: a_obj = AnthropicConfigInput(**settings_input.anthropic_config) settings.anthropic_config = AnthropicConfig(**a_obj.model_dump(mode="json")) + elif settings_input.model_provider == ModelProvider.DEEPSEEK.value: + d_obj = DeepSeekConfigInput(**settings_input.deepseek_config) + settings.deepseek_config = DeepSeekConfig(**d_obj.model_dump(mode="json")) settings_store.set_system_settings(settings) return {} diff --git a/voc_frontend/src/views/HomeView.vue b/voc_frontend/src/views/HomeView.vue index 6c514dd..716c862 100644 --- a/voc_frontend/src/views/HomeView.vue +++ b/voc_frontend/src/views/HomeView.vue @@ -116,7 +116,7 @@ function extractWord() { source.addEventListener('trans_partial', (event) => { const parsedData = JSON.parse(event.data) - liveTranslatedText.value += parsedData.translated_text + liveTranslatedText.value = parsedData.translated_text }) // Translation finished, set the text and start extracting diff --git a/voc_frontend/src/views/Settings.vue b/voc_frontend/src/views/Settings.vue index bda9373..935b252 100644 --- a/voc_frontend/src/views/Settings.vue +++ b/voc_frontend/src/views/Settings.vue @@ -22,6 +22,11 @@ const settings = reactive({ api_key: "", api_host: "", model: "" + }, + deepseek_config: { + api_key: "", + api_host: "", + model: "" } }) @@ -50,6 +55,11 @@ const toggleAnthropicApiKeyVisibility = () => { showAnthropicApiKey.value = !showAnthropicApiKey.value; }; +const showDeepSeekApiKey = ref(false); + +const toggleDeepSeekApiKeyVisibility = () => { + showDeepSeekApiKey.value = !showDeepSeekApiKey.value; +}; onMounted(() => { loadSettings() @@ -114,6 +124,7 @@ const handleSubmit = async (event: Event) => { + @@ -192,6 +203,33 @@ const handleSubmit = async (event: Event) => { + +
+
+ +
+ + +
+ +
+
+ + +
+ Optional, the default value is https://api.deepseek.com +
+
+
+ + +
+
+