Skip to content

Commit

Permalink
stat matching between LLM and SeedData. Finished frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
polskiTran committed Apr 23, 2024
1 parent 29d54ec commit 282bee2
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 43 deletions.
2 changes: 1 addition & 1 deletion projects/Mood2SpotifyRec/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Forcing light theme on streamlit app
[theme]
base="light"
base="dark"
22 changes: 22 additions & 0 deletions projects/Mood2SpotifyRec/SeedData.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,25 @@ def get_test_seed_data(self):
target_danceability=0.7,
target_acousticness=0.3,
)

def to_seed_data(self, in_metric):
return SeedData(
seed_genres=["indie", "pop"],
seed_artists=[
"4NHQUGzhtTLFvgF5SZesLK"
], # Example artist ID for Tame Impala
target_valence=in_metric["valence"],
target_energy=in_metric["energy"],
target_danceability=in_metric["danceability"],
target_acousticness=in_metric["acousticness"],
)

def view_seed_data(self):
return {
"seed_genres": self.seed_genres,
"seed_artists": self.seed_artists,
"target_valence": self.target_valence,
"target_energy": self.target_energy,
"target_danceability": self.target_danceability,
"target_acousticness": self.target_acousticness,
}
22 changes: 17 additions & 5 deletions projects/Mood2SpotifyRec/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import google.generativeai as genai


class NatLang2Latext:
class Mood2SpotifyRec:
def __init__(
self, google_api_key: str | None = None, model_name: str = "gemini-pro"
):
Expand Down Expand Up @@ -55,7 +55,7 @@ def convert(self, text: str) -> dict[str, float]:
genai.configure(api_key=self.google_api_key)
model = genai.GenerativeModel(model_name=self.model_name)
prompt = f"<MOOD>{self.prompt}{text}</MOOD>\n<METRICS>\n"
raw_ouput = model.generate_content(
raw_output = model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
Expand All @@ -64,11 +64,23 @@ def convert(self, text: str) -> dict[str, float]:
stop_sequences=["</METRICS>"],
),
).text

# metrics = {}
# for line in raw_ouput.split("\n"):
# if ":" in line:
# key, value = line.split(":")
# metrics[key.strip()] = float(value.strip())
# return metrics
metrics = {}
for line in raw_ouput.split("\n"):
for line in raw_output.split("\n"):
if ":" in line:
key, value = line.split(":")
metrics[key.strip()] = float(value.strip())
range_values = value.strip().split("-")
if len(range_values) == 2:
# Calculate the average of the range values
float_values = [float(x) for x in range_values]
metrics[key.strip()] = sum(float_values) / len(float_values)
else:
# If there's only one value, convert it directly
metrics[key.strip()] = float(value.strip())

return metrics
163 changes: 127 additions & 36 deletions projects/Mood2SpotifyRec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,84 @@
import streamlit.components.v1 as components
import spotify
import SeedData
import os
from llm import Mood2SpotifyRec
from dotenv import load_dotenv
from PIL import Image


# Load environment variables
load_dotenv()

# Streamlit page background
page_bg_img = f"""
<style>
[data-testid="stAppViewContainer"] > .main {{
background-image: url("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQlo3P16G9dPsRjdEhx77zlzr2idJqnXZpLuQ90xNYrpQ&s");
background-size: cover;
background-position: center center;
background-repeat: no-repeat;
background-attachment: local;
}}
[data-testid="stHeader"] {{
background: rgba(0,0,0,0);
}}
</style>
"""
st.markdown(page_bg_img, unsafe_allow_html=True)

# =========================================================================================================
# Logo and Navigation
st.title("Mood 2 Spotify Recommendation")
mood2spotifyrec_logo = Image.open(
"resources/Mood2SpotifyRec_Logov2-fotor-bg-remover-20240423143851.png"
)
st.image(mood2spotifyrec_logo, width=300)
# change the size of the logo and center


# Inputs for Client ID and Client Secret
client_id = st.text_input("Client ID", "")
client_secret = st.text_input("Client Secret", "")
client_id = os.getenv("SPOTIPY_CLIENT_ID")
client_secret = os.getenv("SPOTIPY_CLIENT_SECRET")
sp = spotify.SpotifyTools(client_id, client_secret)
access_token = sp.get_access_token()

# =========================================================================================================
# The user input and recommendation section
converter = Mood2SpotifyRec(google_api_key=os.getenv("GOOGLE_API_KEY"))
# User input for mood
st.subheader("Enter your mood:")
mood = st.text_input("Mood", "")

# Output
if st.button("Get Tracks Recommendations:"):
# Validate inputs
if not client_id or not client_secret:
st.warning("Please fill in all fields.")
if not mood.strip():
st.write("*Please complete the missing fields.")
else:
st.header("Output")
st.write("This is the Spotify metrics based on your mood:")
mood_text = mood.strip()
metrics = converter.convert(mood_text)
st.write(f"`{metrics}`") # Display metrics in code block
rec_seed_data = SeedData.SeedData().to_seed_data(metrics)
user_id = sp.user_id
playlist_id = sp.add_recommend_tracks_to_playlist(
"Mood2Spotify Playlist", rec_seed_data
)
# Your playlist embed URL
playlist_embed_url = "https://open.spotify.com/embed/playlist/" + playlist_id
# Embed the playlist using an iframe
components.html(
f"""<iframe src="{playlist_embed_url}" width=100% height=700 frameborder="0" allowtransparency="true" allow="encrypted-media"></iframe>""",
height=700,
)


# =========================================================================================================
# API test and LLM suggestion test
# Section for Audio Features
# st.subheader("Get Audio Features")
# track_id = st.text_input("Spotify Track ID", "")
Expand Down Expand Up @@ -50,39 +118,62 @@
# else:
# st.error("Failed to retrieve access token for recommendations")

# Section for track recommendations
st.subheader("Get TEST Track Recommendations")
test_seed_data = SeedData.SeedData().get_test_seed_data()
st.write(test_seed_data)
if st.button("Get TEST Tracks Recommendations"):
if not client_id or not client_secret:
st.warning("Please fill in all fields.")
else:
if access_token:
recommended_tracks = sp.get_test_track_recommendation()
track_uris = [track["uri"] for track in recommended_tracks["tracks"]]
# Display the tracks as clickable links
for track in recommended_tracks["tracks"]:
track_name = track["name"]
artists = ", ".join(
artist["name"] for artist in track["album"]["artists"]
)
spotify_url = track["external_urls"]["spotify"]
st.write(f"[{track_name} by {artists}]({spotify_url})")

else:
st.error("Failed to retrieve access token for recommendations")
# Spotify recommendation API test
# st.subheader("Get TEST Track Recommendations")
# test_seed_data = SeedData.SeedData().get_test_seed_data()
# st.write(test_seed_data)
# if st.button("Get TEST Tracks Recommendations"):
# if not client_id or not client_secret:
# st.warning("Please fill in all fields.")
# else:
# if access_token:
# recommended_tracks = sp.get_test_track_recommendation()
# track_uris = [track["uri"] for track in recommended_tracks["tracks"]]
# # Display the tracks as clickable links
# for track in recommended_tracks["tracks"]:
# track_name = track["name"]
# artists = ", ".join(
# artist["name"] for artist in track["album"]["artists"]
# )
# spotify_url = track["external_urls"]["spotify"]
# st.write(f"[{track_name} by {artists}]({spotify_url})")

# else:
# st.error("Failed to retrieve access token for recommendations")

# Show embed TEST recommendation playlist
if st.button("Show playlist"):
user_id = sp.user_id
playlist_id = sp.add_recommend_tracks_to_playlist(
"Mood2Spotify Playlist", test_seed_data
)
# Your playlist embed URL
playlist_embed_url = "https://open.spotify.com/embed/playlist/" + playlist_id
# Embed the playlist using an iframe
components.html(
f"""<iframe src="{playlist_embed_url}" width=100% height=700 frameborder="0" allowtransparency="true" allow="encrypted-media"></iframe>""",
height=700,
)
# if st.button("Show playlist"):
# user_id = sp.user_id
# playlist_id = sp.add_recommend_tracks_to_playlist(
# "Mood2Spotify Playlist", test_seed_data
# )
# # Your playlist embed URL
# playlist_embed_url = "https://open.spotify.com/embed/playlist/" + playlist_id
# # Embed the playlist using an iframe
# components.html(
# f"""<iframe src="{playlist_embed_url}" width=100% height=700 frameborder="0" allowtransparency="true" allow="encrypted-media"></iframe>""",
# height=700,
# )


# =========================================================================================================
# Sidebar
st.sidebar.header("About")
st.sidebar.markdown(
"""
A project by [polskiXO](https://github.com/polskiXO) and [menamerai](https://github.com/menamerai) for EECE3092.
Mood2SpotifyRec is made using Google Gemini LLM to generate seed param based on user input mood to feed into Spotify API recommendation API for playlist generation.
"""
)

st.sidebar.header("Resources")
st.sidebar.markdown(
"""
- [GitHub Repository](https://github.com/polskiXO/python-beginner-projects/tree/2-natlang2latex-project/projects/Mood2SpotifyRec)
- [Streamlit Documentation](https://docs.streamlit.io/)
- [LaTeX Documentation](https://www.latex-project.org/help/documentation/)
- [Google Gemini](https://ai.google.dev/)
- [EECE3092 Course](https://www.ece.mcgill.ca/~ece309/)
"""
)
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion projects/Mood2SpotifyRec/spotify.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def get_tracks_recommendations(self, seed_data):
)
return recommended_tracks

# TODO: remove when done
def get_test_track_recommendation(self):
testSeedData = SeedData.SeedData().get_test_seed_data()
return self.get_tracks_recommendations(testSeedData)
Expand Down

0 comments on commit 282bee2

Please sign in to comment.