From df68a6ba1bef81a644f8bb8c312eede4d462c41c Mon Sep 17 00:00:00 2001 From: mkw_pop Date: Fri, 1 Dec 2023 15:50:00 -0800 Subject: [PATCH 1/2] added mood caching --- schema.sql | 11 +++++ src/api/ollamarunner.py | 105 ++++++++++++++++++++++++++++++++++++++++ src/api/server.py | 4 +- src/api/songs.py | 72 +++++++++------------------ 4 files changed, 143 insertions(+), 49 deletions(-) create mode 100644 src/api/ollamarunner.py diff --git a/schema.sql b/schema.sql index afa438d..be6d992 100644 --- a/schema.sql +++ b/schema.sql @@ -91,6 +91,17 @@ create table constraint users_playlist_position_user_id_fkey foreign key (user_id) references users (id) on delete cascade ) tablespace pg_default; +create table + public.user_moods ( + user_id bigint generated by default as identity, + last_updated timestamp with time zone not null default now(), + mood text not null, + songs_played bigint not null, + constraint user_moods_pkey primary key (user_id), + constraint user_moods_user_id_key unique (user_id), + constraint user_moods_user_id_fkey foreign key (user_id) references users (id) on update cascade on delete cascade + ) tablespace pg_default; + INSERT INTO songs (song_name, artist, album) VALUES ('Mr. Brightside', 'The Killers', 'Hot Fuss'); diff --git a/src/api/ollamarunner.py b/src/api/ollamarunner.py new file mode 100644 index 0000000..e30ea8f --- /dev/null +++ b/src/api/ollamarunner.py @@ -0,0 +1,105 @@ +from fastapi import APIRouter, Header +from pydantic import BaseModel +import sqlalchemy +from src import database as db +import json +import requests +import os +import random +import queue +import threading + +q = queue.Queue() + +def gen_mood(user_id) -> str: + with db.engine.begin() as conn: + result = conn.execute(sqlalchemy.text(""" + SELECT COUNT(*) FROM song_history + WHERE user_id = :user_id + """), + [{ + "user_id": user_id + }]).scalar_one() + + if result < 5: + return None + + result = conn.execute(sqlalchemy.text(""" + SELECT song_name, artist + FROM song_history + JOIN songs ON song_history.song_id = songs.id + WHERE user_id = 23 + ORDER BY song_history.created_at DESC + LIMIT 5 + """)).all() + + song_prompt = "Songs:\n" + for song in result: + song_prompt += song.song_name + " by " + song.artist + "\n" + + payload = json.dumps({ + "model": "llama2-uncensored", + "system": "Classify the user's mood based on the following song titles into only one of these emotions: Happy, Sad, Angry. Only include the classification as one word.", + "prompt": song_prompt, + "stream": False + }) + headers = { + 'Content-Type': 'application/json' + } + + print("Getting Sentiment from OLLAMA") + + response = requests.request("POST", os.environ.get("OLLAMA_URI"), headers=headers, data=payload) + response = response.json() + print(response) + + mood = "" + if "happy" in response["response"].lower(): + mood = "HAPPY" + elif "sad" in response["response"].lower(): + mood = "SAD" + elif "angry" in response["response"].lower(): + mood = "ANGRY" + + if mood == "": + return None + return mood + +def thread_func(jobs=queue.Queue): + print("ollama runner daemon started...") + while True: + if not jobs.empty(): + # there is a job + user_id = jobs.get() + with db.engine.begin() as conn: + result = conn.execute(sqlalchemy.text(""" + SELECT COUNT(*) + FROM song_history + WHERE user_id = :user_id + """), + [{ + "user_id":user_id + }]).scalar_one_or_none() + if result is None or result < 5: + continue + else: + mood = gen_mood(user_id) + with db.engine.begin() as conn: + conn.execute(sqlalchemy.text(""" + INSERT INTO user_moods(mood, songs_played,user_id) + VALUES(:mood, 0, :user_id) + ON CONFLICT (user_id) + DO UPDATE SET + last_updated=now(), + mood=:mood, + songs_played=0 + WHERE user_moods.user_id = :user_id + """ + ), [{ + "user_id": user_id, + "mood":mood + }]) + +def start_daemon(): + t = threading.Thread(daemon=True, target=thread_func, args=(q,)) + t.start() \ No newline at end of file diff --git a/src/api/server.py b/src/api/server.py index 597661b..eb8ed73 100644 --- a/src/api/server.py +++ b/src/api/server.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, exceptions from fastapi.responses import JSONResponse from pydantic import ValidationError -from src.api import users, songs, playlists, ad +from src.api import users, songs, playlists, ad, ollamarunner import json import logging @@ -25,6 +25,8 @@ app.include_router(playlists.router) app.include_router(ad.router) +ollamarunner.start_daemon() + @app.exception_handler(exceptions.RequestValidationError) @app.exception_handler(ValidationError) async def validation_exception_handler(request, exc): diff --git a/src/api/songs.py b/src/api/songs.py index 7a8334c..f40898f 100644 --- a/src/api/songs.py +++ b/src/api/songs.py @@ -2,11 +2,13 @@ from pydantic import BaseModel import sqlalchemy from src import database as db +from src.api.ollamarunner import q import json import requests import os import random + router = APIRouter( prefix="/song", tags=["song"], @@ -181,56 +183,27 @@ def play_ad_if_needed(conn, user_id) -> str | None: if random.choice([True, False]): return None + # Check if mood is already cached + result = conn.execute(sqlalchemy.text(""" - SELECT COUNT(*) FROM song_history - WHERE user_id = :user_id - """), - [{ - "user_id": user_id - }]).scalar_one() - - if result < 5: + SELECT last_updated, mood, songs_played + FROM user_moods + WHERE user_id = :user_id + """ + ), [{ + "user_id": user_id + }]).one_or_none() + if result is None: + # no mood calculated + # mood = gen_mood(conn, user_id) + print("calling ollama") + q.put(user_id) return None - - result = conn.execute(sqlalchemy.text(""" - SELECT song_name, artist - FROM song_history - JOIN songs ON song_history.song_id = songs.id - WHERE user_id = 23 - ORDER BY song_history.created_at DESC - LIMIT 5 - """)).all() - - song_prompt = "Songs:\n" - for song in result: - song_prompt += song.song_name + " by " + song.artist + "\n" - - payload = json.dumps({ - "model": "llama2-uncensored", - "system": "Classify the user's mood based on the following song titles into only one of these emotions: Happy, Sad, Angry. Only include the classification as one word.", - "prompt": song_prompt, - "stream": False - }) - headers = { - 'Content-Type': 'application/json' - } - - print("Getting Sentiment from OLLAMA") + elif result.songs_played >= 5: - response = requests.request("POST", os.environ.get("OLLAMA_URI"), headers=headers, data=payload) - response = response.json() - print(response) - - mood = "" - if "happy" in response["response"].lower(): - mood = "HAPPY" - elif "sad" in response["response"].lower(): - mood = "SAD" - elif "angry" in response["response"].lower(): - mood = "ANGRY" - - if mood == "": - return None + print("calling ollama") + q.put(user_id) + mood = result.mood result = conn.execute(sqlalchemy.text(""" SELECT link FROM ad_campaigns @@ -300,7 +273,10 @@ def play_song(song_id: int, user_id: str = Header(None)) -> SongResponse: return SongResponse(url=ad_link, is_ad=True) conn.execute(sqlalchemy.text(""" - INSERT INTO song_history (user_id, song_id) VALUES (:user_id, :song_id) + INSERT INTO song_history (user_id, song_id) VALUES (:user_id, :song_id); + UPDATE user_moods + SET songs_played = songs_played + 1 + WHERE user_id = :user_id """), [{ "song_id": song_id, From 67b4899de2df667012573a575e64c443c9c65bbb Mon Sep 17 00:00:00 2001 From: mkw_pop Date: Fri, 1 Dec 2023 15:51:49 -0800 Subject: [PATCH 2/2] lint fix --- src/api/ollamarunner.py | 3 --- src/api/songs.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/api/ollamarunner.py b/src/api/ollamarunner.py index e30ea8f..b5a49dd 100644 --- a/src/api/ollamarunner.py +++ b/src/api/ollamarunner.py @@ -1,11 +1,8 @@ -from fastapi import APIRouter, Header -from pydantic import BaseModel import sqlalchemy from src import database as db import json import requests import os -import random import queue import threading diff --git a/src/api/songs.py b/src/api/songs.py index f40898f..074f151 100644 --- a/src/api/songs.py +++ b/src/api/songs.py @@ -3,9 +3,6 @@ import sqlalchemy from src import database as db from src.api.ollamarunner import q -import json -import requests -import os import random