Skip to content

Commit

Permalink
Move from OAI to BYOM
Browse files Browse the repository at this point in the history
With the growth of the BYOM model, Brave aims to support popular service endpoints. Originally, the OpenAI model was supported, which was reflected in relevant files and code. With the growing adoption of alternatives (e.g., Ollama locally, Anthropic, Perplexity, Azure hosted models, etc.), it is important to avoid confusion throughout the codebase. This change therefore partially renames files to reflect the move from "oai" to "byom", as well as updates classes, header files, and more. This migration lays the groundwork for a clearer, more maintainable path towards supporting even more endpoints and API patterns.

Also, this change addresses the presubmit warnings associated with associated unit tests (specifically, a couple lines of JSON within).
  • Loading branch information
jonathansampson committed Dec 17, 2024
1 parent 4be896e commit 439cff3
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 89 deletions.
12 changes: 6 additions & 6 deletions components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ static_library("browser") {
"constants.h",
"conversation_handler.cc",
"conversation_handler.h",
"engine/byom_api_client.cc",
"engine/byom_api_client.h",
"engine/conversation_api_client.cc",
"engine/conversation_api_client.h",
"engine/engine_consumer.cc",
"engine/engine_consumer.h",
"engine/engine_consumer_byom.cc",
"engine/engine_consumer_byom.h",
"engine/engine_consumer_claude.cc",
"engine/engine_consumer_claude.h",
"engine/engine_consumer_conversation_api.cc",
"engine/engine_consumer_conversation_api.h",
"engine/engine_consumer_llama.cc",
"engine/engine_consumer_llama.h",
"engine/engine_consumer_oai.cc",
"engine/engine_consumer_oai.h",
"engine/oai_api_client.cc",
"engine/oai_api_client.h",
"engine/remote_completion_client.cc",
"engine/remote_completion_client.h",
"local_models_updater.cc",
Expand Down Expand Up @@ -135,12 +135,12 @@ if (!is_ios) {
"ai_chat_service_unittest.cc",
"associated_content_driver_unittest.cc",
"conversation_handler_unittest.cc",
"engine/byom_api_client_unittest.cc",
"engine/conversation_api_client_unittest.cc",
"engine/engine_consumer_byom_unittest.cc",
"engine/engine_consumer_claude_unittest.cc",
"engine/engine_consumer_conversation_api_unittest.cc",
"engine/engine_consumer_llama_unittest.cc",
"engine/engine_consumer_oai_unittest.cc",
"engine/oai_api_client_unittest.cc",
"engine/test_utils.cc",
"engine/test_utils.h",
"local_models_updater_unittest.cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"
#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h"

#include <ios>
#include <optional>
Expand Down Expand Up @@ -75,19 +75,19 @@ std::string CreateJSONRequestBody(

} // namespace

OAIAPIClient::OAIAPIClient(
BYOMAPIClient::BYOMAPIClient(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
api_request_helper_ = std::make_unique<api_request_helper::APIRequestHelper>(
GetNetworkTrafficAnnotationTag(), url_loader_factory);
}

OAIAPIClient::~OAIAPIClient() = default;
BYOMAPIClient::~BYOMAPIClient() = default;

void OAIAPIClient::ClearAllQueries() {
void BYOMAPIClient::ClearAllQueries() {
api_request_helper_->CancelAll();
}

void OAIAPIClient::PerformRequest(
void BYOMAPIClient::PerformRequest(
const mojom::CustomModelOptions& model_options,
base::Value::List messages,
GenerationDataCallback data_received_callback,
Expand All @@ -108,10 +108,10 @@ void OAIAPIClient::PerformRequest(
}

if (is_sse_enabled) {
auto on_received = base::BindRepeating(&OAIAPIClient::OnQueryDataReceived,
auto on_received = base::BindRepeating(&BYOMAPIClient::OnQueryDataReceived,
weak_ptr_factory_.GetWeakPtr(),
std::move(data_received_callback));
auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted,
auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted,
weak_ptr_factory_.GetWeakPtr(),
std::move(completed_callback));

Expand All @@ -120,7 +120,7 @@ void OAIAPIClient::PerformRequest(
"application/json", std::move(on_received),
std::move(on_complete), headers, {});
} else {
auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted,
auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted,
weak_ptr_factory_.GetWeakPtr(),
std::move(completed_callback));
api_request_helper_->Request(
Expand All @@ -129,8 +129,8 @@ void OAIAPIClient::PerformRequest(
}
}

void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
APIRequestResult result) {
void BYOMAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
APIRequestResult result) {
const bool success = result.Is2XXResponseCode();
// Handle successful request
if (success) {
Expand Down Expand Up @@ -162,7 +162,7 @@ void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
std::move(callback).Run(base::unexpected(mojom::APIError::ConnectionIssue));
}

void OAIAPIClient::OnQueryDataReceived(
void BYOMAPIClient::OnQueryDataReceived(
GenerationDataCallback callback,
base::expected<base::Value, std::string> result) {
if (!result.has_value() || !result->is_dict()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_

#include <memory>
#include <string>
Expand Down Expand Up @@ -33,21 +33,21 @@ namespace mojom {
class CustomModelOptions;
} // namespace mojom

// Performs remote request to the OAI format APIs.
class OAIAPIClient {
// Performs remote request to BYOM endpoints.
class BYOMAPIClient {
public:
using GenerationResult = base::expected<std::string, mojom::APIError>;
using GenerationDataCallback =
base::RepeatingCallback<void(mojom::ConversationEntryEventPtr)>;
using GenerationCompletedCallback =
base::OnceCallback<void(GenerationResult)>;

explicit OAIAPIClient(
explicit BYOMAPIClient(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);

OAIAPIClient(const OAIAPIClient&) = delete;
OAIAPIClient& operator=(const OAIAPIClient&) = delete;
virtual ~OAIAPIClient();
BYOMAPIClient(const BYOMAPIClient&) = delete;
BYOMAPIClient& operator=(const BYOMAPIClient&) = delete;
virtual ~BYOMAPIClient();

virtual void PerformRequest(const mojom::CustomModelOptions& model_options,
base::Value::List messages,
Expand All @@ -73,9 +73,9 @@ class OAIAPIClient {

std::unique_ptr<api_request_helper::APIRequestHelper> api_request_helper_;

base::WeakPtrFactory<OAIAPIClient> weak_ptr_factory_{this};
base::WeakPtrFactory<BYOMAPIClient> weak_ptr_factory_{this};
};

} // namespace ai_chat

#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/. */

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"
#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h"

#include <list>
#include <optional>
Expand Down Expand Up @@ -40,7 +40,7 @@ using DataReceivedCallback =
api_request_helper::APIRequestHelper::DataReceivedCallback;
using ResultCallback = api_request_helper::APIRequestHelper::ResultCallback;
using Ticket = api_request_helper::APIRequestHelper::Ticket;
using GenerationResult = ai_chat::OAIAPIClient::GenerationResult;
using GenerationResult = ai_chat::BYOMAPIClient::GenerationResult;

namespace ai_chat {

Expand All @@ -66,26 +66,26 @@ class MockAPIRequestHelper : public api_request_helper::APIRequestHelper {
(override));
};

class TestOAIAPIClient : public OAIAPIClient {
class TestBYOMAPIClient : public BYOMAPIClient {
public:
TestOAIAPIClient() : OAIAPIClient(nullptr) {
TestBYOMAPIClient() : BYOMAPIClient(nullptr) {
SetAPIRequestHelperForTesting(std::make_unique<MockAPIRequestHelper>(
net::NetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
nullptr));
}
~TestOAIAPIClient() override = default;
~TestBYOMAPIClient() override = default;

MockAPIRequestHelper* GetMockAPIRequestHelper() {
return static_cast<MockAPIRequestHelper*>(GetAPIRequestHelperForTesting());
}
};

class OAIAPIUnitTest : public testing::Test {
class BYOMAPIUnitTest : public testing::Test {
public:
OAIAPIUnitTest() = default;
~OAIAPIUnitTest() override = default;
BYOMAPIUnitTest() = default;
~BYOMAPIUnitTest() override = default;

void SetUp() override { client_ = std::make_unique<TestOAIAPIClient>(); }
void SetUp() override { client_ = std::make_unique<TestBYOMAPIClient>(); }

void TearDown() override {}

Expand All @@ -112,18 +112,58 @@ class OAIAPIUnitTest : public testing::Test {

protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<TestOAIAPIClient> client_;
std::unique_ptr<TestBYOMAPIClient> client_;
};

TEST_F(OAIAPIUnitTest, PerformRequest) {
TEST_F(BYOMAPIUnitTest, PerformRequest) {
mojom::CustomModelOptionsPtr model_options = mojom::CustomModelOptions::New(
"test_api_key", 0, 0, 0, "test_system_prompt", GURL("https://test.com"),
"test_model");

std::string server_chunk =
R"({"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"It was played in Arlington, Texas."},"logprobs":null,"finish_reason":null}]})";
R"({
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1694268190,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "It was played in Arlington, Texas."
},
"logprobs": null,
"finish_reason": null
}
]
})";

std::string server_completion =
R"({"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"message":{"role":"assistant","content":"\n\nCan I assist you further?"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}})";
R"({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nCan I assist you further?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
})";

std::string expected_chunk_response = "It was played in Arlington, Texas.";
std::string expected_completion_response = "\n\nCan I assist you further?";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer_byom.h"

#include <optional>
#include <string>
Expand Down Expand Up @@ -117,27 +117,27 @@ base::Value::List BuildMessages(

} // namespace

EngineConsumerOAIRemote::EngineConsumerOAIRemote(
EngineConsumerBYOMRemote::EngineConsumerBYOMRemote(
const mojom::CustomModelOptions& model_options,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
model_options_ = model_options;
max_associated_content_length_ = model_options.max_associated_content_length;

// Initialize the API client
api_ = std::make_unique<OAIAPIClient>(url_loader_factory);
api_ = std::make_unique<BYOMAPIClient>(url_loader_factory);
}

EngineConsumerOAIRemote::~EngineConsumerOAIRemote() = default;
EngineConsumerBYOMRemote::~EngineConsumerBYOMRemote() = default;

void EngineConsumerOAIRemote::ClearAllQueries() {
void EngineConsumerBYOMRemote::ClearAllQueries() {
api_->ClearAllQueries();
}

bool EngineConsumerOAIRemote::SupportsDeltaTextResponses() const {
bool EngineConsumerBYOMRemote::SupportsDeltaTextResponses() const {
return true;
}

void EngineConsumerOAIRemote::UpdateModelOptions(
void EngineConsumerBYOMRemote::UpdateModelOptions(
const mojom::ModelOptions& options) {
if (options.is_custom_model_options()) {
model_options_ = *options.get_custom_model_options();
Expand All @@ -146,7 +146,7 @@ void EngineConsumerOAIRemote::UpdateModelOptions(
}
}

void EngineConsumerOAIRemote::GenerateRewriteSuggestion(
void EngineConsumerBYOMRemote::GenerateRewriteSuggestion(
std::string text,
const std::string& question,
const std::string& selected_language,
Expand All @@ -173,7 +173,7 @@ void EngineConsumerOAIRemote::GenerateRewriteSuggestion(
std::move(completed_callback));
}

void EngineConsumerOAIRemote::GenerateQuestionSuggestions(
void EngineConsumerBYOMRemote::GenerateQuestionSuggestions(
const bool& is_video,
const std::string& page_content,
const std::string& selected_language,
Expand Down Expand Up @@ -207,11 +207,11 @@ void EngineConsumerOAIRemote::GenerateQuestionSuggestions(
api_->PerformRequest(
model_options_, std::move(messages), base::NullCallback(),
base::BindOnce(
&EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse,
&EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}

void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse(
void EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse(
SuggestedQuestionsCallback callback,
GenerationResult result) {
if (!result.has_value() || result->empty()) {
Expand Down Expand Up @@ -242,7 +242,7 @@ void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse(
std::move(callback).Run(std::move(questions));
}

void EngineConsumerOAIRemote::GenerateAssistantResponse(
void EngineConsumerBYOMRemote::GenerateAssistantResponse(
const bool& is_video,
const std::string& page_content,
const ConversationHistory& conversation_history,
Expand Down Expand Up @@ -274,6 +274,6 @@ void EngineConsumerOAIRemote::GenerateAssistantResponse(
std::move(completed_callback));
}

void EngineConsumerOAIRemote::SanitizeInput(std::string& input) {}
void EngineConsumerBYOMRemote::SanitizeInput(std::string& input) {}

} // namespace ai_chat
Loading

0 comments on commit 439cff3

Please sign in to comment.