Skip to content

Commit

Permalink
Implement model selection for analysis creation
Browse files Browse the repository at this point in the history
  • Loading branch information
fmagin committed Oct 11, 2024
1 parent f955a14 commit 80ddacc
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import ai.reveng.toolkit.ghidra.binarysimularity.ui.functionsimularity.FunctionSimularityDockableDialog;
import ai.reveng.toolkit.ghidra.core.RevEngAIAnalysisStatusChanged;
import ai.reveng.toolkit.ghidra.core.services.api.GhidraRevengService;
import ai.reveng.toolkit.ghidra.core.services.api.ModelName;
import ai.reveng.toolkit.ghidra.core.services.api.types.AnalysisStatus;
import ai.reveng.toolkit.ghidra.core.services.api.types.BinaryHash;
import ai.reveng.toolkit.ghidra.core.services.api.types.BinaryID;
import ai.reveng.toolkit.ghidra.core.services.api.types.ProgramWithBinaryID;
import ai.reveng.toolkit.ghidra.core.services.function.export.ExportFunctionBoundariesService;
import ai.reveng.toolkit.ghidra.core.services.logging.ReaiLoggingService;
import docking.action.builder.ActionBuilder;
import docking.widgets.OptionDialog;
import ghidra.app.context.ProgramActionContext;
import ghidra.app.context.ProgramLocationActionContext;
import ghidra.app.plugin.PluginCategoryNames;
Expand Down Expand Up @@ -131,11 +133,28 @@ private void setupActions() {
"Program has not been auto-analyzed by Ghidra yet. Please run auto-analysis first.");
return;
}
// Get the available models
monitor.setMessage("Getting available models...");
var models = apiService.getAvailableModels();
var suggestedModel = apiService.getModelNameForProgram(context.getProgram(), models);
// Show user a dropdown menu to pick the model
var selectedModel = OptionDialog.showInputChoiceDialog(
null,
ReaiPluginPackage.WINDOW_PREFIX + "Create new Analysis for Binary",
"Select a model to use for analysis",
models.stream().map(ModelName::modelName).toArray(String[]::new),
suggestedModel.modelName(),
OptionDialog.QUESTION_MESSAGE);

if (selectedModel == null) {
// User canceled the model choice dialog, so we cancel the analysis task
return;
}
monitor.setMessage("Uploading binary...");
apiService.upload(context.getProgram());
monitor.setProgress(99);
monitor.setMessage("Launching Analysis");
ProgramWithBinaryID binID = apiService.analyse(context.getProgram());
ProgramWithBinaryID binID = apiService.analyse(context.getProgram(), new ModelName(selectedModel));
Msg.showInfo(this, null, ReaiPluginPackage.WINDOW_PREFIX + "Create new Analysis for Binary",
"Analysis is running for: " + binID + "\n"
+ "You will be notified when the analysis is complete.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,12 @@ public List<Collection> collections() {
}

public ProgramWithBinaryID analyse(Program program) {
var binID = analyse(program,
getModelNameForProgram(program).orElseThrow()
);
return new ProgramWithBinaryID(program, binID);
return analyse(program, getModelNameForProgram(program));
}

public BinaryID analyse(Program program, ModelName modelName){
public ProgramWithBinaryID analyse(Program program, ModelName modelName){
if (programMap.containsKey(program)){
return programMap.get(program);
return new ProgramWithBinaryID(program, programMap.get(program));
}

AnalysisOptionsBuilder builder = new AnalysisOptionsBuilder();
Expand All @@ -311,19 +308,22 @@ public BinaryID analyse(Program program, ModelName modelName){

var binID = api.analyse(builder);
programMap.put(program, binID);
return binID;
return new ProgramWithBinaryID(program, binID);
}

private Optional<ModelName> getModelNameForProgram(Program program){
// TODO: Model name choice will be removed from the client API in the future
private ModelName getModelNameForProgram(Program program){
return getModelNameForProgram(program, this.api.models());
}

public ModelName getModelNameForProgram(Program program, List<ModelName> models){
var s = models.stream().map (ModelName::modelName);
var format = program.getOptions("Program Information").getString("Executable Format", null);
if (format.equals(ElfLoader.ELF_NAME)){
return Optional.of(new ModelName("binnet-0.3-x86-linux"));
s = s.filter(modelName -> modelName.contains("linux"));
} else if (format.equals(PeLoader.PE_NAME)) {
return Optional.of(new ModelName("binnet-0.3-x86-windows"));
s = s.filter(modelName -> modelName.contains("windows"));
}
return Optional.empty();

return new ModelName(s.sorted(Collections.reverseOrder()).toList().get(0));
}

private List<FunctionBoundary> exportFunctionBoundaries(Program program){
Expand Down Expand Up @@ -399,5 +399,9 @@ public String health(){
return api.healthMessage();
}

public List<ModelName> getAvailableModels(){
return api.models();
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public String getAnalysisLogs(BinaryID binID) {
}

@Override
public List<ModelInfo> models() {
public List<ModelName> models() {
return null;
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,12 @@ private HttpRequest.Builder requestBuilderForEndpoint(String endpoint){
return requestBuilder;
}
@Override
public List<ModelInfo> models(){
public List<ModelName> models(){
JSONObject jsonResponse = sendRequest(requestBuilderForEndpoint("models").GET().build());
List<ModelInfo> result = new ArrayList<>();
List<ModelName> result = new ArrayList<>();
jsonResponse.getJSONArray("models").forEach((Object o) -> {
result.add(ModelInfo.fromJSONObject((JSONObject) o));
JSONObject obj = (JSONObject) o;
result.add(new ModelName(obj.getString("model_name")));
});
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ List<FunctionMatch> annSymbolsForFunctions(List<FunctionID> fID,

List<Collection> collectionQuickSearch(ModelName modelName);

List<ModelInfo> models();
List<ModelName> models();

List<Collection> collectionQuickSearch(String searchTerm);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import ai.reveng.toolkit.ghidra.ReaiPluginPackage;
import ai.reveng.toolkit.ghidra.core.models.ReaiConfig;
import ai.reveng.toolkit.ghidra.core.services.logging.ReaiLoggingService;
import ai.reveng.toolkit.ghidra.core.ui.wizard.panels.UserAvailableModelsPanel;
import ai.reveng.toolkit.ghidra.core.ui.wizard.panels.UserCredentialsPanel;
import docking.wizard.AbstractMagePanelManager;
import docking.wizard.IllegalPanelStateException;
Expand All @@ -38,7 +37,6 @@ public SetupWizardManager(WizardState<SetupWizardStateKey> initialState, PluginT
protected List<MagePanel<SetupWizardStateKey>> createPanels() {
List<MagePanel<SetupWizardStateKey>> panels = new ArrayList<MagePanel<SetupWizardStateKey>>();
panels.add(new UserCredentialsPanel(tool));
panels.add(new UserAvailableModelsPanel());

return panels;
}
Expand Down

This file was deleted.

0 comments on commit 80ddacc

Please sign in to comment.