diff --git a/src/test/java/org/rogmann/llm/demo/DemoVerboseMain.java b/src/test/java/org/rogmann/llm/demo/DemoVerboseMain.java index 2019732..78ddabb 100644 --- a/src/test/java/org/rogmann/llm/demo/DemoVerboseMain.java +++ b/src/test/java/org/rogmann/llm/demo/DemoVerboseMain.java @@ -58,13 +58,14 @@ public static void main(String[] args) throws IOException, LlmConfigException { //try (LlmExecutor executor = new LlmWorkerPoolBusySpin(nThreads)) { //try (LlmExecutor executor = new LlmExecutorSingleThread()) { - final int maxBatchSize = 3; + // Set maxBatchSize = 3 to get three different beams. + final int maxBatchSize = 1; final BloomModel model = new BloomModel(modelReader, maxBatchSize, executor); tsStartInfer = Instant.now(); //String inputSentence = "Auf der Wiese läuft ein Hund hinter"; - String inputSentence = "Der Hund heißt Karl. Die Katze heißt Mimi. Wie nennt Mimi den Hund?"; - //String inputSentence = "Translate to Chinese: I write a program in Java."; + //String inputSentence = "Der Hund heißt Karl. Die Katze heißt Mimi. Wie nennt Mimi den Hund?"; + String inputSentence = "Translate to Chinese: I write a program in Java."; //String inputSentence = "What is the capital of France?"; //String inputSentence = "Translate to chinese: cat."; //String inputSentence = "¿Quién era Joan Miró?"; @@ -121,7 +122,7 @@ public static void main(String[] args) throws IOException, LlmConfigException { } final int numCandidates = idxCandidates.size(); - if (idxInf == 2 && numCandidates >= numBeams) { + if (idxInf == 2 && numCandidates >= numBeams && numBeams > batchSize) { inputIds = new int[numBeams][]; inputIds[0] = nextInputIds[0]; for (int j = 1; j < numBeams; j++) {