Skip to content

Commit

Permalink
Merge pull request #51 from StarDylan/cache-mood
Browse files Browse the repository at this point in the history
added mood caching
  • Loading branch information
cup0noodles authored Dec 1, 2023
2 parents 98b76af + 67b4899 commit 600eee8
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 52 deletions.
11 changes: 11 additions & 0 deletions schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ create table
constraint users_playlist_position_user_id_fkey foreign key (user_id) references users (id) on delete cascade
) tablespace pg_default;

create table
public.user_moods (
user_id bigint generated by default as identity,
last_updated timestamp with time zone not null default now(),
mood text not null,
songs_played bigint not null,
constraint user_moods_pkey primary key (user_id),
constraint user_moods_user_id_key unique (user_id),
constraint user_moods_user_id_fkey foreign key (user_id) references users (id) on update cascade on delete cascade
) tablespace pg_default;

INSERT INTO songs (song_name, artist, album)
VALUES ('Mr. Brightside', 'The Killers', 'Hot Fuss');

Expand Down
102 changes: 102 additions & 0 deletions src/api/ollamarunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import sqlalchemy
from src import database as db
import json
import requests
import os
import queue
import threading

q = queue.Queue()

def gen_mood(user_id) -> str:
with db.engine.begin() as conn:
result = conn.execute(sqlalchemy.text("""
SELECT COUNT(*) FROM song_history
WHERE user_id = :user_id
"""),
[{
"user_id": user_id
}]).scalar_one()

if result < 5:
return None

result = conn.execute(sqlalchemy.text("""
SELECT song_name, artist
FROM song_history
JOIN songs ON song_history.song_id = songs.id
WHERE user_id = 23
ORDER BY song_history.created_at DESC
LIMIT 5
""")).all()

song_prompt = "Songs:\n"
for song in result:
song_prompt += song.song_name + " by " + song.artist + "\n"

payload = json.dumps({
"model": "llama2-uncensored",
"system": "Classify the user's mood based on the following song titles into only one of these emotions: Happy, Sad, Angry. Only include the classification as one word.",
"prompt": song_prompt,
"stream": False
})
headers = {
'Content-Type': 'application/json'
}

print("Getting Sentiment from OLLAMA")

response = requests.request("POST", os.environ.get("OLLAMA_URI"), headers=headers, data=payload)
response = response.json()
print(response)

mood = ""
if "happy" in response["response"].lower():
mood = "HAPPY"
elif "sad" in response["response"].lower():
mood = "SAD"
elif "angry" in response["response"].lower():
mood = "ANGRY"

if mood == "":
return None
return mood

def thread_func(jobs=queue.Queue):
print("ollama runner daemon started...")
while True:
if not jobs.empty():
# there is a job
user_id = jobs.get()
with db.engine.begin() as conn:
result = conn.execute(sqlalchemy.text("""
SELECT COUNT(*)
FROM song_history
WHERE user_id = :user_id
"""),
[{
"user_id":user_id
}]).scalar_one_or_none()
if result is None or result < 5:
continue
else:
mood = gen_mood(user_id)
with db.engine.begin() as conn:
conn.execute(sqlalchemy.text("""
INSERT INTO user_moods(mood, songs_played,user_id)
VALUES(:mood, 0, :user_id)
ON CONFLICT (user_id)
DO UPDATE SET
last_updated=now(),
mood=:mood,
songs_played=0
WHERE user_moods.user_id = :user_id
"""
), [{
"user_id": user_id,
"mood":mood
}])

def start_daemon():
t = threading.Thread(daemon=True, target=thread_func, args=(q,))
t.start()
4 changes: 3 additions & 1 deletion src/api/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI, exceptions
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from src.api import users, songs, playlists, ad
from src.api import users, songs, playlists, ad, ollamarunner
import json
import logging

Expand All @@ -25,6 +25,8 @@
app.include_router(playlists.router)
app.include_router(ad.router)

ollamarunner.start_daemon()

@app.exception_handler(exceptions.RequestValidationError)
@app.exception_handler(ValidationError)
async def validation_exception_handler(request, exc):
Expand Down
75 changes: 24 additions & 51 deletions src/api/songs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from pydantic import BaseModel
import sqlalchemy
from src import database as db
import json
import requests
import os
from src.api.ollamarunner import q
import random


router = APIRouter(
prefix="/song",
tags=["song"],
Expand Down Expand Up @@ -181,56 +180,27 @@ def play_ad_if_needed(conn, user_id) -> str | None:
if random.choice([True, False]):
return None

# Check if mood is already cached

result = conn.execute(sqlalchemy.text("""
SELECT COUNT(*) FROM song_history
WHERE user_id = :user_id
"""),
[{
"user_id": user_id
}]).scalar_one()

if result < 5:
SELECT last_updated, mood, songs_played
FROM user_moods
WHERE user_id = :user_id
"""
), [{
"user_id": user_id
}]).one_or_none()
if result is None:
# no mood calculated
# mood = gen_mood(conn, user_id)
print("calling ollama")
q.put(user_id)
return None

result = conn.execute(sqlalchemy.text("""
SELECT song_name, artist
FROM song_history
JOIN songs ON song_history.song_id = songs.id
WHERE user_id = 23
ORDER BY song_history.created_at DESC
LIMIT 5
""")).all()

song_prompt = "Songs:\n"
for song in result:
song_prompt += song.song_name + " by " + song.artist + "\n"

payload = json.dumps({
"model": "llama2-uncensored",
"system": "Classify the user's mood based on the following song titles into only one of these emotions: Happy, Sad, Angry. Only include the classification as one word.",
"prompt": song_prompt,
"stream": False
})
headers = {
'Content-Type': 'application/json'
}

print("Getting Sentiment from OLLAMA")
elif result.songs_played >= 5:

response = requests.request("POST", os.environ.get("OLLAMA_URI"), headers=headers, data=payload)
response = response.json()
print(response)

mood = ""
if "happy" in response["response"].lower():
mood = "HAPPY"
elif "sad" in response["response"].lower():
mood = "SAD"
elif "angry" in response["response"].lower():
mood = "ANGRY"

if mood == "":
return None
print("calling ollama")
q.put(user_id)
mood = result.mood

result = conn.execute(sqlalchemy.text("""
SELECT link FROM ad_campaigns
Expand Down Expand Up @@ -300,7 +270,10 @@ def play_song(song_id: int, user_id: str = Header(None)) -> SongResponse:
return SongResponse(url=ad_link, is_ad=True)

conn.execute(sqlalchemy.text("""
INSERT INTO song_history (user_id, song_id) VALUES (:user_id, :song_id)
INSERT INTO song_history (user_id, song_id) VALUES (:user_id, :song_id);
UPDATE user_moods
SET songs_played = songs_played + 1
WHERE user_id = :user_id
"""),
[{
"song_id": song_id,
Expand Down

0 comments on commit 600eee8

Please sign in to comment.