diff --git a/packages/cli/package.json b/packages/cli/package.json index fe54ae9..8440856 100644 --- a/packages/cli/package.json +++ b/packages/cli/package.json @@ -16,9 +16,10 @@ "inquirer-search-list": "^1.2.6", "js-tiktoken": "^1.0.8", "llamaindex": "^0.1.2", + "modelfusion": "^0.137.0", "openai": "^4.20.1", "openpipe": "^0.8.0", - "prompt-iteration-assistant": "^0.0.34", + "prompt-iteration-assistant": "^0.0.37", "promptfoo": "^0.28.1", "pybridge-zod": "^1.1.0", "remeda": "^1.29.0", diff --git a/packages/cli/src/recommender/prompts/answersQuestion/answersQuestion.ts b/packages/cli/src/recommender/prompts/answersQuestion/answersQuestion.ts index ada2b15..73b3114 100644 --- a/packages/cli/src/recommender/prompts/answersQuestion/answersQuestion.ts +++ b/packages/cli/src/recommender/prompts/answersQuestion/answersQuestion.ts @@ -5,6 +5,7 @@ import { } from "./schemas/answersQuestionInputSchema"; import { answersQuestionOutputSchema } from "./schemas/answersQuestionOutputSchema"; import { answersQuestionPrompt } from "./prompts/answersQuestionPrompt"; +import { DefaultRun } from "modelfusion"; export const ANSWERS_QUESTION = "Answers Question"; @@ -23,11 +24,12 @@ class AnswersQuestion extends Prompt< output: answersQuestionOutputSchema, }); } - async execute(args: AnswersQuestionInput) { + async execute(args: AnswersQuestionInput & { run?: DefaultRun }) { try { return this.run({ stream: false, promptVariables: args, + run: args.run, }); } catch (e) { console.error(e); diff --git a/packages/cli/src/recommender/prompts/brainstormSubQuestions/brainstormQuestions.ts b/packages/cli/src/recommender/prompts/brainstormSubQuestions/brainstormQuestions.ts index e0fae43..8693f66 100644 --- a/packages/cli/src/recommender/prompts/brainstormSubQuestions/brainstormQuestions.ts +++ b/packages/cli/src/recommender/prompts/brainstormSubQuestions/brainstormQuestions.ts @@ -10,6 +10,7 @@ import { } from "./schemas/brainstormQuestionsInputSchema"; import { brainstormQuestionsOutputSchema } from "./schemas/brainstormQuestionsOutputSchema"; import { brainstormQuestionsZeroShotPrompt } from "./prompts/zeroShot"; +import { DefaultRun } from "modelfusion"; export const BRAINSTORM_QUESTIONS = "Brainstorm Questions"; @@ -29,39 +30,22 @@ export class BrainstormQuestions extends Prompt< }); } - async execute(args: { - query: string; - openPipeRequestTags?: RequestTagsWithoutName; - enableOpenPipeLogging?: boolean; - }) { + async execute(args: { query: string; run?: DefaultRun }) { const promptVariables: BrainstormQuestionsInput = { query: args.query, }; - const candidatePrompt = this.chooseCandidatePrompt(promptVariables); - const res = await openpipe.functionCall({ - function: { - name: this.name, - description: this.description, - input: this.input!, - output: this.output!, - }, - vars: promptVariables, - prompt: candidatePrompt, - body: { - max_tokens: this.max_tokens, - temperature: this.temperature, - model: this.model, + try { + const res = await this.run({ + promptVariables, stream: false, - }, - openPipeRequestTags: args.openPipeRequestTags - ? { - ...args.openPipeRequestTags, - promptName: formatPromptName(this.name, candidatePrompt.name), - } - : undefined, - enableOpenPipeLogging: args.enableOpenPipeLogging, - }); - return res?.questions || []; + run: args.run, + }); + + return res?.questions || []; + } catch (e) { + console.error(e); + return []; + } } } diff --git a/packages/cli/src/recommender/prompts/brainstormSubQuestions/prompts/zeroShot.ts b/packages/cli/src/recommender/prompts/brainstormSubQuestions/prompts/zeroShot.ts index 5e7b1a6..09a78ba 100644 --- a/packages/cli/src/recommender/prompts/brainstormSubQuestions/prompts/zeroShot.ts +++ b/packages/cli/src/recommender/prompts/brainstormSubQuestions/prompts/zeroShot.ts @@ -19,7 +19,8 @@ export const brainstormQuestionsZeroShotPrompt = - Your task is to brainstorm related questions to research. - If the topic is controversial, try to include both sides of the argument. - DO NOT USE ACRONYMS OR ABBREVIATIONS, ALWAYS USE THE FULL NAME WITH THE ACRONYM IN PARANTHESES. -- Brainstorm five questions. +- Don't include introductory or background questions. +- Brainstorm three questions. `.trim() ), { @@ -33,11 +34,9 @@ How can Spaced Repetition Systems (SRS) be improved with AI? name: toCamelCase(BRAINSTORM_QUESTIONS), arguments: { questions: [ - "What are the current limitations of Spaced Repetition Systems (SRS) for learning?", "How has artificial intelligence (AI) been applied to educational systems in other contexts? Can these methods be adapted for Spaced Repetition Systems (SRS)?", "What potential improvements could artificial intelligence (AI) bring to Spaced Repetition Systems (SRS) in terms of personalization and effectiveness?", "What are the possible risks or drawbacks in integrating artificial intelligence (AI) into Spaced Repetition Systems (SRS)?", - "How do professionals in the fields of education and artificial intelligence (AI) view the potential integration of artificial intelligence (AI) and Spaced Repetition Systems (SRS)?", ], }, }), diff --git a/packages/cli/src/recommender/prompts/createQueriesFromProfile/createQueriesFromProfile.ts b/packages/cli/src/recommender/prompts/createQueriesFromProfile/createQueriesFromProfile.ts index dea589d..45f2c38 100644 --- a/packages/cli/src/recommender/prompts/createQueriesFromProfile/createQueriesFromProfile.ts +++ b/packages/cli/src/recommender/prompts/createQueriesFromProfile/createQueriesFromProfile.ts @@ -11,6 +11,7 @@ import { import { openpipe } from "../../../openpipe/openpipe"; import { experilearningDataset } from "./datasets/experilearningDataset"; import { createQueriesFromProfileZeroShotFreeFormPrompt } from "./prompts/zeroShotFreeForm"; +import { DefaultRun } from "modelfusion"; export const CREATE_SEARCH_QUERIES_FROM_PROFILE = "Create Questions"; @@ -38,39 +39,24 @@ export class CreateSearchQueriesFromProfile extends Prompt< user: string; profile: string; bio: string; - openPipeRequestTags?: RequestTagsWithoutName; - enableOpenPipeLogging?: boolean; + run?: DefaultRun; }) { const promptVariables: CreateQueriesFromProfileInput = { user: args.user, bio: args.bio, profile: args.profile, }; - const candidatePrompt = this.chooseCandidatePrompt(promptVariables); - const res = await openpipe.functionCall({ - function: { - name: this.name, - description: this.description, - input: this.input!, - output: this.output!, - }, - vars: promptVariables, - prompt: candidatePrompt, - body: { - max_tokens: this.max_tokens, - temperature: this.temperature, - model: this.model, + try { + const res = await this.run({ + promptVariables, stream: false, - }, - openPipeRequestTags: args.openPipeRequestTags - ? { - ...args.openPipeRequestTags, - promptName: formatPromptName(this.name, candidatePrompt.name), - } - : undefined, - enableOpenPipeLogging: args.enableOpenPipeLogging, - }); - return res || { queries: [] }; + run: args.run, + }); + return res || { queries: [] }; + } catch (e) { + console.error(e); + return { queries: [] }; + } } } diff --git a/packages/cli/src/recommender/prompts/findStartOfAnswer/findStartOfAnswer.ts b/packages/cli/src/recommender/prompts/findStartOfAnswer/findStartOfAnswer.ts index 061f918..b0d1487 100644 --- a/packages/cli/src/recommender/prompts/findStartOfAnswer/findStartOfAnswer.ts +++ b/packages/cli/src/recommender/prompts/findStartOfAnswer/findStartOfAnswer.ts @@ -1,21 +1,18 @@ import { Prompt } from "prompt-iteration-assistant"; -import { - RequestTagsWithoutName, - formatPromptName, -} from "../../../openpipe/requestTags"; -import { openpipe } from "../../../openpipe/openpipe"; +import { RequestTagsWithoutName } from "../../../openpipe/requestTags"; import { FindStartOfAnswerInput, findStartOfAnswerInputSchema, } from "./schemas/findStartOfAnswerOutputSchema"; -import { findStartOfAnswerOutputSchema } from "./schemas/findStartOfAnswerInputSchema"; import { findStartOfAnswerPrompt } from "./prompts/findStartOfAnswerPrompt"; +import { z } from "zod"; +import { DefaultRun } from "modelfusion"; export const FIND_START_OF_ANSWER = "Find Start Of Answer"; class FindStartOfAnswer extends Prompt< typeof findStartOfAnswerInputSchema, - typeof findStartOfAnswerOutputSchema + z.ZodString > { constructor() { super({ @@ -24,7 +21,6 @@ class FindStartOfAnswer extends Prompt< prompts: [findStartOfAnswerPrompt], model: "gpt-4", input: findStartOfAnswerInputSchema, - output: findStartOfAnswerOutputSchema, exampleData: [], }); } @@ -32,9 +28,8 @@ class FindStartOfAnswer extends Prompt< async execute(args: { question: string; text: string; - openPipeRequestTags?: RequestTagsWithoutName; - enableOpenPipeLogging?: boolean; - }) { + run?: DefaultRun; + }): Promise { const promptVariables: FindStartOfAnswerInput = { text: args.text, question: args.question, @@ -43,10 +38,11 @@ class FindStartOfAnswer extends Prompt< return this.run({ stream: false, promptVariables, + run: args.run, }); } catch (e) { console.error(e); - return { quotedAnswer: null }; + return null; } } } diff --git a/packages/cli/src/recommender/prompts/findStartOfAnswer/prompts/findStartOfAnswerPrompt.ts b/packages/cli/src/recommender/prompts/findStartOfAnswer/prompts/findStartOfAnswerPrompt.ts index 648d5dd..3f86b25 100644 --- a/packages/cli/src/recommender/prompts/findStartOfAnswer/prompts/findStartOfAnswerPrompt.ts +++ b/packages/cli/src/recommender/prompts/findStartOfAnswer/prompts/findStartOfAnswerPrompt.ts @@ -12,38 +12,9 @@ export const findStartOfAnswerPrompt = - Given a question from the user, evalutate whether the beginning of the answer is in the text. - If the beginning of the answer is in the text, quote the beginning of the answer. - The answer doesn't need to be complete, just the start of it. -- Quote the beginning of the answer directly from the text as a JSON string, escaping any literal control characters. +- Quote the beginning of the answer directly from the text. `.trim() ), - // ChatMessage.user( - // ` - // # Text - - // # Question - // What is the best way to learn a new language? - // `.trim() - // ), - // ChatMessage.assistant(null, { - // name: toCamelCase(FIND_START_OF_ANSWER), - // arguments: { - // answersQuestion: true, - // quotedAnswer: "", - // }, - // }), - // ChatMessage.user( - // ` - // # Text - - // # Question - // How does chain of thought prompting work? - // `.trim() - // ), - // ChatMessage.assistant(null, { - // name: toCamelCase(FIND_START_OF_ANSWER), - // arguments: { - // answersQuestion: false, - // }, - // }), ChatMessage.user( ` # Question diff --git a/packages/cli/src/recommender/prompts/findStartOfAnswerYouTube/findStartOfAnswerYouTube.ts b/packages/cli/src/recommender/prompts/findStartOfAnswerYouTube/findStartOfAnswerYouTube.ts index ec0fff5..210d5d8 100644 --- a/packages/cli/src/recommender/prompts/findStartOfAnswerYouTube/findStartOfAnswerYouTube.ts +++ b/packages/cli/src/recommender/prompts/findStartOfAnswerYouTube/findStartOfAnswerYouTube.ts @@ -6,6 +6,7 @@ import { findStartOfAnswerYouTubeInputSchema, } from "./schemas/findStartOfAnswerYouTubeInputSchema"; import { findStartOfAnswerYouTubePrompt } from "./prompts/findStartOfAnswerYouTubePrompt"; +import { DefaultRun } from "modelfusion"; export const FIND_START_OF_ANSWER_YOUTUBE = "Find Start Of Answer Cue"; @@ -29,8 +30,7 @@ export class FindStartOfAnswerYouTube extends Prompt< async execute(args: { question: string; cues: { text: string }[]; - openPipeRequestTags?: RequestTagsWithoutName; - enableOpenPipeLogging?: boolean; + run?: DefaultRun; }) { const promptVariables: FindStartOfAnswerYouTubeInput = { transcript: args.cues @@ -47,6 +47,7 @@ ${cue.text} return this.run({ stream: false, promptVariables, + run: args.run, }); } catch (e) { console.error(e); diff --git a/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/prompts/zeroShot.ts b/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/prompts/zeroShot.ts index 854eeac..192fa04 100644 --- a/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/prompts/zeroShot.ts +++ b/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/prompts/zeroShot.ts @@ -9,13 +9,11 @@ export const zeroShot = new CandidatePrompt({ ` # Instructions - Act as a user profile writer. -- You are shown a collection of user data. - Given the information from multiple sources and not prior knowledge summarize the user's interests into a profile that explains what topics, people, concepts and ideas interest them. - Focus on specific low-level interests, like "the use of large language models (LLMs) in recommender systems", as opposed to generic high-level interests like "AI", "technology" and "innovation". -- Expand abbreviations and acronyms, eg. "LLM agents" should be written as "large language model (LLM) agents". -- You must include any technical terms and names of people, places, and things that are relevant to the user. +- Expand abbreviations and acronyms. For example, "LLM agents" should be written as "Large Language Model (LLM) agents". +- You must include technical terms and names of people, places, and things that are relevant to the user. - If summarizing existing summaries, preserve the technical terms and names of concepts, people, places, ideas and events that are relevant to the user. -- The summary string should be JSON parsable (escaped quotes, etc). - Write a two paragraph summary of around 500 words. `.trim() ), @@ -28,7 +26,7 @@ ${this.getVariable("user")} # Bio ${this.getVariable("bio")} -# User data +# Raw User Data or Existing Summaries ${this.getVariable("tweets")} `.trim(), }, diff --git a/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/recursiveTwitterSummarizer.ts b/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/recursiveTwitterSummarizer.ts index 0e0f8b0..d1c17c1 100644 --- a/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/recursiveTwitterSummarizer.ts +++ b/packages/cli/src/recommender/prompts/recursiveTwitterSummarizer/recursiveTwitterSummarizer.ts @@ -1,12 +1,8 @@ import { Prompt } from "prompt-iteration-assistant"; import { zeroShot } from "./prompts/zeroShot"; -import { - RequestTagsWithoutName, - formatPromptName, -} from "../../../openpipe/requestTags"; +import { RequestTagsWithoutName } from "../../../openpipe/requestTags"; import { Tweet, TwitterUser } from "shared/src/manual/Tweet"; import { tweetsToString } from "../../../twitter/getUserContext"; -import { openpipe } from "../../../openpipe/openpipe"; import { experilearningDataset, experilearningTweets, @@ -17,8 +13,15 @@ import { RecursiveTwitterSummarizerInput, recursiveTwitterSummarizerInputSchema, } from "./schemas/recursiveTwitterSummarizerInputSchema"; -import { recursiveTwitterSummarizerOutputSchema } from "./schemas/recursiveTwitterSummarizerOutputSchema"; import { tokenize } from "../../../tokenize"; +import { z } from "zod"; +import dotenv from "dotenv"; +import path from "path"; +import { DefaultRun, Run } from "modelfusion"; +import { + calculateCost, + OpenAICostCalculator, +} from "@modelfusion/cost-calculator"; export const SUMMARIZE_TWEETS = "Summarize Data"; @@ -27,16 +30,15 @@ export const SUMMARIZE_TWEETS = "Summarize Data"; */ export class RecursiveTwitterSummarizer extends Prompt< typeof recursiveTwitterSummarizerInputSchema, - typeof recursiveTwitterSummarizerOutputSchema + z.ZodString > { constructor() { super({ name: SUMMARIZE_TWEETS, - description: "Summarize the user's data.", + description: "Summarize the user's data into a user profile.", prompts: [zeroShot], model: "gpt-4", input: recursiveTwitterSummarizerInputSchema, - output: recursiveTwitterSummarizerOutputSchema, exampleData: [], }); } @@ -46,6 +48,7 @@ export class RecursiveTwitterSummarizer extends Prompt< tweets: Tweet[]; openPipeRequestTags?: RequestTagsWithoutName; enableOpenPipeLogging?: boolean; + run?: Run; }) { if (args.tweets.length === 0) { return undefined; @@ -70,31 +73,13 @@ export class RecursiveTwitterSummarizer extends Prompt< bio, user: args.user.displayname, }; - const candidatePrompt = this.chooseCandidatePrompt(promptVariables); - const res = await openpipe.functionCall({ - function: { - name: this.name, - description: this.description, - input: this.input!, - output: this.output!, - }, - vars: promptVariables, - prompt: candidatePrompt, - body: { - max_tokens: this.max_tokens, - temperature: this.temperature, - model: this.model, - stream: false, - }, - openPipeRequestTags: args.openPipeRequestTags - ? { - ...args.openPipeRequestTags, - promptName: formatPromptName(this.name, candidatePrompt.name), - } - : undefined, - enableOpenPipeLogging: args.enableOpenPipeLogging, + const res = await this.run({ + stream: false, + promptVariables, + // @ts-ignore + run: args.run, }); - return res?.summary || ""; + return res || ""; }; const parts = await new RecursiveCharacterTextSplitter({ @@ -142,19 +127,22 @@ export class RecursiveTwitterSummarizer extends Prompt< .map((x) => x.replace(/\r?\n/g, " ")); if ( summaries.length <= 1 || - (await tokenize(summaries.reduce((a, b) => a + " " + b))).length <= 2200 + (await tokenize(summaries.join(" "))).length <= 2200 ) { return summaries.join(" "); + } else { + console.log("summaries"); + console.log(summaries); + const text = summaries.join("\n---\n"); + const parts = await new RecursiveCharacterTextSplitter({ + separators: ["---"], + chunkSize: maxTokens, + chunkOverlap: 100, + }).splitText(text); + return await summarizeRecursively( + await Promise.all(parts.map(callApi)) + ); } - console.log("summaries"); - console.log(summaries); - const text = summaries.join("\n---\n"); - const parts = await new RecursiveCharacterTextSplitter({ - separators: ["---"], - chunkSize: maxTokens, - chunkOverlap: 200, - }).splitText(text); - return await summarizeRecursively(await Promise.all(parts.map(callApi))); }; const summary = await summarizeRecursively(summaries); @@ -174,10 +162,22 @@ export const recursivelySummarizeTweets = () => if (require.main === module) { (async () => { + const p = path.resolve("packages/cli/.env"); + dotenv.config({ path: p }); + console.log("SUMMARIZE TWEETS"); + console.time("SUMMARIZE TWEETS"); + const run = new DefaultRun(); const sum = await recursivelySummarizeTweets().execute({ user: experilearningTwitterUser, - tweets: experilearningTweets, + tweets: experilearningTweets.slice(0, 20), + run, + }); + const cost = await calculateCost({ + calls: run.getSuccessfulModelCalls(), + costCalculators: [new OpenAICostCalculator()], }); + console.log("cost", cost.formatAsDollarAmount({ decimals: 2 })); + console.timeEnd("SUMMARIZE TWEETS"); console.log("FINAL SUMMARY"); console.log(sum); })(); diff --git a/packages/cli/src/recommender/prompts/titleClip/titleClip.ts b/packages/cli/src/recommender/prompts/titleClip/titleClip.ts index 6a68fb5..17798cd 100644 --- a/packages/cli/src/recommender/prompts/titleClip/titleClip.ts +++ b/packages/cli/src/recommender/prompts/titleClip/titleClip.ts @@ -5,6 +5,7 @@ import { titleClipInputSchema, } from "./schemas/titleClipInputSchema"; import { titleClipOutputSchema } from "./schemas/titleClipOutputSchema"; +import { DefaultRun } from "modelfusion"; export const TITLE_CLIP = "Title Clip"; @@ -23,11 +24,12 @@ class TitleClip extends Prompt< }); } - async execute(args: TitleClipInput) { + async execute(args: TitleClipInput & { run?: DefaultRun }) { try { return this.run({ stream: false, promptVariables: args, + run: args.run, }); } catch (e) { console.error(e); diff --git a/packages/server/package.json b/packages/server/package.json index 1a02d0c..9e3930c 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -4,6 +4,7 @@ "main": "./dist/src/main.js", "version": "0.0.0", "dependencies": { + "@modelfusion/cost-calculator": "^0.1.0", "@prisma/client": "5.8.0", "@quixo3/prisma-session-store": "^3.1.13", "@react-email/components": "^0.0.15", @@ -17,6 +18,7 @@ "express-session": "^1.17.3", "graphile-worker": "^0.16.2", "graphile-worker-zod": "0.0.2", + "modelfusion": "^0.137.0", "nodemon": "^3.0.2", "passport": "^0.7.0", "passport-twitter": "^1.0.4", diff --git a/packages/server/prisma/migrations/20240309154943_add_csot/migration.sql b/packages/server/prisma/migrations/20240309154943_add_csot/migration.sql new file mode 100644 index 0000000..263cdab --- /dev/null +++ b/packages/server/prisma/migrations/20240309154943_add_csot/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "PipelineTask" ADD COLUMN "costInMillicents" INTEGER; diff --git a/packages/server/prisma/schema.prisma b/packages/server/prisma/schema.prisma index db95055..4dc464f 100644 --- a/packages/server/prisma/schema.prisma +++ b/packages/server/prisma/schema.prisma @@ -219,6 +219,7 @@ model PipelineTask { name String status TaskStatus @default(running) logs PipelineTaskLog[] + costInMillicents Int? } model PipelineTaskLog { diff --git a/packages/server/src/tasks/saga.ts b/packages/server/src/tasks/saga.ts index 791f7b1..7bfea7f 100644 --- a/packages/server/src/tasks/saga.ts +++ b/packages/server/src/tasks/saga.ts @@ -1,5 +1,6 @@ import { Job, JobHelpers } from "graphile-worker"; import { createTask, TypedTask } from "graphile-worker-zod"; +import { DefaultRun } from "modelfusion"; import { ZodTypeAny, z } from "zod"; export type TaskNamePayloadMaps = { @@ -33,6 +34,7 @@ type SagaJobHelpers = JobHelpers & { cancel: (reason?: string) => void; logInfo: (message: string) => void; logDebug: (message: string) => void; + run: DefaultRun; }; class CancelError extends Error {} @@ -47,6 +49,7 @@ const makeSagaJobHelpers = (helpers: JobHelpers): SagaJobHelpers => { helpers.logger.info(message, { jobId: helpers.job.id }), logDebug: (message: string) => helpers.logger.debug(message, { jobId: helpers.job.id }), + run: new DefaultRun(), }; }; @@ -63,6 +66,7 @@ type Step = { runResult: StepResult, helpers: JobHelpers ) => Promise; + maxAttempts?: number; }; type StepTemplate = Step; @@ -107,6 +111,11 @@ interface SagaOptions { step: StepTemplate, helpers: SagaJobHelpers ) => Promise; + afterStage?: ( + initialPayload: InitialPayload, + step: StepTemplate, + helpers: SagaJobHelpers + ) => Promise; } const wrappedRunPayloadSchema = z.object({ @@ -187,10 +196,16 @@ export class Saga< const nextStep = this.steps[stepIdx + 1]; if (nextStep) { const nextJobName = `${this.sagaName}|${nextStep.name}`; - await helpers.addJob(nextJobName, { - initialPayload, - previousResults: accumulatedResults, - }); + await helpers.addJob( + nextJobName, + { + initialPayload, + previousResults: accumulatedResults, + }, + { + maxAttempts: nextStep.maxAttempts, + } + ); } // If next step is null, the saga is done! diff --git a/packages/server/src/tasks/twitterPipeline.saga.ts b/packages/server/src/tasks/twitterPipeline.saga.ts index ff4d1e5..9424ba2 100644 --- a/packages/server/src/tasks/twitterPipeline.saga.ts +++ b/packages/server/src/tasks/twitterPipeline.saga.ts @@ -37,6 +37,11 @@ import { ArticleSnippetWithScore } from "shared/src/manual/ArticleSnippet"; import { chunksToClips } from "cli/src/recommender/chunksToClips"; import { titleClip } from "cli/src/recommender/prompts/titleClip/titleClip"; import { answersQuestion } from "cli/src/recommender/prompts/answersQuestion/answersQuestion"; +import { DefaultRun } from "modelfusion"; +import { + OpenAICostCalculator, + calculateCost, +} from "@modelfusion/cost-calculator"; type QueryWithSearchResultWithTranscript = { searchResults: (VideoResultWithTranscriptFile | MetaphorArticleResult)[]; @@ -47,6 +52,8 @@ type QueryWithSearchResultWithTranscript = { type YouTubeRAGInput = RAGInput & { metadata: YTMetadata }; type ArticleRAGInput = RAGInput & { metadata: HighlightMetadata }; +const MAX_TWEETS = 120; + export const twitterPipeline = new Saga( "twitter-pipeline-v1", z.object({ @@ -96,6 +103,38 @@ export const twitterPipeline = new Saga( }); } }, + async afterStage(initialPayload, step, helpers) { + const pipeline = await prisma.pipelineRun.findUnique({ + where: { + jobKeyId: initialPayload.runId, + }, + }); + + const existingTask = await prisma.pipelineTask.findFirst({ + where: { + jobId: helpers.job.id, + pipelineRunId: pipeline?.id, + }, + }); + + const cost = await calculateCost({ + calls: helpers.run.getSuccessfulModelCalls(), + costCalculators: [new OpenAICostCalculator()], + }); + + if (existingTask && pipeline) { + await prisma.pipelineTask.update({ + where: { + id: existingTask.id, + pipelineRunId: pipeline?.id, + }, + data: { + costInMillicents: + (existingTask.costInMillicents || 0) + cost.costInMillicents, + }, + }); + } + }, } ) .addStep({ @@ -140,21 +179,21 @@ export const twitterPipeline = new Saga( () => twitter.tweets.fetch({ user: initialPayload.username, - n_tweets: 200, + n_tweets: MAX_TWEETS, since_id: lastSavedTweet?.tweetId, }), helpers.logger ) - ).slice(0, 300); + ).slice(0, MAX_TWEETS); helpers.logInfo(`Fetched ${tweets.length} new tweets from Twitter`); const reuseExistingSummary = tweets.length === 0; helpers.logInfo(`Reuse existing summary: ${reuseExistingSummary}`); - if (tweets.length < 300 && lastSavedTweet) { + if (tweets.length < MAX_TWEETS && lastSavedTweet) { const moreTweets = await getSavedTweetsForUser({ username: initialPayload.username, before: lastSavedTweet.tweetedAt.toISOString(), - limit: 300 - tweets.length, + limit: MAX_TWEETS - tweets.length, }); helpers.logInfo(`Fetched ${moreTweets.length} saved tweets from DB`); tweets.push(...moreTweets.map((x) => x.data)); @@ -170,6 +209,7 @@ export const twitterPipeline = new Saga( }) .addStep({ name: "summarize-tweets", + maxAttempts: 3, run: async (initialPayload, priorResults, helpers) => { const { username } = initialPayload; helpers.logInfo(`Summarizing tweets for Twitter user @${username}`); @@ -229,6 +269,7 @@ export const twitterPipeline = new Saga( user: twitterUser, tweets: priorResults["get-tweets"].tweets, enableOpenPipeLogging: initialPayload.enableOpenPipeLogging, + run: helpers.run, }); if (!profile) { @@ -263,11 +304,6 @@ export const twitterPipeline = new Saga( queries.push(...initialPayload.queries); } else { const res = await createQueriesFromProfile().execute({ - enableOpenPipeLogging: initialPayload.enableOpenPipeLogging, - openPipeRequestTags: createRequestTags({ - runId: initialPayload.runId, - user: initialPayload.username, - }), profile: typeof priorResults["summarize-tweets"].profile?.content, bio: priorResults["summarize-tweets"].twitterUser?.rawDescription || "", @@ -278,12 +314,8 @@ export const twitterPipeline = new Saga( const queriesWithQuestions = await pAll( queries.map((query) => async () => { const questions = await brainstormQuestions().execute({ + run: helpers.run, query, - enableOpenPipeLogging: initialPayload.enableOpenPipeLogging, - openPipeRequestTags: createRequestTags({ - runId: initialPayload.runId, - user: initialPayload.username, - }), }); return { query, @@ -507,6 +539,7 @@ export const twitterPipeline = new Saga( }) .addStep({ name: "clean-clips", + maxAttempts: 3, run: async (initialPayload, priorResults, helpers) => { const { clips } = priorResults["rag"]; const { queriesWithQuestions } = priorResults["create-queries-metaphor"]; @@ -516,21 +549,24 @@ export const twitterPipeline = new Saga( const answersQ = await answersQuestion().execute({ question: clip.question, text: clip.text, + run: helpers.run, }); if (!answersQ.answersQuestion) { return null; } if (clip.type === "article") { - const result = await findStartOfAnswer().execute({ + const quotedAnswer = await findStartOfAnswer().execute({ question, text: clip.text, + run: helpers.run, }); - if (result?.quotedAnswer) { - const match = nearestSubstring(result.quotedAnswer, clip.text); + if (quotedAnswer) { + const match = nearestSubstring(quotedAnswer, clip.text); if (match.bestMatch && match.bestScore > 0.8) { const summarizedTitle = await titleClip().execute({ clip: clip.text, videoTitle: clip.articleTitle, + run: helpers.run, question: clip.question, }); return { @@ -546,6 +582,7 @@ export const twitterPipeline = new Saga( const result = await findStartOfAnswerYouTube().execute({ question, cues: clip.cues, + run: helpers.run, }); if (result?.cueId != null) { const newCues = clip.cues.slice(result.cueId); @@ -553,6 +590,7 @@ export const twitterPipeline = new Saga( clip: clip.text, videoTitle: clip.videoTitle, question: clip.question, + run: helpers.run, }); return { ...clip, diff --git a/packages/shared/src/schemas/pipelinetask.ts b/packages/shared/src/schemas/pipelinetask.ts index 0e2bed1..00a9249 100644 --- a/packages/shared/src/schemas/pipelinetask.ts +++ b/packages/shared/src/schemas/pipelinetask.ts @@ -10,6 +10,7 @@ export const PipelineTaskModel = z.object({ pipelineRunId: z.number().int(), name: z.string(), status: z.nativeEnum(TaskStatus), + costInMillicents: z.number().int().nullish(), }) export interface CompletePipelineTask extends z.infer { diff --git a/yarn.lock b/yarn.lock index ae56f7e..890acb7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1613,6 +1613,13 @@ __metadata: languageName: node linkType: hard +"@modelfusion/cost-calculator@npm:^0.1.0": + version: 0.1.0 + resolution: "@modelfusion/cost-calculator@npm:0.1.0" + checksum: d30b72b26bed7fafdf6b1381397a5a72bfbfd77c4bbd7d4bf22a965f79506c9f59f6bb4381c5a03e3c59ee151f14b55ffdb8f019fa5bde21a4922e9d52853a2d + languageName: node + linkType: hard + "@mongodb-js/saslprep@npm:^1.1.0": version: 1.1.4 resolution: "@mongodb-js/saslprep@npm:1.1.4" @@ -5613,9 +5620,10 @@ __metadata: inquirer-search-list: "npm:^1.2.6" js-tiktoken: "npm:^1.0.8" llamaindex: "npm:^0.1.2" + modelfusion: "npm:^0.137.0" openai: "npm:^4.20.1" openpipe: "npm:^0.8.0" - prompt-iteration-assistant: "npm:^0.0.34" + prompt-iteration-assistant: "npm:^0.0.37" promptfoo: "npm:^0.28.1" pybridge-zod: "npm:^1.1.0" remeda: "npm:^1.29.0" @@ -10692,29 +10700,19 @@ __metadata: languageName: node linkType: hard -"modelfusion-experimental@npm:^0.2.0": - version: 0.2.0 - resolution: "modelfusion-experimental@npm:0.2.0" - dependencies: - zod: "npm:3.22.4" - peerDependencies: - modelfusion: ">=0.105.0" - checksum: 24234d58db13f1351836444527991e8202581b69f47311333239e847b66f462f8c0cd36c07cc5c88b69362ff4b3c196d1357c72e91a181400c5ad9237a62a329 - languageName: node - linkType: hard - -"modelfusion@npm:^0.107.0": - version: 0.107.0 - resolution: "modelfusion@npm:0.107.0" +"modelfusion@npm:^0.137.0": + version: 0.137.0 + resolution: "modelfusion@npm:0.137.0" dependencies: eventsource-parser: "npm:1.1.1" js-tiktoken: "npm:1.0.7" nanoid: "npm:3.3.6" secure-json-parse: "npm:2.7.0" + type-fest: "npm:4.9.0" ws: "npm:8.14.2" zod: "npm:3.22.4" zod-to-json-schema: "npm:3.22.3" - checksum: d54654e091b80b8a3cf9ee3eb1fbb36274675cdaa8c470552d60822ef96c0fe177fc04a052052b861cb4ff932790c745b92c82652233f9b5948008a0b8c9a76f + checksum: 49391c04f2b23d4df725246ac28491c0f3ada70809fb644d15ee86357ce363a4172c83e4c7421db0705a50a7be4a11f9f0d74dec522f6728f8ae4cf59635344d languageName: node linkType: hard @@ -12304,11 +12302,12 @@ __metadata: languageName: node linkType: hard -"prompt-iteration-assistant@npm:^0.0.34": - version: 0.0.34 - resolution: "prompt-iteration-assistant@npm:0.0.34" +"prompt-iteration-assistant@npm:^0.0.37": + version: 0.0.37 + resolution: "prompt-iteration-assistant@npm:0.0.37" dependencies: "@inquirer/prompts": "npm:^3.3.0" + "@modelfusion/cost-calculator": "npm:^0.1.0" boxen: "npm:^7.1.1" chalk: "npm:^4.1.2" cli-highlight: "npm:^2.1.11" @@ -12319,15 +12318,14 @@ __metadata: json-schema-to-typescript: "npm:^13.1.1" marked: "npm:^10.0.0" marked-terminal: "npm:^6.1.0" - modelfusion: "npm:^0.107.0" - modelfusion-experimental: "npm:^0.2.0" + modelfusion: "npm:^0.137.0" openai: "npm:^4.20.1" promptfoo: "npm:^0.31.1" remeda: "npm:^1.29.0" tty-table: "npm:^4.2.3" zod: "npm:^3.22.4" zod-to-json-schema: "npm:^3.22.1" - checksum: cf59b33eef731b6bfa9c7198861be00aa6742b0449dd3dcf70caff33296fe053229bc9295d9caeeef1386ede280e028fafafb4966b42b8d5aa09014e96cdedc2 + checksum: 1ef90850e841d0202967bcc8adda337c338bbc5b77e6334550eb30c60696821fd61998257789dbb5d910e16a3941f0a6faa3293c1900dd27d5219a04fab8892d languageName: node linkType: hard @@ -13594,6 +13592,7 @@ __metadata: version: 0.0.0-use.local resolution: "server@workspace:packages/server" dependencies: + "@modelfusion/cost-calculator": "npm:^0.1.0" "@prisma/client": "npm:5.8.0" "@quixo3/prisma-session-store": "npm:^3.1.13" "@react-email/components": "npm:^0.0.15" @@ -13614,6 +13613,7 @@ __metadata: express-session: "npm:^1.17.3" graphile-worker: "npm:^0.16.2" graphile-worker-zod: "npm:0.0.2" + modelfusion: "npm:^0.137.0" nodemon: "npm:^3.0.2" passport: "npm:^0.7.0" passport-twitter: "npm:^1.0.4" @@ -15040,6 +15040,13 @@ __metadata: languageName: node linkType: hard +"type-fest@npm:4.9.0": + version: 4.9.0 + resolution: "type-fest@npm:4.9.0" + checksum: 7e6423f7337928a7323ce8a68cfbbaf30ecb70b9c635207899e58297d219c71be4a8c50b52afb9fe09c9f44b2c4276d0a44bb95acabab7bc942455f980aad267 + languageName: node + linkType: hard + "type-fest@npm:^0.13.1": version: 0.13.1 resolution: "type-fest@npm:0.13.1"