From 439cff3b4802c3e125353243d1b5b432204a999d Mon Sep 17 00:00:00 2001 From: Sampson Date: Tue, 17 Dec 2024 07:10:01 -0600 Subject: [PATCH] Move from OAI to BYOM With the growth of the BYOM model, Brave aims to support popular service endpoints. Originally, the OpenAI model was supported, which was reflected in relevant files and code. With the growing adoption of alternatives (e.g., Ollama locally, Anthropic, Perplexity, Azure hosted models, etc.), it is important to avoid confusion throughout the codebase. This change therefore partially renames files to reflect the move from "oai" to "byom", as well as updates classes, header files, and more. This migration lays the groundwork for a clearer, more maintainable path towards supporting even more endpoints and API patterns. Also, this change addresses the presubmit warnings associated with associated unit tests (specifically, a couple lines of JSON within). --- components/ai_chat/core/browser/BUILD.gn | 12 ++-- .../{oai_api_client.cc => byom_api_client.cc} | 22 +++---- .../{oai_api_client.h => byom_api_client.h} | 20 +++--- ...nittest.cc => byom_api_client_unittest.cc} | 66 +++++++++++++++---- ...onsumer_oai.cc => engine_consumer_byom.cc} | 26 ++++---- ..._consumer_oai.h => engine_consumer_byom.h} | 26 ++++---- ...st.cc => engine_consumer_byom_unittest.cc} | 40 +++++------ .../ai_chat/core/browser/model_service.cc | 6 +- 8 files changed, 129 insertions(+), 89 deletions(-) rename components/ai_chat/core/browser/engine/{oai_api_client.cc => byom_api_client.cc} (90%) rename components/ai_chat/core/browser/engine/{oai_api_client.h => byom_api_client.h} (82%) rename components/ai_chat/core/browser/engine/{oai_api_client_unittest.cc => byom_api_client_unittest.cc} (80%) rename components/ai_chat/core/browser/engine/{engine_consumer_oai.cc => engine_consumer_byom.cc} (92%) rename components/ai_chat/core/browser/engine/{engine_consumer_oai.h => engine_consumer_byom.h} (79%) rename components/ai_chat/core/browser/engine/{engine_consumer_oai_unittest.cc => engine_consumer_byom_unittest.cc} (94%) diff --git a/components/ai_chat/core/browser/BUILD.gn b/components/ai_chat/core/browser/BUILD.gn index 26657a9a258a..19e352882804 100644 --- a/components/ai_chat/core/browser/BUILD.gn +++ b/components/ai_chat/core/browser/BUILD.gn @@ -26,20 +26,20 @@ static_library("browser") { "constants.h", "conversation_handler.cc", "conversation_handler.h", + "engine/byom_api_client.cc", + "engine/byom_api_client.h", "engine/conversation_api_client.cc", "engine/conversation_api_client.h", "engine/engine_consumer.cc", "engine/engine_consumer.h", + "engine/engine_consumer_byom.cc", + "engine/engine_consumer_byom.h", "engine/engine_consumer_claude.cc", "engine/engine_consumer_claude.h", "engine/engine_consumer_conversation_api.cc", "engine/engine_consumer_conversation_api.h", "engine/engine_consumer_llama.cc", "engine/engine_consumer_llama.h", - "engine/engine_consumer_oai.cc", - "engine/engine_consumer_oai.h", - "engine/oai_api_client.cc", - "engine/oai_api_client.h", "engine/remote_completion_client.cc", "engine/remote_completion_client.h", "local_models_updater.cc", @@ -135,12 +135,12 @@ if (!is_ios) { "ai_chat_service_unittest.cc", "associated_content_driver_unittest.cc", "conversation_handler_unittest.cc", + "engine/byom_api_client_unittest.cc", "engine/conversation_api_client_unittest.cc", + "engine/engine_consumer_byom_unittest.cc", "engine/engine_consumer_claude_unittest.cc", "engine/engine_consumer_conversation_api_unittest.cc", "engine/engine_consumer_llama_unittest.cc", - "engine/engine_consumer_oai_unittest.cc", - "engine/oai_api_client_unittest.cc", "engine/test_utils.cc", "engine/test_utils.h", "local_models_updater_unittest.cc", diff --git a/components/ai_chat/core/browser/engine/oai_api_client.cc b/components/ai_chat/core/browser/engine/byom_api_client.cc similarity index 90% rename from components/ai_chat/core/browser/engine/oai_api_client.cc rename to components/ai_chat/core/browser/engine/byom_api_client.cc index 33f9ceae31cb..099b93ec80aa 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client.cc +++ b/components/ai_chat/core/browser/engine/byom_api_client.cc @@ -3,7 +3,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" +#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h" #include #include @@ -75,19 +75,19 @@ std::string CreateJSONRequestBody( } // namespace -OAIAPIClient::OAIAPIClient( +BYOMAPIClient::BYOMAPIClient( scoped_refptr url_loader_factory) { api_request_helper_ = std::make_unique( GetNetworkTrafficAnnotationTag(), url_loader_factory); } -OAIAPIClient::~OAIAPIClient() = default; +BYOMAPIClient::~BYOMAPIClient() = default; -void OAIAPIClient::ClearAllQueries() { +void BYOMAPIClient::ClearAllQueries() { api_request_helper_->CancelAll(); } -void OAIAPIClient::PerformRequest( +void BYOMAPIClient::PerformRequest( const mojom::CustomModelOptions& model_options, base::Value::List messages, GenerationDataCallback data_received_callback, @@ -108,10 +108,10 @@ void OAIAPIClient::PerformRequest( } if (is_sse_enabled) { - auto on_received = base::BindRepeating(&OAIAPIClient::OnQueryDataReceived, + auto on_received = base::BindRepeating(&BYOMAPIClient::OnQueryDataReceived, weak_ptr_factory_.GetWeakPtr(), std::move(data_received_callback)); - auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted, + auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted, weak_ptr_factory_.GetWeakPtr(), std::move(completed_callback)); @@ -120,7 +120,7 @@ void OAIAPIClient::PerformRequest( "application/json", std::move(on_received), std::move(on_complete), headers, {}); } else { - auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted, + auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted, weak_ptr_factory_.GetWeakPtr(), std::move(completed_callback)); api_request_helper_->Request( @@ -129,8 +129,8 @@ void OAIAPIClient::PerformRequest( } } -void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback, - APIRequestResult result) { +void BYOMAPIClient::OnQueryCompleted(GenerationCompletedCallback callback, + APIRequestResult result) { const bool success = result.Is2XXResponseCode(); // Handle successful request if (success) { @@ -162,7 +162,7 @@ void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback, std::move(callback).Run(base::unexpected(mojom::APIError::ConnectionIssue)); } -void OAIAPIClient::OnQueryDataReceived( +void BYOMAPIClient::OnQueryDataReceived( GenerationDataCallback callback, base::expected result) { if (!result.has_value() || !result->is_dict()) { diff --git a/components/ai_chat/core/browser/engine/oai_api_client.h b/components/ai_chat/core/browser/engine/byom_api_client.h similarity index 82% rename from components/ai_chat/core/browser/engine/oai_api_client.h rename to components/ai_chat/core/browser/engine/byom_api_client.h index def1b5d726ad..ea4fac5fcc1c 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client.h +++ b/components/ai_chat/core/browser/engine/byom_api_client.h @@ -3,8 +3,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_ -#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_ +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_ #include #include @@ -33,8 +33,8 @@ namespace mojom { class CustomModelOptions; } // namespace mojom -// Performs remote request to the OAI format APIs. -class OAIAPIClient { +// Performs remote request to BYOM endpoints. +class BYOMAPIClient { public: using GenerationResult = base::expected; using GenerationDataCallback = @@ -42,12 +42,12 @@ class OAIAPIClient { using GenerationCompletedCallback = base::OnceCallback; - explicit OAIAPIClient( + explicit BYOMAPIClient( scoped_refptr url_loader_factory); - OAIAPIClient(const OAIAPIClient&) = delete; - OAIAPIClient& operator=(const OAIAPIClient&) = delete; - virtual ~OAIAPIClient(); + BYOMAPIClient(const BYOMAPIClient&) = delete; + BYOMAPIClient& operator=(const BYOMAPIClient&) = delete; + virtual ~BYOMAPIClient(); virtual void PerformRequest(const mojom::CustomModelOptions& model_options, base::Value::List messages, @@ -73,9 +73,9 @@ class OAIAPIClient { std::unique_ptr api_request_helper_; - base::WeakPtrFactory weak_ptr_factory_{this}; + base::WeakPtrFactory weak_ptr_factory_{this}; }; } // namespace ai_chat -#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_ +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_ diff --git a/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc b/components/ai_chat/core/browser/engine/byom_api_client_unittest.cc similarity index 80% rename from components/ai_chat/core/browser/engine/oai_api_client_unittest.cc rename to components/ai_chat/core/browser/engine/byom_api_client_unittest.cc index 7c5db050bcf9..43227a32135d 100644 --- a/components/ai_chat/core/browser/engine/oai_api_client_unittest.cc +++ b/components/ai_chat/core/browser/engine/byom_api_client_unittest.cc @@ -3,7 +3,7 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at https://mozilla.org/MPL/2.0/. */ -#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" +#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h" #include #include @@ -40,7 +40,7 @@ using DataReceivedCallback = api_request_helper::APIRequestHelper::DataReceivedCallback; using ResultCallback = api_request_helper::APIRequestHelper::ResultCallback; using Ticket = api_request_helper::APIRequestHelper::Ticket; -using GenerationResult = ai_chat::OAIAPIClient::GenerationResult; +using GenerationResult = ai_chat::BYOMAPIClient::GenerationResult; namespace ai_chat { @@ -66,26 +66,26 @@ class MockAPIRequestHelper : public api_request_helper::APIRequestHelper { (override)); }; -class TestOAIAPIClient : public OAIAPIClient { +class TestBYOMAPIClient : public BYOMAPIClient { public: - TestOAIAPIClient() : OAIAPIClient(nullptr) { + TestBYOMAPIClient() : BYOMAPIClient(nullptr) { SetAPIRequestHelperForTesting(std::make_unique( net::NetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS), nullptr)); } - ~TestOAIAPIClient() override = default; + ~TestBYOMAPIClient() override = default; MockAPIRequestHelper* GetMockAPIRequestHelper() { return static_cast(GetAPIRequestHelperForTesting()); } }; -class OAIAPIUnitTest : public testing::Test { +class BYOMAPIUnitTest : public testing::Test { public: - OAIAPIUnitTest() = default; - ~OAIAPIUnitTest() override = default; + BYOMAPIUnitTest() = default; + ~BYOMAPIUnitTest() override = default; - void SetUp() override { client_ = std::make_unique(); } + void SetUp() override { client_ = std::make_unique(); } void TearDown() override {} @@ -112,18 +112,58 @@ class OAIAPIUnitTest : public testing::Test { protected: base::test::TaskEnvironment task_environment_; - std::unique_ptr client_; + std::unique_ptr client_; }; -TEST_F(OAIAPIUnitTest, PerformRequest) { +TEST_F(BYOMAPIUnitTest, PerformRequest) { mojom::CustomModelOptionsPtr model_options = mojom::CustomModelOptions::New( "test_api_key", 0, 0, 0, "test_system_prompt", GURL("https://test.com"), "test_model"); std::string server_chunk = - R"({"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"It was played in Arlington, Texas."},"logprobs":null,"finish_reason":null}]})"; + R"({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1694268190, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": "It was played in Arlington, Texas." + }, + "logprobs": null, + "finish_reason": null + } + ] + })"; + std::string server_completion = - R"({"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"message":{"role":"assistant","content":"\n\nCan I assist you further?"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}})"; + R"({ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nCan I assist you further?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + })"; std::string expected_chunk_response = "It was played in Arlington, Texas."; std::string expected_completion_response = "\n\nCan I assist you further?"; diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai.cc b/components/ai_chat/core/browser/engine/engine_consumer_byom.cc similarity index 92% rename from components/ai_chat/core/browser/engine/engine_consumer_oai.cc rename to components/ai_chat/core/browser/engine/engine_consumer_byom.cc index c22690550eeb..6f281b2b92ba 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_byom.cc @@ -3,7 +3,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -#include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h" +#include "brave/components/ai_chat/core/browser/engine/engine_consumer_byom.h" #include #include @@ -117,27 +117,27 @@ base::Value::List BuildMessages( } // namespace -EngineConsumerOAIRemote::EngineConsumerOAIRemote( +EngineConsumerBYOMRemote::EngineConsumerBYOMRemote( const mojom::CustomModelOptions& model_options, scoped_refptr url_loader_factory) { model_options_ = model_options; max_associated_content_length_ = model_options.max_associated_content_length; // Initialize the API client - api_ = std::make_unique(url_loader_factory); + api_ = std::make_unique(url_loader_factory); } -EngineConsumerOAIRemote::~EngineConsumerOAIRemote() = default; +EngineConsumerBYOMRemote::~EngineConsumerBYOMRemote() = default; -void EngineConsumerOAIRemote::ClearAllQueries() { +void EngineConsumerBYOMRemote::ClearAllQueries() { api_->ClearAllQueries(); } -bool EngineConsumerOAIRemote::SupportsDeltaTextResponses() const { +bool EngineConsumerBYOMRemote::SupportsDeltaTextResponses() const { return true; } -void EngineConsumerOAIRemote::UpdateModelOptions( +void EngineConsumerBYOMRemote::UpdateModelOptions( const mojom::ModelOptions& options) { if (options.is_custom_model_options()) { model_options_ = *options.get_custom_model_options(); @@ -146,7 +146,7 @@ void EngineConsumerOAIRemote::UpdateModelOptions( } } -void EngineConsumerOAIRemote::GenerateRewriteSuggestion( +void EngineConsumerBYOMRemote::GenerateRewriteSuggestion( std::string text, const std::string& question, const std::string& selected_language, @@ -173,7 +173,7 @@ void EngineConsumerOAIRemote::GenerateRewriteSuggestion( std::move(completed_callback)); } -void EngineConsumerOAIRemote::GenerateQuestionSuggestions( +void EngineConsumerBYOMRemote::GenerateQuestionSuggestions( const bool& is_video, const std::string& page_content, const std::string& selected_language, @@ -207,11 +207,11 @@ void EngineConsumerOAIRemote::GenerateQuestionSuggestions( api_->PerformRequest( model_options_, std::move(messages), base::NullCallback(), base::BindOnce( - &EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse, + &EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse, weak_ptr_factory_.GetWeakPtr(), std::move(callback))); } -void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse( +void EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse( SuggestedQuestionsCallback callback, GenerationResult result) { if (!result.has_value() || result->empty()) { @@ -242,7 +242,7 @@ void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse( std::move(callback).Run(std::move(questions)); } -void EngineConsumerOAIRemote::GenerateAssistantResponse( +void EngineConsumerBYOMRemote::GenerateAssistantResponse( const bool& is_video, const std::string& page_content, const ConversationHistory& conversation_history, @@ -274,6 +274,6 @@ void EngineConsumerOAIRemote::GenerateAssistantResponse( std::move(completed_callback)); } -void EngineConsumerOAIRemote::SanitizeInput(std::string& input) {} +void EngineConsumerBYOMRemote::SanitizeInput(std::string& input) {} } // namespace ai_chat diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai.h b/components/ai_chat/core/browser/engine/engine_consumer_byom.h similarity index 79% rename from components/ai_chat/core/browser/engine/engine_consumer_oai.h rename to components/ai_chat/core/browser/engine/engine_consumer_byom.h index 8586fd045996..5ba2a7c1e226 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai.h +++ b/components/ai_chat/core/browser/engine/engine_consumer_byom.h @@ -3,8 +3,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_OAI_H_ -#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_OAI_H_ +#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_BYOM_H_ +#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_BYOM_H_ #include #include @@ -12,8 +12,8 @@ #include "base/memory/weak_ptr.h" #include "brave/components/ai_chat/core/browser/ai_chat_credential_manager.h" +#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" -#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" @@ -32,14 +32,14 @@ namespace ai_chat { using api_request_helper::APIRequestResult; -class EngineConsumerOAIRemote : public EngineConsumer { +class EngineConsumerBYOMRemote : public EngineConsumer { public: - explicit EngineConsumerOAIRemote( + explicit EngineConsumerBYOMRemote( const mojom::CustomModelOptions& model_options, scoped_refptr url_loader_factory); - EngineConsumerOAIRemote(const EngineConsumerOAIRemote&) = delete; - EngineConsumerOAIRemote& operator=(const EngineConsumerOAIRemote&) = delete; - ~EngineConsumerOAIRemote() override; + EngineConsumerBYOMRemote(const EngineConsumerBYOMRemote&) = delete; + EngineConsumerBYOMRemote& operator=(const EngineConsumerBYOMRemote&) = delete; + ~EngineConsumerBYOMRemote() override; // EngineConsumer void GenerateQuestionSuggestions( @@ -65,10 +65,10 @@ class EngineConsumerOAIRemote : public EngineConsumer { void ClearAllQueries() override; bool SupportsDeltaTextResponses() const override; - void SetAPIForTesting(std::unique_ptr api_for_testing) { + void SetAPIForTesting(std::unique_ptr api_for_testing) { api_ = std::move(api_for_testing); } - OAIAPIClient* GetAPIForTesting() { return api_.get(); } + BYOMAPIClient* GetAPIForTesting() { return api_.get(); } void UpdateModelOptions(const mojom::ModelOptions& options) override; private: @@ -76,12 +76,12 @@ class EngineConsumerOAIRemote : public EngineConsumer { SuggestedQuestionsCallback callback, GenerationResult result); - std::unique_ptr api_ = nullptr; + std::unique_ptr api_ = nullptr; mojom::CustomModelOptions model_options_; - base::WeakPtrFactory weak_ptr_factory_{this}; + base::WeakPtrFactory weak_ptr_factory_{this}; }; } // namespace ai_chat -#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_OAI_H_ +#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_ENGINE_CONSUMER_BYOM_H_ diff --git a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc b/components/ai_chat/core/browser/engine/engine_consumer_byom_unittest.cc similarity index 94% rename from components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc rename to components/ai_chat/core/browser/engine/engine_consumer_byom_unittest.cc index 6d044412cc51..c52184f76c6e 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_oai_unittest.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_byom_unittest.cc @@ -3,7 +3,7 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at https://mozilla.org/MPL/2.0/. */ -#include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h" +#include "brave/components/ai_chat/core/browser/engine/engine_consumer_byom.h" #include #include @@ -45,10 +45,10 @@ class MockCallback { MOCK_METHOD(void, OnCompleted, (EngineConsumer::GenerationResult)); }; -class MockOAIAPIClient : public OAIAPIClient { +class MockBYOMAPIClient : public BYOMAPIClient { public: - MockOAIAPIClient() : OAIAPIClient(nullptr) {} - ~MockOAIAPIClient() override = default; + MockBYOMAPIClient() : BYOMAPIClient(nullptr) {} + ~MockBYOMAPIClient() override = default; MOCK_METHOD(void, PerformRequest, @@ -59,10 +59,10 @@ class MockOAIAPIClient : public OAIAPIClient { (override)); }; -class EngineConsumerOAIUnitTest : public testing::Test { +class EngineConsumerBYOMUnitTest : public testing::Test { public: - EngineConsumerOAIUnitTest() = default; - ~EngineConsumerOAIUnitTest() override = default; + EngineConsumerBYOMUnitTest() = default; + ~EngineConsumerBYOMUnitTest() override = default; void SetUp() override { auto options = mojom::CustomModelOptions::New(); @@ -81,14 +81,14 @@ class EngineConsumerOAIUnitTest : public testing::Test { model_->options = mojom::ModelOptions::NewCustomModelOptions(std::move(options)); - engine_ = std::make_unique( + engine_ = std::make_unique( *model_->options->get_custom_model_options(), nullptr); - engine_->SetAPIForTesting(std::make_unique()); + engine_->SetAPIForTesting(std::make_unique()); } - MockOAIAPIClient* GetClient() { - return static_cast(engine_->GetAPIForTesting()); + MockBYOMAPIClient* GetClient() { + return static_cast(engine_->GetAPIForTesting()); } void TearDown() override {} @@ -96,10 +96,10 @@ class EngineConsumerOAIUnitTest : public testing::Test { protected: base::test::TaskEnvironment task_environment_; mojom::ModelPtr model_; - std::unique_ptr engine_; + std::unique_ptr engine_; }; -TEST_F(EngineConsumerOAIUnitTest, UpdateModelOptions) { +TEST_F(EngineConsumerBYOMUnitTest, UpdateModelOptions) { auto* client = GetClient(); base::RunLoop run_loop; @@ -148,7 +148,7 @@ TEST_F(EngineConsumerOAIUnitTest, UpdateModelOptions) { testing::Mock::VerifyAndClearExpectations(client); } -TEST_F(EngineConsumerOAIUnitTest, GenerateQuestionSuggestions) { +TEST_F(EngineConsumerBYOMUnitTest, GenerateQuestionSuggestions) { std::string page_content = "This is a test page content"; auto* client = GetClient(); @@ -211,7 +211,7 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateQuestionSuggestions) { testing::Mock::VerifyAndClearExpectations(client); } -TEST_F(EngineConsumerOAIUnitTest, +TEST_F(EngineConsumerBYOMUnitTest, GenerateAssistantResponseWithDefaultSystemPrompt) { // Create a set of options WITHOUT a custom system prompt. auto options = mojom::CustomModelOptions::New(); @@ -227,9 +227,9 @@ TEST_F(EngineConsumerOAIUnitTest, mojom::ModelOptions::NewCustomModelOptions(std::move(options)); // Create a new engine with the new model. - engine_ = std::make_unique( + engine_ = std::make_unique( *model_->options->get_custom_model_options(), nullptr); - engine_->SetAPIForTesting(std::make_unique()); + engine_->SetAPIForTesting(std::make_unique()); EngineConsumer::ConversationHistory history; @@ -297,7 +297,7 @@ TEST_F(EngineConsumerOAIUnitTest, testing::Mock::VerifyAndClearExpectations(client); } -TEST_F(EngineConsumerOAIUnitTest, +TEST_F(EngineConsumerBYOMUnitTest, TestGenerateAssistantResponseWithCustomSystemPrompt) { EngineConsumer::ConversationHistory history; @@ -406,7 +406,7 @@ TEST_F(EngineConsumerOAIUnitTest, testing::Mock::VerifyAndClearExpectations(client); } -TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseEarlyReturn) { +TEST_F(EngineConsumerBYOMUnitTest, GenerateAssistantResponseEarlyReturn) { EngineConsumer::ConversationHistory history; auto* client = GetClient(); auto run_loop = std::make_unique(); @@ -441,7 +441,7 @@ TEST_F(EngineConsumerOAIUnitTest, GenerateAssistantResponseEarlyReturn) { testing::Mock::VerifyAndClearExpectations(client); } -TEST_F(EngineConsumerOAIUnitTest, SummarizePage) { +TEST_F(EngineConsumerBYOMUnitTest, SummarizePage) { auto* client = GetClient(); base::RunLoop run_loop; diff --git a/components/ai_chat/core/browser/model_service.cc b/components/ai_chat/core/browser/model_service.cc index 580e10c083b4..1086ed28f22f 100644 --- a/components/ai_chat/core/browser/model_service.cc +++ b/components/ai_chat/core/browser/model_service.cc @@ -30,10 +30,10 @@ #include "base/values.h" #include "brave/components/ai_chat/core/browser/constants.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" +#include "brave/components/ai_chat/core/browser/engine/engine_consumer_byom.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_claude.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_conversation_api.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer_llama.h" -#include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h" #include "brave/components/ai_chat/core/browser/model_validator.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" @@ -661,8 +661,8 @@ std::unique_ptr ModelService::GetEngineForModel( } else if (model->options->is_custom_model_options()) { auto& custom_model_opts = model->options->get_custom_model_options(); DVLOG(1) << "Started AI engine: custom"; - engine = std::make_unique(*custom_model_opts, - url_loader_factory); + engine = std::make_unique(*custom_model_opts, + url_loader_factory); } return engine;