Skip to content

Commit

Permalink
[fix](nereids) Fix query rewrite by mv fail when self join
Browse files Browse the repository at this point in the history
  • Loading branch information
seawinde committed Dec 28, 2023
1 parent e5b6826 commit 424be79
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableBiMap.Builder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -59,54 +61,84 @@ public static RelationMapping of(ImmutableBiMap<MappedRelation, MappedRelation>
*/
public static List<RelationMapping> generate(List<CatalogRelation> sources, List<CatalogRelation> targets) {
// Construct tmp map, key is the table qualifier, value is the corresponding catalog relations
LinkedListMultimap<Long, MappedRelation> sourceTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> sourceTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : sources) {
sourceTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
LinkedListMultimap<Long, MappedRelation> targetTableRelationIdMap = LinkedListMultimap.create();
HashMultimap<Long, MappedRelation> targetTableRelationIdMap = HashMultimap.create();
for (CatalogRelation relation : targets) {
targetTableRelationIdMap.put(getTableQualifier(relation.getTable()),
MappedRelation.of(relation.getRelationId(), relation));
}
Set<Long> sourceTableKeySet = sourceTableRelationIdMap.keySet();
List<List<Pair<MappedRelation, MappedRelation>>> mappedRelations = new ArrayList<>();
List<List<RelationMapping>> mappedRelations = new ArrayList<>();

for (Long sourceTableQualifier : sourceTableKeySet) {
List<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableQualifier);
List<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableQualifier);
Set<MappedRelation> sourceMappedRelations = sourceTableRelationIdMap.get(sourceTableQualifier);
Set<MappedRelation> targetMappedRelations = targetTableRelationIdMap.get(sourceTableQualifier);
if (targetMappedRelations.isEmpty()) {
continue;
}
// if source and target relation appear once, just map them
if (targetMappedRelations.size() == 1 && sourceMappedRelations.size() == 1) {
mappedRelations.add(ImmutableList.of(Pair.of(sourceMappedRelations.get(0),
targetMappedRelations.get(0))));
ImmutableBiMap.Builder<MappedRelation, MappedRelation> biMapBuilder = ImmutableBiMap.builder();
mappedRelations.add(ImmutableList.of(
RelationMapping.of(biMapBuilder.put(sourceMappedRelations.iterator().next(),
targetMappedRelations.iterator().next()).build())));
continue;
}
// relation appear more than once, should cartesian them
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Lists.cartesianProduct(
sourceTableRelationIdMap.get(sourceTableQualifier), targetMappedRelations)
// relation appear more than once, should cartesian them and power set to correct combination
// if query is a0, a1, view is b0, b1
// relationMapping will be
// a0 b0
// a0 b1
// a1 b0
// a1 b1
ImmutableList<Pair<MappedRelation, MappedRelation>> relationMapping = Sets.cartesianProduct(
sourceMappedRelations, targetMappedRelations)
.stream()
.map(listPair -> Pair.of(listPair.get(0), listPair.get(1)))
.collect(ImmutableList.toImmutableList());
mappedRelations.add(relationMapping);
}

int mappedRelationCount = mappedRelations.size();

return Lists.cartesianProduct(mappedRelations).stream()
.map(mappedRelationList -> {
Builder<MappedRelation, MappedRelation> mapBuilder = ImmutableBiMap.builder();
for (int relationIndex = 0; relationIndex < mappedRelationCount; relationIndex++) {
mapBuilder.put(mappedRelationList.get(relationIndex).key(),
mappedRelationList.get(relationIndex).value());
// the mapping in relationMappingPowerList should be bi-direction
// [
// {a0-b0, a1-b1}
// {a1-b0, a0-b1}
// ]
List<RelationMapping> relationMappingPowerList = new ArrayList<>();
int relationMappingSize = relationMapping.size();
int relationMappingMinSize = Math.min(sourceMappedRelations.size(), targetMappedRelations.size());
for (int i = 0; i < relationMappingSize; i++) {
HashBiMap<MappedRelation, MappedRelation> relationBiMap = HashBiMap.create();
relationBiMap.put(relationMapping.get(i).key(), relationMapping.get(i).value());
for (int j = i + 1; j < relationMappingSize; j++) {
if (!relationBiMap.containsKey(relationMapping.get(j).key())
&& !relationBiMap.containsValue(relationMapping.get(j).value())) {
relationBiMap.put(relationMapping.get(j).key(), relationMapping.get(j).value());
}
return RelationMapping.of(mapBuilder.build());
})
}
// mapping should contain min num of relation in source or target at least
if (relationBiMap.size() >= relationMappingMinSize) {
relationMappingPowerList.add(RelationMapping.of(ImmutableBiMap.copyOf(relationBiMap)));
}
}
mappedRelations.add(relationMappingPowerList);
}
// mappedRelations product and merge into each relationMapping
return Lists.cartesianProduct(mappedRelations).stream()
.map(RelationMapping::merge)
.collect(ImmutableList.toImmutableList());
}

public static RelationMapping merge(List<RelationMapping> relationMappings) {
Builder<MappedRelation, MappedRelation> mappingBuilder = ImmutableBiMap.builder();
for (RelationMapping relationMapping : relationMappings) {
relationMapping.getMappedRelationMap().forEach(mappingBuilder::put);
}
return RelationMapping.of(mappingBuilder.build());
}

private static Long getTableQualifier(TableIf tableIf) {
return tableIf.getId();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
import java.util.ArrayList;
import java.util.List;

/**MappingTest*/
/**
* MappingTest
*/
public class MappingTest extends TestWithFeService {

@Override
Expand Down Expand Up @@ -275,6 +277,73 @@ public void testGenerateMapping4() {
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);
}

@Test
public void testGenerateMapping5() {
Plan sourcePlan = PlanChecker.from(connectContext)
.analyze("SELECT orders.*, l1.* "
+ "FROM\n"
+ " orders,\n"
+ " lineitem l1,\n"
+ " lineitem l2\n"
+ "WHERE\n"
+ " l1.l_orderkey = l2.l_orderkey\n"
+ " AND l1.l_orderkey = o_orderkey")
.getPlan();

Plan targetPlan = PlanChecker.from(connectContext)
.analyze("SELECT orders.*, l1.* "
+ "FROM\n"
+ " lineitem l1,\n"
+ " orders,\n"
+ " lineitem l2\n"
+ "WHERE\n"
+ " l1.l_orderkey = l2.l_orderkey\n"
+ " AND l2.l_orderkey = o_orderkey")
.getPlan();
List<CatalogRelation> sourceRelations = new ArrayList<>();
sourcePlan.accept(RelationCollector.INSTANCE, sourceRelations);

List<CatalogRelation> targetRelations = new ArrayList<>();
targetPlan.accept(RelationCollector.INSTANCE, targetRelations);

List<RelationMapping> generateRelationMapping = RelationMapping.generate(sourceRelations, targetRelations);
Assertions.assertNotNull(generateRelationMapping);
Assertions.assertEquals(2, generateRelationMapping.size());

// expected slot mapping
BiMap<ExprId, ExprId> expectedSlotMapping = HashBiMap.create();
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
expectedSlotMapping.put(new ExprId(3), new ExprId(0));
expectedSlotMapping.put(new ExprId(4), new ExprId(1));
expectedSlotMapping.put(new ExprId(5), new ExprId(5));
expectedSlotMapping.put(new ExprId(6), new ExprId(6));
// expected relation mapping
BiMap<RelationId, RelationId> expectedRelationMapping = HashBiMap.create();
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
expectedRelationMapping.put(new RelationId(1), new RelationId(0));
expectedRelationMapping.put(new RelationId(2), new RelationId(2));
assertRelationMapping(generateRelationMapping.get(1), expectedRelationMapping, expectedSlotMapping);

// expected slot mapping
expectedSlotMapping = HashBiMap.create();
expectedSlotMapping.put(new ExprId(0), new ExprId(2));
expectedSlotMapping.put(new ExprId(1), new ExprId(3));
expectedSlotMapping.put(new ExprId(2), new ExprId(4));
expectedSlotMapping.put(new ExprId(3), new ExprId(5));
expectedSlotMapping.put(new ExprId(4), new ExprId(6));
expectedSlotMapping.put(new ExprId(5), new ExprId(0));
expectedSlotMapping.put(new ExprId(6), new ExprId(1));
// expected relation mapping
expectedRelationMapping = HashBiMap.create();
expectedRelationMapping.put(new RelationId(0), new RelationId(1));
expectedRelationMapping.put(new RelationId(1), new RelationId(2));
expectedRelationMapping.put(new RelationId(2), new RelationId(0));
assertRelationMapping(generateRelationMapping.get(0), expectedRelationMapping, expectedSlotMapping);
}


private void assertRelationMapping(RelationMapping relationMapping,
BiMap<RelationId, RelationId> expectRelationMapping,
BiMap<ExprId, ExprId> expectSlotMapping) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,17 @@
-- !query7_0_after --
3 3 2023-12-11

-- !query8_0_before --
1 0 8 0 10.0000 10.50 9.50
2 0 2 0 11.5000 11.50 11.50
3 0 0 0 23.0000 33.50 12.50
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20

-- !query8_0_after --
1 0 8 0 10.0000 10.50 9.50
2 0 2 0 11.5000 11.50 11.50
3 0 0 0 23.0000 33.50 12.50
4 0 0 0 43.2000 43.20 43.20
5 0 0 0 28.7000 56.20 1.20

Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,43 @@ suite("outer_join") {
order_qt_query7_0_after "${query7_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv7_0"""


// self join test
def mv8_0 = """
select
a.o_orderkey,
count(distinct a.o_orderstatus) num1,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
def query8_0 = """
select
a.o_orderkey,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate = '2023-12-08' AND b.o_orderdate = '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num2,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority = 1 AND a.o_orderdate >= '2023-12-01' AND a.o_orderdate <= '2023-12-09' THEN a.o_shippriority+b.o_custkey ELSE 0 END) num3,
SUM(CASE WHEN a.o_orderstatus = 'o' AND a.o_shippriority in (1,2) AND a.o_orderdate >= '2023-12-08' AND b.o_orderdate <= '2023-12-09' THEN a.o_shippriority-b.o_custkey ELSE 0 END) num4,
AVG(a.o_totalprice) num5,
MAX(b.o_totalprice) num6,
MIN(a.o_totalprice) num7
from
orders a
left outer join orders b
on a.o_orderkey = b.o_orderkey
and a.o_custkey = b.o_custkey
group by a.o_orderkey;
"""
order_qt_query8_0_before "${query8_0}"
check_rewrite(mv8_0, query8_0, "mv8_0")
order_qt_query8_0_after "${query8_0}"
sql """ DROP MATERIALIZED VIEW IF EXISTS mv8_0"""
}

0 comments on commit 424be79

Please sign in to comment.