Skip to content

Commit

Permalink
[FLINK-34229][table] Set CodeGeneratorContext of outer class as ances…
Browse files Browse the repository at this point in the history
…tor context when generate inner classes to avoid naming conflicts
  • Loading branch information
zoudan committed Feb 4, 2024
1 parent 5fe6f20 commit c30e7d3
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,15 @@ object CodeGenUtils {
s"Unsupported type($t) to generate hash code," +
s" the type($t) is not supported as a GROUP_BY/PARTITION_BY/JOIN_EQUAL/UNION field.")
case ARRAY =>
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader)
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader, ctx)
val genHash =
HashCodeGenerator.generateArrayHash(
subCtx,
t.asInstanceOf[ArrayType].getElementType,
"SubHashArray")
genHashFunction(ctx, subCtx, genHash, term)
case MULTISET | MAP =>
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader)
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader, ctx)
val (keyType, valueType) = t match {
case multiset: MultisetType =>
(multiset.getElementType, new IntType())
Expand All @@ -353,7 +353,7 @@ object CodeGenUtils {
case INTERVAL_DAY_TIME => s"${className[JLong]}.hashCode($term)"
case ROW | STRUCTURED_TYPE =>
val fieldCount = getFieldCount(t)
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader)
val subCtx = new CodeGeneratorContext(ctx.tableConfig, ctx.classLoader, ctx)
val genHash =
HashCodeGenerator.generateRowHash(subCtx, t, "SubHashRow", (0 until fieldCount).toArray)
genHashFunction(ctx, subCtx, genHash, term)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,8 @@ object HashAggCodeGenHelper {
ctx.tableConfig,
ctx.classLoader,
aggMapKeyType,
SortUtil.getAscendingSortSpec(Array.range(0, aggMapKeyType.getFieldCount)))
SortUtil.getAscendingSortSpec(Array.range(0, aggMapKeyType.getFieldCount)),
ctx)
val computer = sortCodeGenerator.generateNormalizedKeyComputer("AggMapKeyComputer")
val comparator = sortCodeGenerator.generateRecordComparator("AggMapValueComparator")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,38 @@ object ComparatorCodeGenerator {
name: String,
inputType: RowType,
sortSpec: SortSpec): GeneratedRecordComparator = {
gen(tableConfig, classLoader, name, inputType, sortSpec, null)
}

/**
* Generates a [[RecordComparator]] that can be passed to a Java compiler.
*
* @param tableConfig
* Table config.
* @param classLoader
* user ClassLoader.
* @param name
* Class name of the function. Does not need to be unique but has to be a valid Java class
* identifier.
* @param inputType
* input type.
* @param sortSpec
* sort specification.
* @param parentCtx
* parent CodeGeneratorContext to avoid name conflicts.
* @return
* A GeneratedRecordComparator
*/
def gen(
tableConfig: ReadableConfig,
classLoader: ClassLoader,
name: String,
inputType: RowType,
sortSpec: SortSpec,
parentCtx: CodeGeneratorContext): GeneratedRecordComparator = {
val baseClass = classOf[RecordComparator]

val ctx = new CodeGeneratorContext(tableConfig, classLoader)
val ctx = new CodeGeneratorContext(tableConfig, classLoader, parentCtx)
val className = newName(ctx, name)
val compareCode = GenerateUtils.generateRowCompare(ctx, inputType, sortSpec, "o1", "o2")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,23 @@ import scala.collection.mutable
* input type.
* @param sortSpec
* sort specification.
* @param parentCtx
* parent CodeGeneratorContext to avoid name conflicts. If the generated [[NormalizedKeyComputer]]
* and [[RecordComparator]] will be used as inner classes, a non-null value must be set.
*/
class SortCodeGenerator(
tableConfig: ReadableConfig,
classLoader: ClassLoader,
val input: RowType,
val sortSpec: SortSpec) {
val sortSpec: SortSpec,
parentCtx: CodeGeneratorContext) {

def this(
tableConfig: ReadableConfig,
classLoader: ClassLoader,
input: RowType,
sortSpec: SortSpec) =
this(tableConfig, classLoader, input, sortSpec, null)

private val MAX_NORMALIZED_KEY_LEN = 16

Expand Down Expand Up @@ -130,7 +141,7 @@ class SortCodeGenerator(
* A GeneratedNormalizedKeyComputer
*/
def generateNormalizedKeyComputer(name: String): GeneratedNormalizedKeyComputer = {
val ctx = new CodeGeneratorContext(tableConfig, classLoader)
val ctx = new CodeGeneratorContext(tableConfig, classLoader, parentCtx)
val className = newName(ctx, name)

val (keyFullyDetermines, numKeyBytes) = getKeyFullyDeterminesAndBytes
Expand Down Expand Up @@ -386,7 +397,7 @@ class SortCodeGenerator(
* A GeneratedRecordComparator
*/
def generateRecordComparator(name: String): GeneratedRecordComparator = {
ComparatorCodeGenerator.gen(tableConfig, classLoader, name, input, sortSpec)
ComparatorCodeGenerator.gen(tableConfig, classLoader, name, input, sortSpec, parentCtx)
}

def getter(t: LogicalType, index: Int): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,19 @@ class OperatorFusionCodegenITCase extends BatchTestBase {
)
}

@TestTemplate
def testMultipleHashAgg(): Unit = {
checkOpFusionCodegenResult(
"""
|SELECT * FROM
| (SELECT a, SUM(b) as b FROM x group by a) T1
| INNER JOIN
| (SELECT d, SUM(e) as e FROM y group by d) T2
| ON T1.a = T2.d
|""".stripMargin
)
}

@TestTemplate
def testGlobalHashAggWithKey2(): Unit = {
checkOpFusionCodegenResult(
Expand Down

0 comments on commit c30e7d3

Please sign in to comment.