diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Avg.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Avg.java index f794888c02..46fa2c8e96 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Avg.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Avg.java @@ -99,7 +99,9 @@ public Row transform(Table table, FunctionParams params) throws Exception { } Object[] targetValues = new Object[targetFields.size()]; for (int i = 0; i < targetValues.length; i++) { - targetValues[i] = targetSums[i] / counts[i]; + if (counts[i] != 0) { + targetValues[i] = targetSums[i] / counts[i]; + } } return new Row(new Header(targetFields), targetValues); } diff --git a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Sum.java b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Sum.java index a248f52003..71d217e280 100644 --- a/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Sum.java +++ b/core/src/main/java/cn/edu/tsinghua/iginx/engine/shared/function/system/Sum.java @@ -70,6 +70,7 @@ public Row transform(Table table, FunctionParams params) throws Exception { } Object[] targetValues = new Object[targetFields.size()]; + long[] counts = new long[targetFields.size()]; for (int i = 0; i < targetFields.size(); i++) { Field targetField = targetFields.get(i); if (targetField.getType() == DataType.LONG) { @@ -102,6 +103,12 @@ public Row transform(Table table, FunctionParams params) throws Exception { throw new IllegalStateException( "Unexpected field type: " + fields.get(index).getType().toString()); } + counts[i]++; + } + } + for (int i = 0; i < targetValues.length; i++) { + if (counts[i] == 0) { + targetValues[i] = null; } } return new Row(new Header(targetFields), targetValues); diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/client/ImportFileIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/client/ImportFileIT.java index ea4dcbd3cb..4c94a34f5d 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/client/ImportFileIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/client/ImportFileIT.java @@ -80,6 +80,6 @@ public void testLoadData() { + "| 14| true| eee| 4.5| 4.0|\n" + "+---+------+----+----+------+\n" + "Total line number = 5\n"; - executor.executeAndCompare(query, expected); + executor.executeAndCompare(query1, expected1); } } diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/session/SessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/session/SessionIT.java index 48ebedf244..d1b7abf790 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/session/SessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/session/SessionIT.java @@ -533,8 +533,7 @@ public void sessionTest() throws SessionException, InterruptedException { if (dsStartKey > delEndKey || dsEndKey < delStartKey) { assertEquals(delDsAvg + pathNum, changeResultToDouble(dsResult.get(j)), delta); } else if (dsStartKey >= delStartKey && dsEndKey < delEndKey) { - // assertNull(dsResult.get(j)); - assertTrue(Double.isNaN((Double) dsResult.get(j))); + assertNull(dsResult.get(j)); } else if (dsStartKey < delStartKey) { assertEquals( (dsStartKey + delStartKey - 1) / 2.0 + pathNum, @@ -613,18 +612,12 @@ public void sessionTest() throws SessionException, InterruptedException { session.aggregateQuery(delDataInColumnPaths, START_KEY, END_KEY + 1, AggregateType.AVG); List delDataAvgResPaths = delDataAvgSet.getPaths(); Object[] delDataAvgResult = delDataAvgSet.getValues(); - assertEquals(dataInColumnLen, delDataAvgResPaths.size()); - assertEquals(dataInColumnLen, delDataAvgSet.getValues().length); - for (int i = 0; i < dataInColumnLen; i++) { + assertEquals(dataInColumnLen - deleteDataInColumnLen, delDataAvgResPaths.size()); + assertEquals(dataInColumnLen - deleteDataInColumnLen, delDataAvgSet.getValues().length); + for (int i = 0; i < dataInColumnLen - deleteDataInColumnLen; i++) { int pathNum = getPathNum(delDataAvgResPaths.get(i)); assertNotEquals(pathNum, -1); - if (pathNum < currPath + deleteDataInColumnLen) { // Here is the removed rows - // assertEquals("null", new String((byte[]) - // delDataAvgResult[i])); - assertTrue(Double.isNaN((Double) delDataAvgResult[i])); - } else { - assertEquals((START_KEY + END_KEY) / 2.0 + pathNum, delDataAvgResult[i]); - } + assertEquals((START_KEY + END_KEY) / 2.0 + pathNum, delDataAvgResult[i]); } // Test downsample function for the delete @@ -645,12 +638,7 @@ public void sessionTest() throws SessionException, InterruptedException { double avg = (dsKey + maxNum) / 2.0; int pathNum = getPathNum(dsDelDataResPaths.get(j)); assertNotEquals(pathNum, -1); - if (pathNum < currPath + deleteDataInColumnLen) { // Here is the removed rows - // assertNull(dsResult.get(j)); - assertTrue(Double.isNaN((Double) dsResult.get(j))); - } else { - assertEquals(avg + pathNum, changeResultToDouble(dsResult.get(j)), delta); - } + assertEquals(avg + pathNum, changeResultToDouble(dsResult.get(j)), delta); } } currPath += dataInColumnLen; diff --git a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java index 538d2be3f2..de49d30b3a 100644 --- a/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java +++ b/test/src/test/java/cn/edu/tsinghua/iginx/integration/func/sql/SQLSessionIT.java @@ -1636,6 +1636,54 @@ public void testAggregateQueryWithArithExpr() { } } + @Test + public void testAggregateQueryWithNullValues() { + String insert = "insert into test(key, a) values (0, 1), (1, 2), (2, 3);"; + executor.execute(insert); + insert = "insert into test(key, b) values (3, 1), (4, 2), (5, 3);"; + executor.execute(insert); + + String query = "select * from test;"; + String expected = + "ResultSets:\n" + + "+---+------+------+\n" + + "|key|test.a|test.b|\n" + + "+---+------+------+\n" + + "| 0| 1| null|\n" + + "| 1| 2| null|\n" + + "| 2| 3| null|\n" + + "| 3| null| 1|\n" + + "| 4| null| 2|\n" + + "| 5| null| 3|\n" + + "+---+------+------+\n" + + "Total line number = 6\n"; + executor.executeAndCompare(query, expected); + + query = "select avg(*), sum(*), count(*) from test where key < 3;"; + // key<3时,avg(test.b)和sum(test.b)的值是null + expected = + "ResultSets:\n" + + "+-----------+-----------+-------------+-------------+\n" + + "|avg(test.a)|sum(test.a)|count(test.a)|count(test.b)|\n" + + "+-----------+-----------+-------------+-------------+\n" + + "| 2.0| 6| 3| 0|\n" + + "+-----------+-----------+-------------+-------------+\n" + + "Total line number = 1\n"; + executor.executeAndCompare(query, expected); + + query = "select avg(*), sum(*), count(*) from test where key > 2;"; + // key>2时,avg(test.a)和sum(test.a)的值是null + expected = + "ResultSets:\n" + + "+-----------+-----------+-------------+-------------+\n" + + "|avg(test.b)|sum(test.b)|count(test.a)|count(test.b)|\n" + + "+-----------+-----------+-------------+-------------+\n" + + "| 2.0| 6| 0| 3|\n" + + "+-----------+-----------+-------------+-------------+\n" + + "Total line number = 1\n"; + executor.executeAndCompare(query, expected); + } + @Test public void testDownSampleQuery() { String statement = "SELECT %s(s1), %s(s4) FROM us.d1 OVER WINDOW (size 100 IN (0, 1000));";