Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ES|QL categorize with multiple groupings #118173

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118173.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118173
summary: ES|QL categorize with multiple groupings
area: Machine Learning
type: feature
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,16 @@ public static BlockHash buildCategorizeBlockHash(
List<GroupSpec> groups,
AggregatorMode aggregatorMode,
BlockFactory blockFactory,
AnalysisRegistry analysisRegistry
AnalysisRegistry analysisRegistry,
int emitBatchSize
) {
if (groups.size() != 1) {
throw new IllegalArgumentException("only a single CATEGORIZE group can used");
if (groups.size() == 1) {
return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
} else {
assert groups.get(0).isCategorize();
jan-elastic marked this conversation as resolved.
Show resolved Hide resolved
assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);
Comment on lines +189 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/bikeshed: throwing IllegalArgumentException would be more friendly towards test; when assertions trigger, they bring down a whole node because that's an error, not exception. It's probably fine, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't happen, right? If this assertion fails, other code is broken (the verifier). I'll leave it as is unless you object. BTW, do you know if we run assertions in production?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertions are disabled in Prod, and this indeed shouldn't happen. Occasionally, a bug slips through, though, and when it triggers an assertion, it kill the whole IT suite run because it kills a node. It's fine to leave as-is, though!

return new CategorizePackedValuesBlockHash(groups, blockFactory, aggregatorMode, analysisRegistry, emitBatchSize);
}

return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import java.util.Objects;

/**
* Base BlockHash implementation for {@code Categorize} grouping function.
* BlockHash implementation for {@code Categorize} grouping function.
*/
public class CategorizeBlockHash extends BlockHash {

Expand Down Expand Up @@ -95,12 +95,14 @@ public class CategorizeBlockHash extends BlockHash {
}
}

boolean seenNull() {
return seenNull;
}

@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
if (aggregatorMode.isInputPartial() == false) {
addInitial(page, addInput);
} else {
addIntermediate(page, addInput);
try (IntBlock block = add(page)) {
addInput.add(0, block);
}
}

Expand Down Expand Up @@ -129,50 +131,39 @@ public void close() {
Releasables.close(evaluator, categorizer);
}

private IntBlock add(Page page) {
return aggregatorMode.isInputPartial() == false ? addInitial(page) : addIntermediate(page);
}

/**
* Adds initial (raw) input to the state.
*/
private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) {
try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) {
addInput.add(0, result);
}
IntBlock addInitial(Page page) {
return (IntBlock) evaluator.eval(page.getBlock(channel));
}

/**
* Adds intermediate state to the state.
*/
private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) {
private IntBlock addIntermediate(Page page) {
if (page.getPositionCount() == 0) {
return;
return null;
}
BytesRefBlock categorizerState = page.getBlock(channel);
if (categorizerState.areAllValuesNull()) {
seenNull = true;
try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
addInput.add(0, newIds);
}
return;
}

Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
int fromId = idMap.containsKey(0) ? 0 : 1;
int toId = fromId + idMap.size();
for (int i = fromId; i < toId; i++) {
newIdsBuilder.appendInt(idMap.get(i));
}
try (IntBlock newIds = newIdsBuilder.build()) {
addInput.add(0, newIds);
}
return blockFactory.newConstantIntBlockWith(NULL_ORD, 1);
}
int[] ids = recategorize(categorizerState.getBytesRef(0, new BytesRef()), null);
return blockFactory.newIntArrayVector(ids, ids.length).asBlock();
}

/**
* Read intermediate state from a block.
*
* @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
* Reads the intermediate state from a block and recategorizes the provided IDs.
* If no IDs are provided, the IDs are the IDs in the categorizer's state in order.
* (So 0...N-1 or 1...N, depending on whether null is present.)
*/
private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
int[] recategorize(BytesRef bytes, int[] ids) {
Map<Integer, Integer> idMap = new HashMap<>();
try (StreamInput in = new BytesArray(bytes).streamInput()) {
if (in.readBoolean()) {
Expand All @@ -185,10 +176,20 @@ private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
// +1 because the 0 ordinal is reserved for null
idMap.put(oldCategoryId + 1, newCategoryId + 1);
}
return idMap;
} catch (IOException e) {
throw new RuntimeException(e);
}
if (ids == null) {
ids = new int[idMap.size()];
int idOffset = idMap.containsKey(0) ? 0 : 1;
for (int i = 0; i < idMap.size(); i++) {
ids[i] = i + idOffset;
}
}
for (int i = 0; i < ids.length; i++) {
ids[i] = idMap.get(ids[i]);
}
return ids;
}

/**
Expand All @@ -198,15 +199,19 @@ private Block buildIntermediateBlock() {
if (categorizer.getCategoryCount() == 0) {
return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
}
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
return blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), positionCount);
}

BytesRef serializeCategorizer() {
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.writeBoolean(seenNull);
out.writeVInt(categorizer.getCategoryCount());
for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
category.writeTo(out);
}
// We're returning a block with N positions just because the Page must have all blocks with the same position count!
int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
return out.bytes().toBytesRef();
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.aggregation.blockhash;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BitArray;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.ReleasableIterator;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.index.analysis.AnalysisRegistry;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* BlockHash implementation for {@code Categorize} grouping function as first
* grouping expression, followed by one or mode other grouping expressions.
* <p>
* For the first grouping (the {@code Categorize} grouping function), a
* {@code CategorizeBlockHash} is used, which outputs integers (category IDs).
* Next, a {@code PackedValuesBlockHash} is used on the category IDs and the
* other groupings (which are not {@code Categorize}s).
*/
public class CategorizePackedValuesBlockHash extends BlockHash {

private final List<GroupSpec> specs;
private final AggregatorMode aggregatorMode;
private final Block[] blocks;
private final CategorizeBlockHash categorizeBlockHash;
private final PackedValuesBlockHash packedValuesBlockHash;

CategorizePackedValuesBlockHash(
List<GroupSpec> specs,
BlockFactory blockFactory,
AggregatorMode aggregatorMode,
AnalysisRegistry analysisRegistry,
int emitBatchSize
) {
super(blockFactory);
this.specs = specs;
this.aggregatorMode = aggregatorMode;
blocks = new Block[specs.size()];

List<GroupSpec> delegateSpecs = new ArrayList<>();
delegateSpecs.add(new GroupSpec(0, ElementType.INT));
jan-elastic marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 1; i < specs.size(); i++) {
delegateSpecs.add(new GroupSpec(i, specs.get(i).elementType()));
}

boolean success = false;
try {
categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
success = true;
} finally {
if (success == false) {
close();
}
}
}

@Override
public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
try (IntBlock categories = getCategories(page)) {
blocks[0] = categories;
for (int i = 1; i < specs.size(); i++) {
blocks[i] = page.getBlock(specs.get(i).channel());
}
packedValuesBlockHash.add(new Page(blocks), addInput);
}
}

private IntBlock getCategories(Page page) {
if (aggregatorMode.isInputPartial() == false) {
return categorizeBlockHash.addInitial(page);
} else {
BytesRefBlock stateBlock = page.getBlock(0);
BytesRef stateBytes = stateBlock.getBytesRef(0, new BytesRef());
try (StreamInput in = new BytesArray(stateBytes).streamInput()) {
BytesRef categorizerState = in.readBytesRef();
int[] ids = in.readIntArray();
ids = categorizeBlockHash.recategorize(categorizerState, ids);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ definitely easier to grasp, thanks!

return blockFactory.newIntArrayVector(ids, ids.length).asBlock();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
public Block[] getKeys() {
Block[] keys = packedValuesBlockHash.getKeys();
if (aggregatorMode.isOutputPartial() == false) {
// For final output, the keys are the category regexes.
try (
BytesRefBlock regexes = (BytesRefBlock) categorizeBlockHash.getKeys()[0];
BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(keys[0].getPositionCount())
) {
IntVector idsVector = (IntVector) keys[0].asVector();
int idsOffset = categorizeBlockHash.seenNull() ? 0 : -1;
BytesRef scratch = new BytesRef();
for (int i = 0; i < idsVector.getPositionCount(); i++) {
int id = idsVector.getInt(i);
if (id == 0) {
builder.appendNull();
} else {
builder.appendBytesRef(regexes.getBytesRef(id + idsOffset, scratch));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for now: We're repeating, potentially, a lot of bytesref values here. I wonder if there is or it would make sense to have a BytesRefBlock that instead of all the bytesrefs, stores every value just once, and then a reference per position:

AAAAAA
BBBBBBB
AAAAAA
AAAAAA

->

// 1: AAAAAA
// 2: BBBBBBB
1
2
1
1

@nik9000 Something to consider for later? Maybe it's too specific for this. And anyway, the next EVAL or whatever will duplicate the value again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a nice thing to have, but definitely out of scope for this PR.

However, the next EVAL should not duplicate the value again.

If you have:

// 1: AAAAAA
// 2: BBBBBBB
1
2
1
1

then an efficient EVAL x=SUBSTRING(x, 1, 2) should give

// 1: AA
// 2: BB
1
2
1
1

without ever duplicating.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that SUBSTRING to not duplicate, we would need to add that "hashtable" strategy in the BytesRefBlockBuilder. It looks goo (?), but I wonder if using that by default could perform negatively in some scenarios. Something to try eventually probably

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like worth trying in the future. Are you making a note (issue) of this, so that the idea doesn't get lost?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I'll comment it with Nik, just in case it was considered and discarded already, and then I'll document it in an issue somewhere

jan-elastic marked this conversation as resolved.
Show resolved Hide resolved
}
}
keys[0].close();
keys[0] = builder.build();
}
} else {
// For intermediate output, the keys are the delegate PackedValuesBlockHash's
// keys, with the category IDs replaced by the categorizer's internal state
// together with the list of category IDs.
BytesRef state;
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.writeBytesRef(categorizeBlockHash.serializeCategorizer());
// It's a bit inefficient to copy the IntVector's values into an int[]
// and discard the array soon after. IntVector should maybe expose the
// underlying array instead. TODO: investigate whether that's worth it
IntVector idsVector = (IntVector) keys[0].asVector();
int[] idsArray = new int[idsVector.getPositionCount()];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little afraid that we allocate potentially quite a bit of memory without asking the breaker first. I believe this will lead to tricky-to-debug situations when the memory pressure is already high and this leads to an OOM. Not sure how likely, but still.

The blockFactory has convenience methods preAdjustBreakerForInt and adjustBreaker that we better use here. That needs to be done carefully re. try/catching as not to have a circuit breaker leak.

@nik9000 wdyt? Should we play it safe here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of manually handling the memory here, maybe we should just do a idsVector.writeTo(...), so we remove a chunk of code from here, and avoid allocating anything else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to track these. I'm not sure it has to be a blocker though. Until a few months ago aggs didn't track the a few similar things to this. OTOH, it could cause problems.....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it looks like we just writeIntArray with this. In that case, yeah, I'd write the ids manually.

Copy link
Contributor

@alex-spies alex-spies Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I checked and we also don't really track memory in CategorizeBlockHash.getKeys; neither do we track the memory for the categorizer itself. Update: actually, we probably do - so that should be covered.

The problem here is that due to combinatorial explosion, the untracked memory when writing the idsVector can be a lot larger than the actual categorizer state.

E.g. STATS ... BY CATEGORIZE(message), field1, field2.

If there are n categories of messages, m distinct field1 values and o distinct field2 values, then the number of rows - and thus ids - will be n*m*o. And we're copying this twice: once into an int[] and another time when writing into out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this up to you all to decide...

for (int i = 0; i < idsVector.getPositionCount(); i++) {
idsArray[i] = idsVector.getInt(i);
}
out.writeIntArray(idsArray);
jan-elastic marked this conversation as resolved.
Show resolved Hide resolved
state = out.bytes().toBytesRef();
} catch (IOException e) {
throw new RuntimeException(e);
}
keys[0].close();
keys[0] = blockFactory.newConstantBytesRefBlockWith(state, keys[0].getPositionCount());
}
return keys;
}

@Override
public IntVector nonEmpty() {
return packedValuesBlockHash.nonEmpty();
}

@Override
public BitArray seenGroupIds(BigArrays bigArrays) {
return packedValuesBlockHash.seenGroupIds(bigArrays);
}

@Override
public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
throw new UnsupportedOperationException();
}

@Override
public void close() {
Releasables.close(categorizeBlockHash, packedValuesBlockHash);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ public Operator get(DriverContext driverContext) {
if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
() -> BlockHash.buildCategorizeBlockHash(
groups,
aggregatorMode,
driverContext.blockFactory(),
analysisRegistry,
maxPageSize
),
driverContext
);
}
Expand Down
Loading
Loading