Skip to content

Commit

Permalink
refactor(openai): api에서 ReadableStream 생성후 처리하는 비효율 로직 제거
Browse files Browse the repository at this point in the history
  • Loading branch information
IsthisLee committed Aug 15, 2023
1 parent 979c66e commit 251a4e7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 48 deletions.
6 changes: 4 additions & 2 deletions src/apis/laws/laws.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import {
import { Throttle } from '@nestjs/throttler';
import { Response } from 'express';
import { OpenaiService } from 'src/shared/services/openai.service';
import { Stream } from 'openai/streaming';
import { ChatCompletionChunk } from 'openai/resources/chat';

@Controller('laws')
@ApiTags('Laws')
Expand Down Expand Up @@ -213,11 +215,11 @@ export class LawsController {
@Param('id', new ParseIntPipe()) id: number,
@Body() requestSummaryDto?: RequestSummaryDto,
) {
const lawSummaryReadableStream: ReadableStream<Uint8Array> = await this.lawsService.createLawStreamSummary(
const lawSummaryReadableStream: Stream<ChatCompletionChunk> = await this.lawsService.createLawStreamSummary(
type,
id,
requestSummaryDto.recentSummaryMsg,
);
return this.openaiService.sendResWithReadableStream(res, lawSummaryReadableStream);
return this.openaiService.sendResWithOpenAIStream(res, lawSummaryReadableStream);
}
}
13 changes: 5 additions & 8 deletions src/apis/laws/laws.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import convert from 'xml-js';
import { fetchData } from 'src/common/utils';
import { getLawListDto } from './dtos/get-law.dto';
import { OpenaiService } from 'src/shared/services/openai.service';
import { OpenAI } from 'openai';
import { ChatCompletionMessage } from 'openai/resources/chat';

interface GetLawListParams {
type: SearchTabEnum;
q: string;
Expand Down Expand Up @@ -272,11 +273,7 @@ export class LawsService {
};
}

async createLawStreamSummary(
type: SearchTabEnum,
id: number,
recentSummaryMsg: string,
): Promise<ReadableStream<any>> {
async createLawStreamSummary(type: SearchTabEnum, id: number, recentSummaryMsg: string) {
const lawDetail = await this.getLawDetail(type, id);

const summaryReqMsgs = await this.generateSummaryReqMessasges(lawDetail, recentSummaryMsg, {
Expand Down Expand Up @@ -325,14 +322,14 @@ export class LawsService {
{ onlySummary }: { onlySummary?: boolean } = {
onlySummary: false,
},
): Promise<Array<OpenAI.Chat.Completions.ChatCompletionMessage>> {
): Promise<Array<ChatCompletionMessage>> {
const isFirstSummary = !recentSummaryMsg;
const initContent = onlySummary
? this.configService.get('LAW_SUMMARY_INIT_PROMPT_ONLY_SUMMARY')
: isFirstSummary
? this.configService.get('LAW_SUMMARY_INIT_PROMPT')
: this.configService.get('LAW_SUMMARY_INIT_PROMPT_ONLY_SUMMARY');
const messages: Array<OpenAI.Chat.Completions.ChatCompletionMessage> = [
const messages: Array<ChatCompletionMessage> = [
{
role: 'system',
content: initContent,
Expand Down
56 changes: 18 additions & 38 deletions src/shared/services/openai.service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { BadRequestException, Injectable } from '@nestjs/common';
import { OpenAI } from 'openai';
import { Stream } from 'openai/streaming';
import { ChatCompletionChunk, ChatCompletionMessage, ChatCompletion } from 'openai/resources/chat';
import { ConfigService } from '@nestjs/config';
import { encoding_for_model } from '@dqbd/tiktoken';
import { TiktokenModel } from '@dqbd/tiktoken';
Expand All @@ -20,59 +21,38 @@ export class OpenaiService {
}); // singleton
}

async createAIChatCompletion(
promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>,
): Promise<OpenAI.Chat.Completions.ChatCompletion> {
async createAIChatCompletion(promptMessages: Array<ChatCompletionMessage>): Promise<ChatCompletion> {
const requestData = this.generateOpenAIChatRequestData(promptMessages);
const chatCompletion = await this.openai.chat.completions.create(requestData);

return chatCompletion;
}

async createAIStramChatCompletion(
promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>,
): Promise<ReadableStream<Uint8Array>> {
promptMessages: Array<ChatCompletionMessage>,
): Promise<Stream<ChatCompletionChunk>> {
const requestData = this.generateOpenAIChatRequestData(promptMessages);
const stream: Stream<OpenAI.Chat.Completions.ChatCompletionChunk> = await this.openai.chat.completions.create({
const stream: Stream<ChatCompletionChunk> = await this.openai.chat.completions.create({
...requestData,
stream: true,
});

const encoder = new TextEncoder();

const readableStream = new ReadableStream({
async start(controller) {
for await (const chunk of stream) {
const str = typeof chunk === 'string' ? chunk : JSON.stringify(chunk);
const bytes = encoder.encode(str);
controller.enqueue(bytes);
}
},
});

return readableStream;
return stream;
}

async sendResWithReadableStream(res: Response, readableStream: ReadableStream<Uint8Array>) {
const reader = readableStream.getReader();

const readAllChunks = async () => {
const { value, done } = await reader.read();
if (done) {
res.end();
return;
}

res.write(value); // 클라이언트에 chunk data를 write
readAllChunks(); // 스트림이 소진될 때까지 재귀 호출
};
async sendResWithOpenAIStream(res: Response, opanAIStream: Stream<ChatCompletionChunk>): Promise<void> {
const encoder = new TextEncoder();

readAllChunks();
for await (const part of opanAIStream) {
const str = typeof part === 'string' ? part : JSON.stringify(part);
const bytes = encoder.encode(str);
res.write(bytes);
}
}

private generateOpenAIChatRequestData(promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>): {
private generateOpenAIChatRequestData(promptMessages: Array<ChatCompletionMessage>): {
model: TiktokenModel | 'gpt-3.5-turbo-16k';
messages: OpenAI.Chat.Completions.ChatCompletionMessage[];
messages: ChatCompletionMessage[];
} {
const { promptMessages: possiblePrompt, currentTokenCount } = this.convertChatPromptToPossible(promptMessages);

Expand All @@ -85,8 +65,8 @@ export class OpenaiService {
};
}

private convertChatPromptToPossible(promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>): {
promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>;
private convertChatPromptToPossible(promptMessages: Array<ChatCompletionMessage>): {
promptMessages: Array<ChatCompletionMessage>;
currentTokenCount: number;
} {
let currentTokenCount = this.calculateTokensWithTiktoken(promptMessages);
Expand Down Expand Up @@ -117,7 +97,7 @@ export class OpenaiService {
};
}

private calculateTokensWithTiktoken(promptMessages: Array<OpenAI.Chat.Completions.ChatCompletionMessage>): number {
private calculateTokensWithTiktoken(promptMessages: Array<ChatCompletionMessage>): number {
const gptModel: TiktokenModel = 'gpt-3.5-turbo';
const enc = encoding_for_model(gptModel);

Expand Down

0 comments on commit 251a4e7

Please sign in to comment.