diff --git a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs index 4490290a1ab0..ee5eb0f1e280 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_dictionary.rs @@ -57,14 +57,14 @@ use crate::sql::IndexType; macro_rules! sqlx_fetch_optional { ($pool:expr, $sql:expr, $key_type:ty, $val_type:ty, $format_val_fn:expr) => {{ - let res: Option<($key_type, $val_type)> = + let res: Option<($key_type, Option<$val_type>)> = sqlx::query_as(&$sql).fetch_optional($pool).await?; - Ok(res.map(|(_, v)| $format_val_fn(v))) + Ok(res.and_then(|(_, v)| v.map($format_val_fn))) }}; } macro_rules! fetch_single_row_by_sqlx { - ($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_val_fn:expr) => {{ + ($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_val_fn:expr) => { match $key_scalar { DataType::Boolean => { sqlx_fetch_optional!($pool, $sql, bool, $val_type, $format_val_fn) @@ -88,45 +88,62 @@ macro_rules! fetch_single_row_by_sqlx { $key_scalar, ))), } - }}; + }; } macro_rules! fetch_all_rows_by_sqlx { ($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_key_fn:expr) => { match $key_scalar { DataType::Boolean => { - let res: Vec<(bool, $val_type)> = sqlx::query_as($sql).fetch_all($pool).await?; + let res: Vec<(bool, Option<$val_type>)> = + sqlx::query_as($sql).fetch_all($pool).await?; res.into_iter() - .map(|(k, v)| ($format_key_fn(ScalarRef::Boolean(k)), v)) + .filter_map(|(key, val)| match (key, val) { + (k, Some(v)) => Some(($format_key_fn(ScalarRef::Boolean(k)), v)), + _ => None, + }) .collect() } DataType::String => { - let res: Vec<(String, $val_type)> = sqlx::query_as($sql).fetch_all($pool).await?; + let res: Vec<(String, Option<$val_type>)> = + sqlx::query_as($sql).fetch_all($pool).await?; res.into_iter() - .map(|(k, v)| ($format_key_fn(ScalarRef::String(&k)), v)) + .filter_map(|(key, val)| match (key, val) { + (k, Some(v)) => Some(($format_key_fn(ScalarRef::String(&k)), v)), + _ => None, + }) .collect() } DataType::Number(num_ty) => { with_integer_mapped_type!(|NUM_TYPE| match num_ty { NumberDataType::NUM_TYPE => { - let res: Vec<(NUM_TYPE, $val_type)> = + let res: Vec<(NUM_TYPE, Option<$val_type>)> = sqlx::query_as($sql).fetch_all($pool).await?; res.into_iter() - .map(|(k, v)| (format!("{}", k), v)) + .filter_map(|(key, val)| match (key, val) { + (k, Some(v)) => Some((format!("{}", k), v)), + _ => None, + }) .collect() } NumberDataType::Float32 => { - let res: Vec<(f32, $val_type)> = + let res: Vec<(f32, Option<$val_type>)> = sqlx::query_as($sql).fetch_all($pool).await?; res.into_iter() - .map(|(k, v)| (format!("{}", k), v)) + .filter_map(|(key, val)| match (key, val) { + (k, Some(v)) => Some((format!("{}", k), v)), + _ => None, + }) .collect() } NumberDataType::Float64 => { - let res: Vec<(f64, $val_type)> = + let res: Vec<(f64, Option<$val_type>)> = sqlx::query_as($sql).fetch_all($pool).await?; res.into_iter() - .map(|(k, v)| (format!("{}", k), v)) + .filter_map(|(key, val)| match (key, val) { + (k, Some(v)) => Some((format!("{}", k), v)), + _ => None, + }) .collect() } }) diff --git a/tests/sqllogictests/src/mock_source/mysql_source.rs b/tests/sqllogictests/src/mock_source/mysql_source.rs index 6a8c1708c2c2..c2b297034bfe 100644 --- a/tests/sqllogictests/src/mock_source/mysql_source.rs +++ b/tests/sqllogictests/src/mock_source/mysql_source.rs @@ -69,7 +69,7 @@ pub fn run_mysql_source() { struct Backend { table: String, schema: Vec, - block: Vec>, + block: Vec>>, prepared_id: u32, prepared: HashMap, Vec)>, @@ -112,27 +112,42 @@ impl Backend { }, ]; - let block = vec![ - vec![Value::Int(1), Value::Int(2), Value::Int(3), Value::Int(4)], + let block: Vec>> = vec![ vec![ - Value::Bytes("Alice".as_bytes().to_vec()), - Value::Bytes("Bob".as_bytes().to_vec()), - Value::Bytes("Lily".as_bytes().to_vec()), - Value::Bytes("Tom".as_bytes().to_vec()), + Some(Value::Int(1)), + Some(Value::Int(2)), + Some(Value::Int(3)), + Some(Value::Int(4)), + Some(Value::Int(5)), ], vec![ - Value::UInt(24), - Value::UInt(35), - Value::UInt(41), - Value::UInt(55), + Some(Value::Bytes("Alice".as_bytes().to_vec())), + Some(Value::Bytes("Bob".as_bytes().to_vec())), + Some(Value::Bytes("Lily".as_bytes().to_vec())), + Some(Value::Bytes("Tom".as_bytes().to_vec())), + None, ], vec![ - Value::Double(100.0), - Value::Double(200.1), - Value::Double(1000.20), - Value::Double(3000.55), + Some(Value::UInt(24)), + Some(Value::UInt(35)), + Some(Value::UInt(41)), + Some(Value::UInt(55)), + None, + ], + vec![ + Some(Value::Double(100.0)), + Some(Value::Double(200.1)), + Some(Value::Double(1000.20)), + Some(Value::Double(3000.55)), + None, + ], + vec![ + Some(Value::Int(1)), + Some(Value::Int(0)), + Some(Value::Int(1)), + Some(Value::Int(0)), + None, ], - vec![Value::Int(1), Value::Int(0), Value::Int(1), Value::Int(0)], ]; Self { @@ -242,7 +257,7 @@ impl MysqlShim for Backend { let key_field = self.schema[*key_idx].clone(); let key_column = self.block[*key_idx].clone(); - // find matched rows by compare key params. + // step-1: find matched rows by compare key params. let mut rows: Vec> = vec![]; match key_field.coltype { ColumnType::MYSQL_TYPE_TINY => { @@ -254,9 +269,11 @@ impl MysqlShim for Backend { let key = param.parse::().unwrap(); let key_param = Value::Int(key.into()); for (i, key) in key_column.iter().enumerate() { - if key == &key_param { - rows.push(Some(i)); - break; + if let Some(key) = key { + if key == &key_param { + rows.push(Some(i)); + break; + } } } } @@ -270,9 +287,11 @@ impl MysqlShim for Backend { let key = param.parse::().unwrap(); let key_param = Value::UInt(key); for (i, key) in key_column.iter().enumerate() { - if key == &key_param { - rows.push(Some(i)); - break; + if let Some(key) = key { + if key == &key_param { + rows.push(Some(i)); + break; + } } } } @@ -286,9 +305,11 @@ impl MysqlShim for Backend { let key = param.parse::().unwrap(); let key_param = Value::Int(key); for (i, key) in key_column.iter().enumerate() { - if key == &key_param { - rows.push(Some(i)); - break; + if let Some(key) = key { + if key == &key_param { + rows.push(Some(i)); + break; + } } } } @@ -302,9 +323,11 @@ impl MysqlShim for Backend { let key = param.parse::().unwrap(); let key_param = Value::Double(key); for (i, key) in key_column.iter().enumerate() { - if key == &key_param { - rows.push(Some(i)); - break; + if let Some(key) = key { + if key == &key_param { + rows.push(Some(i)); + break; + } } } } @@ -322,9 +345,11 @@ impl MysqlShim for Backend { let key = param_str.as_bytes().to_vec(); let key_param = Value::Bytes(key); for (i, key) in key_column.iter().enumerate() { - if key == &key_param { - rows.push(Some(i)); - break; + if let Some(key) = key { + if key == &key_param { + rows.push(Some(i)); + break; + } } } } @@ -332,6 +357,7 @@ impl MysqlShim for Backend { _ => {} } + // step-2: write columns based on the matched rows // return NULL if params not matched. if rows.is_empty() { return results.completed(0, 0); @@ -346,7 +372,6 @@ impl MysqlShim for Backend { let value_column2 = self.block[value_idx2].clone(); let cols = vec![value_field1.clone(), value_field2.clone()]; - let mut rw = results.start(&cols)?; for row in rows.into_iter().map(|r| r.unwrap()) { @@ -357,51 +382,56 @@ impl MysqlShim for Backend { (value2, value_field2.clone()), ] { match value { - Value::Bytes(v) => { - rw.write_col(v)?; - } - Value::Int(v) => match value_field.coltype { - ColumnType::MYSQL_TYPE_TINY => { - rw.write_col(v as i8)?; - } - ColumnType::MYSQL_TYPE_SHORT => { - rw.write_col(v as u16)?; - } - ColumnType::MYSQL_TYPE_LONG => { - rw.write_col(v as i32)?; - } - ColumnType::MYSQL_TYPE_LONGLONG => { + Some(val) => match val { + Value::Bytes(v) => { rw.write_col(v)?; } - _ => { - unreachable!() - } - }, - Value::UInt(v) => match value_field.coltype { - ColumnType::MYSQL_TYPE_TINY => { - rw.write_col(v as u8)?; - } - ColumnType::MYSQL_TYPE_SHORT => { - rw.write_col(v as u16)?; - } - ColumnType::MYSQL_TYPE_LONG => { - rw.write_col(v as u32)?; + Value::Int(v) => match value_field.coltype { + ColumnType::MYSQL_TYPE_TINY => { + rw.write_col(v as i8)?; + } + ColumnType::MYSQL_TYPE_SHORT => { + rw.write_col(v as u16)?; + } + ColumnType::MYSQL_TYPE_LONG => { + rw.write_col(v as i32)?; + } + ColumnType::MYSQL_TYPE_LONGLONG => { + rw.write_col(v)?; + } + _ => { + unreachable!() + } + }, + Value::UInt(v) => match value_field.coltype { + ColumnType::MYSQL_TYPE_TINY => { + rw.write_col(v as u8)?; + } + ColumnType::MYSQL_TYPE_SHORT => { + rw.write_col(v as u16)?; + } + ColumnType::MYSQL_TYPE_LONG => { + rw.write_col(v as u32)?; + } + ColumnType::MYSQL_TYPE_LONGLONG => { + rw.write_col(v)?; + } + _ => { + unreachable!() + } + }, + Value::Float(v) => { + rw.write_col(v)?; } - ColumnType::MYSQL_TYPE_LONGLONG => { + Value::Double(v) => { rw.write_col(v)?; } _ => { - unreachable!() + rw.write_col("")?; } }, - Value::Float(v) => { - rw.write_col(v)?; - } - Value::Double(v) => { - rw.write_col(v)?; - } - _ => { - rw.write_col("")?; + None => { + rw.write_col(None::)?; } } } diff --git a/tests/sqllogictests/suites/query/functions/02_0077_function_dict_get.test b/tests/sqllogictests/suites/query/functions/02_0077_function_dict_get.test index 4a84dee72401..7537504e584d 100644 --- a/tests/sqllogictests/suites/query/functions/02_0077_function_dict_get.test +++ b/tests/sqllogictests/suites/query/functions/02_0077_function_dict_get.test @@ -174,12 +174,12 @@ select dict_get(mysql_dic_id, 'id', 1), dict_get(mysql_dic_id, 'name', 1), dict_ query ITIFT select dict_get(mysql_dic_id, 'id', 5), dict_get(mysql_dic_id, 'name', 5), dict_get(mysql_dic_id, 'age', 5), dict_get(mysql_dic_id, 'salary', 5), dict_get(mysql_dic_id, 'active', 5) ---- -NULL NULL NULL NULL NULL +5 NULL NULL NULL NULL query ITIFT select dict_get(mysql_dic_id_not_null, 'id', 5), dict_get(mysql_dic_id_not_null, 'name', 5), dict_get(mysql_dic_id_not_null, 'age', 5), dict_get(mysql_dic_id_not_null, 'salary', 5), dict_get(mysql_dic_id_not_null, 'active', 5) ---- -0 default_name 0 0.0 0 +5 default_name 0 0.0 0 statement error 1006 select dict_get(mysql_dic_id, 'id2', 5)