Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOM] Move from "OAI" naming convention to "BYOM" #27043

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions components/ai_chat/core/browser/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ static_library("browser") {
"constants.h",
"conversation_handler.cc",
"conversation_handler.h",
"engine/byom_api_client.cc",
"engine/byom_api_client.h",
"engine/conversation_api_client.cc",
"engine/conversation_api_client.h",
"engine/engine_consumer.cc",
"engine/engine_consumer.h",
"engine/engine_consumer_byom.cc",
"engine/engine_consumer_byom.h",
"engine/engine_consumer_claude.cc",
"engine/engine_consumer_claude.h",
"engine/engine_consumer_conversation_api.cc",
"engine/engine_consumer_conversation_api.h",
"engine/engine_consumer_llama.cc",
"engine/engine_consumer_llama.h",
"engine/engine_consumer_oai.cc",
"engine/engine_consumer_oai.h",
"engine/oai_api_client.cc",
"engine/oai_api_client.h",
"engine/remote_completion_client.cc",
"engine/remote_completion_client.h",
"local_models_updater.cc",
Expand Down Expand Up @@ -135,12 +135,12 @@ if (!is_ios) {
"ai_chat_service_unittest.cc",
"associated_content_driver_unittest.cc",
"conversation_handler_unittest.cc",
"engine/byom_api_client_unittest.cc",
"engine/conversation_api_client_unittest.cc",
"engine/engine_consumer_byom_unittest.cc",
"engine/engine_consumer_claude_unittest.cc",
"engine/engine_consumer_conversation_api_unittest.cc",
"engine/engine_consumer_llama_unittest.cc",
"engine/engine_consumer_oai_unittest.cc",
"engine/oai_api_client_unittest.cc",
"engine/test_utils.cc",
"engine/test_utils.h",
"local_models_updater_unittest.cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"
#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h"

#include <ios>
#include <optional>
Expand Down Expand Up @@ -75,19 +75,19 @@ std::string CreateJSONRequestBody(

} // namespace

OAIAPIClient::OAIAPIClient(
BYOMAPIClient::BYOMAPIClient(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
api_request_helper_ = std::make_unique<api_request_helper::APIRequestHelper>(
GetNetworkTrafficAnnotationTag(), url_loader_factory);
}

OAIAPIClient::~OAIAPIClient() = default;
BYOMAPIClient::~BYOMAPIClient() = default;

void OAIAPIClient::ClearAllQueries() {
void BYOMAPIClient::ClearAllQueries() {
api_request_helper_->CancelAll();
}

void OAIAPIClient::PerformRequest(
void BYOMAPIClient::PerformRequest(
const mojom::CustomModelOptions& model_options,
base::Value::List messages,
GenerationDataCallback data_received_callback,
Expand All @@ -108,10 +108,10 @@ void OAIAPIClient::PerformRequest(
}

if (is_sse_enabled) {
auto on_received = base::BindRepeating(&OAIAPIClient::OnQueryDataReceived,
auto on_received = base::BindRepeating(&BYOMAPIClient::OnQueryDataReceived,
weak_ptr_factory_.GetWeakPtr(),
std::move(data_received_callback));
auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted,
auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted,
weak_ptr_factory_.GetWeakPtr(),
std::move(completed_callback));

Expand All @@ -120,7 +120,7 @@ void OAIAPIClient::PerformRequest(
"application/json", std::move(on_received),
std::move(on_complete), headers, {});
} else {
auto on_complete = base::BindOnce(&OAIAPIClient::OnQueryCompleted,
auto on_complete = base::BindOnce(&BYOMAPIClient::OnQueryCompleted,
weak_ptr_factory_.GetWeakPtr(),
std::move(completed_callback));
api_request_helper_->Request(
Expand All @@ -129,8 +129,8 @@ void OAIAPIClient::PerformRequest(
}
}

void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
APIRequestResult result) {
void BYOMAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
APIRequestResult result) {
const bool success = result.Is2XXResponseCode();
// Handle successful request
if (success) {
Expand Down Expand Up @@ -162,7 +162,7 @@ void OAIAPIClient::OnQueryCompleted(GenerationCompletedCallback callback,
std::move(callback).Run(base::unexpected(mojom::APIError::ConnectionIssue));
}

void OAIAPIClient::OnQueryDataReceived(
void BYOMAPIClient::OnQueryDataReceived(
GenerationDataCallback callback,
base::expected<base::Value, std::string> result) {
if (!result.has_value() || !result->is_dict()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#ifndef BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_

#include <memory>
#include <string>
Expand Down Expand Up @@ -33,21 +33,21 @@ namespace mojom {
class CustomModelOptions;
} // namespace mojom

// Performs remote request to the OAI format APIs.
class OAIAPIClient {
// Performs remote request to BYOM endpoints.
class BYOMAPIClient {
public:
using GenerationResult = base::expected<std::string, mojom::APIError>;
using GenerationDataCallback =
base::RepeatingCallback<void(mojom::ConversationEntryEventPtr)>;
using GenerationCompletedCallback =
base::OnceCallback<void(GenerationResult)>;

explicit OAIAPIClient(
explicit BYOMAPIClient(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);

OAIAPIClient(const OAIAPIClient&) = delete;
OAIAPIClient& operator=(const OAIAPIClient&) = delete;
virtual ~OAIAPIClient();
BYOMAPIClient(const BYOMAPIClient&) = delete;
BYOMAPIClient& operator=(const BYOMAPIClient&) = delete;
virtual ~BYOMAPIClient();

virtual void PerformRequest(const mojom::CustomModelOptions& model_options,
base::Value::List messages,
Expand All @@ -73,9 +73,9 @@ class OAIAPIClient {

std::unique_ptr<api_request_helper::APIRequestHelper> api_request_helper_;

base::WeakPtrFactory<OAIAPIClient> weak_ptr_factory_{this};
base::WeakPtrFactory<BYOMAPIClient> weak_ptr_factory_{this};
};

} // namespace ai_chat

#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_OAI_API_CLIENT_H_
#endif // BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_ENGINE_BYOM_API_CLIENT_H_
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at https://mozilla.org/MPL/2.0/. */

#include "brave/components/ai_chat/core/browser/engine/oai_api_client.h"
#include "brave/components/ai_chat/core/browser/engine/byom_api_client.h"

#include <list>
#include <optional>
Expand Down Expand Up @@ -40,7 +40,7 @@ using DataReceivedCallback =
api_request_helper::APIRequestHelper::DataReceivedCallback;
using ResultCallback = api_request_helper::APIRequestHelper::ResultCallback;
using Ticket = api_request_helper::APIRequestHelper::Ticket;
using GenerationResult = ai_chat::OAIAPIClient::GenerationResult;
using GenerationResult = ai_chat::BYOMAPIClient::GenerationResult;

namespace ai_chat {

Expand All @@ -66,26 +66,26 @@ class MockAPIRequestHelper : public api_request_helper::APIRequestHelper {
(override));
};

class TestOAIAPIClient : public OAIAPIClient {
class TestBYOMAPIClient : public BYOMAPIClient {
public:
TestOAIAPIClient() : OAIAPIClient(nullptr) {
TestBYOMAPIClient() : BYOMAPIClient(nullptr) {
SetAPIRequestHelperForTesting(std::make_unique<MockAPIRequestHelper>(
net::NetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
nullptr));
}
~TestOAIAPIClient() override = default;
~TestBYOMAPIClient() override = default;

MockAPIRequestHelper* GetMockAPIRequestHelper() {
return static_cast<MockAPIRequestHelper*>(GetAPIRequestHelperForTesting());
}
};

class OAIAPIUnitTest : public testing::Test {
class BYOMAPIUnitTest : public testing::Test {
public:
OAIAPIUnitTest() = default;
~OAIAPIUnitTest() override = default;
BYOMAPIUnitTest() = default;
~BYOMAPIUnitTest() override = default;

void SetUp() override { client_ = std::make_unique<TestOAIAPIClient>(); }
void SetUp() override { client_ = std::make_unique<TestBYOMAPIClient>(); }

void TearDown() override {}

Expand All @@ -112,18 +112,58 @@ class OAIAPIUnitTest : public testing::Test {

protected:
base::test::TaskEnvironment task_environment_;
std::unique_ptr<TestOAIAPIClient> client_;
std::unique_ptr<TestBYOMAPIClient> client_;
};

TEST_F(OAIAPIUnitTest, PerformRequest) {
TEST_F(BYOMAPIUnitTest, PerformRequest) {
mojom::CustomModelOptionsPtr model_options = mojom::CustomModelOptions::New(
"test_api_key", 0, 0, 0, "test_system_prompt", GURL("https://test.com"),
"test_model");

std::string server_chunk =
R"({"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0125", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":"It was played in Arlington, Texas."},"logprobs":null,"finish_reason":null}]})";
R"({
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1694268190,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": "It was played in Arlington, Texas."
},
"logprobs": null,
"finish_reason": null
}
]
})";

std::string server_completion =
R"({"id":"chatcmpl-123","object":"chat.completion","created":1677652288,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"message":{"role":"assistant","content":"\n\nCan I assist you further?"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}})";
R"({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nCan I assist you further?"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
})";

std::string expected_chunk_response = "It was played in Arlington, Texas.";
std::string expected_completion_response = "\n\nCan I assist you further?";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.

#include "brave/components/ai_chat/core/browser/engine/engine_consumer_oai.h"
#include "brave/components/ai_chat/core/browser/engine/engine_consumer_byom.h"

#include <optional>
#include <string>
Expand Down Expand Up @@ -117,27 +117,27 @@ base::Value::List BuildMessages(

} // namespace

EngineConsumerOAIRemote::EngineConsumerOAIRemote(
EngineConsumerBYOMRemote::EngineConsumerBYOMRemote(
const mojom::CustomModelOptions& model_options,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
model_options_ = model_options;
max_associated_content_length_ = model_options.max_associated_content_length;

// Initialize the API client
api_ = std::make_unique<OAIAPIClient>(url_loader_factory);
api_ = std::make_unique<BYOMAPIClient>(url_loader_factory);
}

EngineConsumerOAIRemote::~EngineConsumerOAIRemote() = default;
EngineConsumerBYOMRemote::~EngineConsumerBYOMRemote() = default;

void EngineConsumerOAIRemote::ClearAllQueries() {
void EngineConsumerBYOMRemote::ClearAllQueries() {
api_->ClearAllQueries();
}

bool EngineConsumerOAIRemote::SupportsDeltaTextResponses() const {
bool EngineConsumerBYOMRemote::SupportsDeltaTextResponses() const {
return true;
}

void EngineConsumerOAIRemote::UpdateModelOptions(
void EngineConsumerBYOMRemote::UpdateModelOptions(
const mojom::ModelOptions& options) {
if (options.is_custom_model_options()) {
model_options_ = *options.get_custom_model_options();
Expand All @@ -146,7 +146,7 @@ void EngineConsumerOAIRemote::UpdateModelOptions(
}
}

void EngineConsumerOAIRemote::GenerateRewriteSuggestion(
void EngineConsumerBYOMRemote::GenerateRewriteSuggestion(
std::string text,
const std::string& question,
const std::string& selected_language,
Expand All @@ -173,7 +173,7 @@ void EngineConsumerOAIRemote::GenerateRewriteSuggestion(
std::move(completed_callback));
}

void EngineConsumerOAIRemote::GenerateQuestionSuggestions(
void EngineConsumerBYOMRemote::GenerateQuestionSuggestions(
const bool& is_video,
const std::string& page_content,
const std::string& selected_language,
Expand Down Expand Up @@ -207,11 +207,11 @@ void EngineConsumerOAIRemote::GenerateQuestionSuggestions(
api_->PerformRequest(
model_options_, std::move(messages), base::NullCallback(),
base::BindOnce(
&EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse,
&EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}

void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse(
void EngineConsumerBYOMRemote::OnGenerateQuestionSuggestionsResponse(
SuggestedQuestionsCallback callback,
GenerationResult result) {
if (!result.has_value() || result->empty()) {
Expand Down Expand Up @@ -242,7 +242,7 @@ void EngineConsumerOAIRemote::OnGenerateQuestionSuggestionsResponse(
std::move(callback).Run(std::move(questions));
}

void EngineConsumerOAIRemote::GenerateAssistantResponse(
void EngineConsumerBYOMRemote::GenerateAssistantResponse(
const bool& is_video,
const std::string& page_content,
const ConversationHistory& conversation_history,
Expand Down Expand Up @@ -274,6 +274,6 @@ void EngineConsumerOAIRemote::GenerateAssistantResponse(
std::move(completed_callback));
}

void EngineConsumerOAIRemote::SanitizeInput(std::string& input) {}
void EngineConsumerBYOMRemote::SanitizeInput(std::string& input) {}

} // namespace ai_chat
Loading
Loading