diff --git a/demo/app.py b/demo/app.py index 0776dbf..93d2523 100644 --- a/demo/app.py +++ b/demo/app.py @@ -46,13 +46,6 @@ st.title("Cleaned Text") st.text_area(f"Total Length: {len(clean_text)}", f"{clean_text[:500]} . . .") - # I set this value as a quick safeguard but we should actually tokenize the text and count the number of real tokens. - if len(clean_text) > 4096 * 3: - st.warning( - f"Input text is too big ({len(clean_text)}). Using only a subset of it ({4096 * 3})." - ) - clean_text = clean_text[: 4096 * 3] - repo_name = st.selectbox("Select Repo", CURATED_REPOS) model_name = st.selectbox( "Select Model", @@ -67,6 +60,15 @@ with st.spinner("Downloading and Loading Model..."): model = load_llama_cpp_model(model_id=f"{repo_name}/{model_name}") + # ~4 characters per token is considered a reasonable default. + max_characters = model.n_ctx() * 4 + if len(clean_text) > max_characters: + st.warning( + f"Input text is too big ({len(clean_text)})." + f" Using only a subset of it ({max_characters})." + ) + clean_text = clean_text[:max_characters] + system_prompt = st.text_area("Podcast generation prompt", value=PODCAST_PROMPT) if st.button("Generate Podcast Script"):