Skip to content

Commit

Permalink
feat(openai): SSE를 통해 stream 데이터를 전송하도록 함
Browse files Browse the repository at this point in the history
  • Loading branch information
IsthisLee committed Aug 17, 2023
1 parent c65f18a commit 240c709
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 16 deletions.
17 changes: 8 additions & 9 deletions src/apis/laws/laws.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import {
Post,
Body,
ParseIntPipe,
Res,
UseGuards,
Delete,
Sse,
MessageEvent,
} from '@nestjs/common';
import { LawsService } from './laws.service';
import {
Expand All @@ -32,14 +33,14 @@ import {
LawSummaryResponseData,
} from 'src/common/types';
import { Throttle } from '@nestjs/throttler';
import { Response } from 'express';
import { OpenaiService } from 'src/shared/services/openai.service';
import { Stream } from 'openai/streaming';
import { ChatCompletionChunk } from 'openai/resources/chat';
import { JwtUserPayload } from 'src/common/decorators/jwt-user.decorator';
import { JwtPayloadInfo } from 'src/common/types';
import { AuthGuard } from '@nestjs/passport';
import { GetBookmarkLawListDto } from './dtos/get-bookmark-laws.dto';
import { Observable, Subscriber } from 'rxjs';

@Controller({ path: 'laws' })
@ApiTags('Laws')
Expand Down Expand Up @@ -233,7 +234,7 @@ export class LawsController {
return this.lawsService.createLawSummary(type, id.toString(), requestSummaryDto.recentSummaryMsg);
}

@Post(':type/:id/summary-stream')
@Sse(':type/:id/summary-stream')
@ApiOperation({
summary: '판례/법령 요약 요청 - stream version',
description: `
Expand All @@ -251,18 +252,16 @@ export class LawsController {
description: '판례 또는 법령의 ID(판례일련번호/법령ID)',
})
async createLawStreamSummary(
@Res() res: Response,
@Param('type', new ParseEnumPipe(SearchTabEnum))
type: SearchTabEnum,
@Param('type', new ParseEnumPipe(SearchTabEnum)) type: SearchTabEnum,
@Param('id', new ParseIntPipe()) id: number,
@Body() requestSummaryDto?: RequestSummaryDto,
) {
): Promise<Observable<MessageEvent>> {
const lawSummaryReadableStream: Stream<ChatCompletionChunk> = await this.lawsService.createLawStreamSummary(
type,
id.toString(),
requestSummaryDto.recentSummaryMsg,
requestSummaryDto?.recentSummaryMsg,
);
return this.openaiService.sendResWithOpenAIStream(res, lawSummaryReadableStream);
return this.openaiService.sendSSEWithOpenAIStream(lawSummaryReadableStream);
}

@Post(':type/:id/bookmark')
Expand Down
2 changes: 1 addition & 1 deletion src/apis/laws/laws.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ export class LawsService {
queryParams: GetBookmarkLawListDto,
): Promise<PageResponse<PrecDetailData[] | StatuteDetailData[]>> {
const params: GetBookmarkLawListParams = { type: lawType, ...queryParams };
const where = {
const where: Prisma.LawBookmarkWhereInput = {
userId,
lawType,
deletedAt: null,
Expand Down
43 changes: 37 additions & 6 deletions src/shared/services/openai.service.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import { BadRequestException, Injectable } from '@nestjs/common';
import { BadRequestException, Injectable, MessageEvent } from '@nestjs/common';
import { OpenAI } from 'openai';
import { Stream } from 'openai/streaming';
import { ChatCompletionChunk, ChatCompletionMessage, ChatCompletion } from 'openai/resources/chat';
import { ConfigService } from '@nestjs/config';
import { encoding_for_model } from '@dqbd/tiktoken';
import { TiktokenModel } from '@dqbd/tiktoken';
import { Response } from 'express';
import { Observable, Subscriber } from 'rxjs';

@Injectable()
export class OpenaiService {
Expand Down Expand Up @@ -40,15 +41,45 @@ export class OpenaiService {
return stream;
}

async sendResWithOpenAIStream(res: Response, opanAIStream: Stream<ChatCompletionChunk>): Promise<void> {
const encoder = new TextEncoder();
async sendSSEWithOpenAIStream(opanAIStream: Stream<ChatCompletionChunk>): Promise<Observable<MessageEvent>> {
return new Observable((subscriber: Subscriber<MessageEvent>) => {
this.sendData(opanAIStream, subscriber);
});
}

private async sendData(opanAIStream: Stream<ChatCompletionChunk>, subscriber: Subscriber<MessageEvent>) {
for await (const part of opanAIStream) {
const content: string = part.choices[0]?.delta?.content;
if (content) {
subscriber.next({ data: content, retry: 1000 });
}
}

subscriber.next({
type: 'close',
data: 'true',
});

subscriber.complete();
}

// unused(legacy) code
async sendSSEResponseWithOpenAIStream(res: Response, opanAIStream: Stream<ChatCompletionChunk>): Promise<void> {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.flushHeaders(); // SSE를 위해 지정된 헤더를 클라이언트에게 보냄

res.write('retry: 1000\n\n'); // 클라에서 연결이 끊기면 1초 간격으로 재연결을 시도하라는 의미

for await (const part of opanAIStream) {
const str = typeof part === 'string' ? part : JSON.stringify(part);
const bytes = encoder.encode(str);
res.write(bytes);
const content: string = part.choices[0].delta.content;
content && res.write(`data: ${content}\n\n`);
}

res.write(`event: close\n`);
res.write(`data: true\n\n`);

res.end();
}

Expand Down

0 comments on commit 240c709

Please sign in to comment.