Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizable AI templates #11884

Merged
merged 10 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ dependencies {
exclude group: 'org.jetbrains.kotlin'
}


implementation 'org.apache.velocity:velocity-engine-core:2.3'
implementation platform('ai.djl:bom:0.30.0')
implementation 'ai.djl:api'
implementation 'ai.djl.huggingface:tokenizers'
Expand Down
1 change: 1 addition & 0 deletions src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
uses ai.djl.repository.RepositoryFactory;
uses ai.djl.repository.zoo.ZooProvider;
uses dev.langchain4j.spi.prompt.PromptTemplateFactory;
requires velocity.engine.core;
// endregion

// region: Lucene
Expand Down
30 changes: 26 additions & 4 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,6 @@
</children>
</HBox>

<ResizableTextArea
fx:id="instructionTextArea"
wrapText="true"/>

<GridPane hgap="10" vgap="10">
<columnConstraints>
<ColumnConstraints hgrow="ALWAYS" percentWidth="50" />
Expand Down Expand Up @@ -235,5 +231,31 @@
glyph="REFRESH"/>
</graphic>
</Button>

<HBox alignment="BASELINE_CENTER">
<Label styleClass="sectionHeader"
text="%Templates"
maxWidth="Infinity"
HBox.hgrow="ALWAYS"/>
<Button fx:id="templatesHelp"
prefWidth="20.0"/>
</HBox>

<ComboBox
fx:id="currentEditingTemplateComboBox"
maxWidth="1.7976931348623157E308"
HBox.hgrow="ALWAYS"/>

<ResizableTextArea
fx:id="currentEditingTemplateSourceTextArea"
wrapText="true"/>

<Button onAction="#onResetTemplatesButtonClick"
text="%Reset templates to default">
<graphic>
<JabRefIconView
glyph="REFRESH"/>
</graphic>
</Button>
</children>
</fx:root>
28 changes: 21 additions & 7 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.jabref.gui.preferences.AbstractPreferenceTabView;
import org.jabref.gui.preferences.PreferencesTab;
import org.jabref.gui.util.ViewModelListCellFactory;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.help.HelpFile;
import org.jabref.logic.l10n.Localization;
import org.jabref.model.ai.AiProvider;
Expand Down Expand Up @@ -41,16 +42,19 @@ public class AiTab extends AbstractPreferenceTabView<AiTabViewModel> implements

@FXML private TextField apiBaseUrlTextField;
@FXML private SearchableComboBox<EmbeddingModel> embeddingModelComboBox;
@FXML private ResizableTextArea instructionTextArea;
@FXML private TextField temperatureTextField;
@FXML private IntegerInputField contextWindowSizeTextField;
@FXML private IntegerInputField documentSplitterChunkSizeTextField;
@FXML private IntegerInputField documentSplitterOverlapSizeTextField;
@FXML private IntegerInputField ragMaxResultsCountTextField;
@FXML private TextField ragMinScoreTextField;

@FXML private ComboBox<AiTemplate> currentEditingTemplateComboBox;
@FXML private ResizableTextArea currentEditingTemplateSourceTextArea;

@FXML private Button generalSettingsHelp;
@FXML private Button expertSettingsHelp;
@FXML private Button templatesHelp;

private final ControlsFxVisualizer visualizer = new ControlsFxVisualizer();

Expand All @@ -72,14 +76,14 @@ public void initialize() {
new ViewModelListCellFactory<AiProvider>()
.withText(AiProvider::toString)
.install(aiProviderComboBox);
aiProviderComboBox.setItems(viewModel.aiProvidersProperty());
aiProviderComboBox.itemsProperty().bind(viewModel.aiProvidersProperty());
aiProviderComboBox.valueProperty().bindBidirectional(viewModel.selectedAiProviderProperty());
aiProviderComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());

new ViewModelListCellFactory<String>()
.withText(text -> text)
.install(chatModelComboBox);
chatModelComboBox.setItems(viewModel.chatModelsProperty());
chatModelComboBox.itemsProperty().bind(viewModel.chatModelsProperty());
chatModelComboBox.valueProperty().bindBidirectional(viewModel.selectedChatModelProperty());
chatModelComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());

Expand Down Expand Up @@ -112,9 +116,6 @@ public void initialize() {
apiBaseUrlTextField.setDisable(newValue || viewModel.disableExpertSettingsProperty().get())
);

instructionTextArea.textProperty().bindBidirectional(viewModel.instructionProperty());
instructionTextArea.disableProperty().bind(viewModel.disableExpertSettingsProperty());

// bindBidirectional doesn't work well with number input fields ({@link IntegerInputField}, {@link DoubleInputField}),
// so they are expanded into `addListener` calls.

Expand Down Expand Up @@ -169,7 +170,6 @@ public void initialize() {
visualizer.initVisualization(viewModel.getChatModelValidationStatus(), chatModelComboBox);
visualizer.initVisualization(viewModel.getApiBaseUrlValidationStatus(), apiBaseUrlTextField);
visualizer.initVisualization(viewModel.getEmbeddingModelValidationStatus(), embeddingModelComboBox);
visualizer.initVisualization(viewModel.getSystemMessageValidationStatus(), instructionTextArea);
visualizer.initVisualization(viewModel.getTemperatureTypeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getTemperatureRangeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getMessageWindowSizeValidationStatus(), contextWindowSizeTextField);
Expand All @@ -180,9 +180,18 @@ public void initialize() {
visualizer.initVisualization(viewModel.getRagMinScoreRangeValidationStatus(), ragMinScoreTextField);
});

new ViewModelListCellFactory<AiTemplate>()
.withText(AiTemplate::getLocalizedName)
.install(currentEditingTemplateComboBox);
currentEditingTemplateComboBox.itemsProperty().bind(viewModel.templatesProperty());
currentEditingTemplateComboBox.valueProperty().bindBidirectional(viewModel.currentEditingTemplate());

currentEditingTemplateSourceTextArea.textProperty().bindBidirectional(viewModel.currentEditingTemplateSource());

ActionFactory actionFactory = new ActionFactory();
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_GENERAL_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), generalSettingsHelp);
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_EXPERT_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), expertSettingsHelp);
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_TEMPLATES, dialogService, preferences.getExternalApplicationsPreferences()), templatesHelp);
}

@Override
Expand All @@ -195,6 +204,11 @@ private void onResetExpertSettingsButtonClick() {
viewModel.resetExpertSettings();
}

@FXML
private void onResetTemplatesButtonClick() {
viewModel.resetTemplates();
}

public ReadOnlyBooleanProperty aiEnabledProperty() {
return enableAi.selectedProperty();
}
Expand Down
66 changes: 48 additions & 18 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.jabref.gui.preferences.ai;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand All @@ -21,6 +22,7 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.preferences.CliPreferences;
import org.jabref.logic.util.LocalizedNumbers;
Expand Down Expand Up @@ -77,7 +79,18 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final StringProperty geminiApiBaseUrl = new SimpleStringProperty();
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();

private final StringProperty instruction = new SimpleStringProperty();
private final ListProperty<AiTemplate> templatesList =
new SimpleListProperty<>(FXCollections.observableArrayList(AiTemplate.values()));
private final ObjectProperty<AiTemplate> currentEditingTemplate = new SimpleObjectProperty<>(AiTemplate.CHATTING_SYSTEM_MESSAGE);

private final Map<AiTemplate, StringProperty> templateSources = Map.of(
AiTemplate.CHATTING_SYSTEM_MESSAGE, new SimpleStringProperty(),
AiTemplate.CHATTING_USER_MESSAGE, new SimpleStringProperty(),
AiTemplate.SUMMARIZATION_CHUNK, new SimpleStringProperty(),
AiTemplate.SUMMARIZATION_COMBINE, new SimpleStringProperty()
);
private final StringProperty currentEditingTemplateSource = new SimpleStringProperty();

private final StringProperty temperature = new SimpleStringProperty();
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
private final IntegerProperty documentSplitterChunkSize = new SimpleIntegerProperty();
Expand All @@ -94,7 +107,6 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final Validator chatModelValidator;
private final Validator apiBaseUrlValidator;
private final Validator embeddingModelValidator;
private final Validator instructionValidator;
private final Validator temperatureTypeValidator;
private final Validator temperatureRangeValidator;
private final Validator contextWindowSizeValidator;
Expand Down Expand Up @@ -214,6 +226,14 @@ public AiTabViewModel(CliPreferences preferences) {
}
});

this.currentEditingTemplateSource.addListener((observable, oldValue, newValue) -> {
templateSources.get(currentEditingTemplate.get()).set(newValue);
});

this.currentEditingTemplate.addListener((observable, oldValue, newValue) -> {
currentEditingTemplateSource.set(templateSources.get(newValue).get());
});

this.apiKeyValidator = new FunctionBasedValidator<>(
currentApiKey,
token -> !StringUtil.isBlank(token),
Expand All @@ -234,11 +254,6 @@ public AiTabViewModel(CliPreferences preferences) {
Objects::nonNull,
ValidationMessage.error(Localization.lang("Embedding model has to be provided")));

this.instructionValidator = new FunctionBasedValidator<>(
instruction,
message -> !StringUtil.isBlank(message),
ValidationMessage.error(Localization.lang("The instruction has to be provided")));

this.temperatureTypeValidator = new FunctionBasedValidator<>(
temperature,
temp -> LocalizedNumbers.stringToDouble(temp).isPresent(),
Expand Down Expand Up @@ -307,7 +322,12 @@ public void setValues() {
customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());

selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());
instruction.setValue(aiPreferences.getInstruction());

Arrays.stream(AiTemplate.values()).forEach(template ->
templateSources.get(template).set(aiPreferences.getTemplate(template)));

currentEditingTemplateSource.set(templateSources.get(currentEditingTemplate.get()).get());

temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
Expand Down Expand Up @@ -345,7 +365,9 @@ public void storeSettings() {
aiPreferences.setGeminiApiBaseUrl(geminiApiBaseUrl.get() == null ? "" : geminiApiBaseUrl.get());
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());

aiPreferences.setInstruction(instruction.get());
Arrays.stream(AiTemplate.values()).forEach(template ->
aiPreferences.setTemplate(template, templateSources.get(template).get()));

// We already check the correctness of temperature and RAG minimum score in validators, so we don't need to check it here.
aiPreferences.setTemperature(LocalizedNumbers.stringToDouble(oldLocale, temperature.get()).get());
aiPreferences.setContextWindowSize(contextWindowSize.get());
Expand All @@ -359,8 +381,6 @@ public void resetExpertSettings() {
String resetApiBaseUrl = AiDefaultPreferences.PROVIDERS_API_URLS.get(selectedAiProvider.get());
currentApiBaseUrl.set(resetApiBaseUrl);

instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);

int resetContextWindowSize = AiDefaultPreferences.CONTEXT_WINDOW_SIZES.getOrDefault(selectedAiProvider.get(), Map.of()).getOrDefault(currentChatModel.get(), 0);
contextWindowSize.set(resetContextWindowSize);

Expand All @@ -371,6 +391,13 @@ public void resetExpertSettings() {
ragMinScore.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.RAG_MIN_SCORE));
}

public void resetTemplates() {
Arrays.stream(AiTemplate.values()).forEach(template ->
templateSources.get(template).set(AiDefaultPreferences.TEMPLATES.get(template)));

currentEditingTemplateSource.set(templateSources.get(currentEditingTemplate.get()).get());
}

@Override
public boolean validateSettings() {
if (enableAi.get()) {
Expand All @@ -397,7 +424,6 @@ public boolean validateExpertSettings() {
List<Validator> validators = List.of(
apiBaseUrlValidator,
embeddingModelValidator,
instructionValidator,
temperatureTypeValidator,
temperatureRangeValidator,
contextWindowSizeValidator,
Expand Down Expand Up @@ -471,8 +497,16 @@ public BooleanProperty disableApiBaseUrlProperty() {
return disableApiBaseUrl;
}

public StringProperty instructionProperty() {
return instruction;
public ListProperty<AiTemplate> templatesProperty() {
return templatesList;
}

public ObjectProperty<AiTemplate> currentEditingTemplate() {
return currentEditingTemplate;
}

public StringProperty currentEditingTemplateSource() {
return currentEditingTemplateSource;
}

public StringProperty temperatureProperty() {
Expand Down Expand Up @@ -523,10 +557,6 @@ public ValidationStatus getEmbeddingModelValidationStatus() {
return embeddingModelValidator.getValidationStatus();
}

public ValidationStatus getSystemMessageValidationStatus() {
return instructionValidator.getValidationStatus();
}

public ValidationStatus getTemperatureTypeValidationStatus() {
return temperatureTypeValidator.getValidationStatus();
}
Expand Down
42 changes: 41 additions & 1 deletion src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.List;
import java.util.Map;

import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;

Expand Down Expand Up @@ -67,7 +68,7 @@ public class AiDefaultPreferences {
public static final boolean CUSTOMIZE_SETTINGS = false;

public static final EmbeddingModel EMBEDDING_MODEL = EmbeddingModel.SENTENCE_TRANSFORMERS_ALL_MINILM_L12_V2;
public static final String SYSTEM_MESSAGE = "You are an AI assistant that analyses research papers. You answer questions about papers. You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'. Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key. Here are the papers you are analyzing:\n";
public static final String SYSTEM_MESSAGE = "";
public static final double TEMPERATURE = 0.7;
public static final int DOCUMENT_SPLITTER_CHUNK_SIZE = 300;
public static final int DOCUMENT_SPLITTER_OVERLAP = 100;
Expand All @@ -76,6 +77,45 @@ public class AiDefaultPreferences {

public static final int CONTEXT_WINDOW_SIZE = 8196;

public static final Map<AiTemplate, String> TEMPLATES = Map.of(
AiTemplate.CHATTING_SYSTEM_MESSAGE,
"You are an AI assistant that analyses research papers. You answer questions about papers.\n" +
"You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'.\n" +
"Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key.\n\n" +
"Here are the papers you are analyzing:\n" +
"#foreach( $entry in $entries )\n" +
"${CanonicalBibEntry.getCanonicalRepresentation($entry)}\n" +
"#end",
InAnYan marked this conversation as resolved.
Show resolved Hide resolved

AiTemplate.CHATTING_USER_MESSAGE, """
$message

Here is some relevant information for you:
#foreach( $excerpt in $excerpts )
${excerpt.citationKey()}:
${excerpt.text()}
#end""",

AiTemplate.SUMMARIZATION_CHUNK, """
Please provide an overview of the following text. It's a part of a scientific paper.
InAnYan marked this conversation as resolved.
Show resolved Hide resolved
The summary should include the main objectives, methodologies used, key findings, and conclusions.
Mention any significant experiments, data, or discussions presented in the paper.

DOCUMENT:
$document

OVERVIEW:""",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand that - this should be generated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About that. There are two ways of asking AI to generate text:

  1. In form of chat: a list of messages. For this situation we need both system + user message templates.
  2. In form of text that should be completed. This is only one template + I've seen that this approach is typically chosen for summarization. And I think it's easier.

For the summarization I chose 2nd approach. That is why you are seeing "OVERVIEW" and "FINAL OVERVIEW"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because without those words AI won't understand what should it do


AiTemplate.SUMMARIZATION_COMBINE, """
You have written an overview of a scientific paper. You have been collecting notes from various parts
of the paper. Now your task is to combine all of the notes in one structured message.

SUMMARIES:
$summaries

FINAL OVERVIEW:"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand that - the FINAL OVERVIEW should be generated? Why is it part of the prompt?

);

public static int getContextWindowSize(AiProvider aiProvider, String model) {
return CONTEXT_WINDOW_SIZES.getOrDefault(aiProvider, Map.of()).getOrDefault(model, 0);
}
Expand Down
Loading
Loading