Skip to content

Commit

Permalink
fix bug while mysql table contains null values
Browse files Browse the repository at this point in the history
  • Loading branch information
Dragonliu2018 committed Nov 30, 2024
1 parent 7f3ec69 commit c0981d3
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ use crate::sql::IndexType;

macro_rules! sqlx_fetch_optional {
($pool:expr, $sql:expr, $key_type:ty, $val_type:ty, $format_val_fn:expr) => {{
let res: Option<($key_type, $val_type)> =
let res: Option<(Option<$key_type>, Option<$val_type>)> =
sqlx::query_as(&$sql).fetch_optional($pool).await?;
Ok(res.map(|(_, v)| $format_val_fn(v)))
Ok(res.and_then(|(_, v)| v.map($format_val_fn)))
}};
}

macro_rules! fetch_single_row_by_sqlx {
($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_val_fn:expr) => {{
($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_val_fn:expr) => {
match $key_scalar {
DataType::Boolean => {
sqlx_fetch_optional!($pool, $sql, bool, $val_type, $format_val_fn)
Expand All @@ -88,45 +88,62 @@ macro_rules! fetch_single_row_by_sqlx {
$key_scalar,
))),
}
}};
};
}

macro_rules! fetch_all_rows_by_sqlx {
($pool:expr, $sql:expr, $key_scalar:expr, $val_type:ty, $format_key_fn:expr) => {
match $key_scalar {
DataType::Boolean => {
let res: Vec<(bool, $val_type)> = sqlx::query_as($sql).fetch_all($pool).await?;
let res: Vec<(Option<bool>, Option<$val_type>)> =
sqlx::query_as($sql).fetch_all($pool).await?;
res.into_iter()
.map(|(k, v)| ($format_key_fn(ScalarRef::Boolean(k)), v))
.filter_map(|(key, val)| match (key, val) {
(Some(k), Some(v)) => Some(($format_key_fn(ScalarRef::Boolean(k)), v)),
_ => None,
})
.collect()
}
DataType::String => {
let res: Vec<(String, $val_type)> = sqlx::query_as($sql).fetch_all($pool).await?;
let res: Vec<(Option<String>, Option<$val_type>)> =
sqlx::query_as($sql).fetch_all($pool).await?;
res.into_iter()
.map(|(k, v)| ($format_key_fn(ScalarRef::String(&k)), v))
.filter_map(|(key, val)| match (key, val) {
(Some(k), Some(v)) => Some(($format_key_fn(ScalarRef::String(&k)), v)),
_ => None,
})
.collect()
}
DataType::Number(num_ty) => {
with_integer_mapped_type!(|NUM_TYPE| match num_ty {
NumberDataType::NUM_TYPE => {
let res: Vec<(NUM_TYPE, $val_type)> =
let res: Vec<(Option<NUM_TYPE>, Option<$val_type>)> =
sqlx::query_as($sql).fetch_all($pool).await?;
res.into_iter()
.map(|(k, v)| (format!("{}", k), v))
.filter_map(|(key, val)| match (key, val) {
(Some(k), Some(v)) => Some((format!("{}", k), v)),
_ => None,
})
.collect()
}
NumberDataType::Float32 => {
let res: Vec<(f32, $val_type)> =
let res: Vec<(Option<f32>, Option<$val_type>)> =
sqlx::query_as($sql).fetch_all($pool).await?;
res.into_iter()
.map(|(k, v)| (format!("{}", k), v))
.filter_map(|(key, val)| match (key, val) {
(Some(k), Some(v)) => Some((format!("{}", k), v)),
_ => None,
})
.collect()
}
NumberDataType::Float64 => {
let res: Vec<(f64, $val_type)> =
let res: Vec<(Option<f64>, Option<$val_type>)> =
sqlx::query_as($sql).fetch_all($pool).await?;
res.into_iter()
.map(|(k, v)| (format!("{}", k), v))
.filter_map(|(key, val)| match (key, val) {
(Some(k), Some(v)) => Some((format!("{}", k), v)),
_ => None,
})
.collect()
}
})
Expand Down
170 changes: 100 additions & 70 deletions tests/sqllogictests/src/mock_source/mysql_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn run_mysql_source() {
struct Backend {
table: String,
schema: Vec<Column>,
block: Vec<Vec<Value>>,
block: Vec<Vec<Option<Value>>>,

prepared_id: u32,
prepared: HashMap<u32, (usize, Vec<usize>, Vec<Expr>)>,
Expand Down Expand Up @@ -112,27 +112,42 @@ impl Backend {
},
];

let block = vec![
vec![Value::Int(1), Value::Int(2), Value::Int(3), Value::Int(4)],
let block: Vec<Vec<Option<Value>>> = vec![
vec![
Value::Bytes("Alice".as_bytes().to_vec()),
Value::Bytes("Bob".as_bytes().to_vec()),
Value::Bytes("Lily".as_bytes().to_vec()),
Value::Bytes("Tom".as_bytes().to_vec()),
Some(Value::Int(1)),
Some(Value::Int(2)),
Some(Value::Int(3)),
Some(Value::Int(4)),
Some(Value::Int(5)),
],
vec![
Value::UInt(24),
Value::UInt(35),
Value::UInt(41),
Value::UInt(55),
Some(Value::Bytes("Alice".as_bytes().to_vec())),
Some(Value::Bytes("Bob".as_bytes().to_vec())),
Some(Value::Bytes("Lily".as_bytes().to_vec())),
Some(Value::Bytes("Tom".as_bytes().to_vec())),
None,
],
vec![
Value::Double(100.0),
Value::Double(200.1),
Value::Double(1000.20),
Value::Double(3000.55),
Some(Value::UInt(24)),
Some(Value::UInt(35)),
Some(Value::UInt(41)),
Some(Value::UInt(55)),
None,
],
vec![
Some(Value::Double(100.0)),
Some(Value::Double(200.1)),
Some(Value::Double(1000.20)),
Some(Value::Double(3000.55)),
None,
],
vec![
Some(Value::Int(1)),
Some(Value::Int(0)),
Some(Value::Int(1)),
Some(Value::Int(0)),
None,
],
vec![Value::Int(1), Value::Int(0), Value::Int(1), Value::Int(0)],
];

Self {
Expand Down Expand Up @@ -242,7 +257,7 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key_field = self.schema[*key_idx].clone();
let key_column = self.block[*key_idx].clone();

// find matched rows by compare key params.
// step-1: find matched rows by compare key params.
let mut rows: Vec<Option<usize>> = vec![];
match key_field.coltype {
ColumnType::MYSQL_TYPE_TINY => {
Expand All @@ -254,9 +269,11 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key = param.parse::<bool>().unwrap();
let key_param = Value::Int(key.into());
for (i, key) in key_column.iter().enumerate() {
if key == &key_param {
rows.push(Some(i));
break;
if let Some(key) = key {
if key == &key_param {
rows.push(Some(i));
break;
}
}
}
}
Expand All @@ -270,9 +287,11 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key = param.parse::<u64>().unwrap();
let key_param = Value::UInt(key);
for (i, key) in key_column.iter().enumerate() {
if key == &key_param {
rows.push(Some(i));
break;
if let Some(key) = key {
if key == &key_param {
rows.push(Some(i));
break;
}
}
}
}
Expand All @@ -286,9 +305,11 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key = param.parse::<i64>().unwrap();
let key_param = Value::Int(key);
for (i, key) in key_column.iter().enumerate() {
if key == &key_param {
rows.push(Some(i));
break;
if let Some(key) = key {
if key == &key_param {
rows.push(Some(i));
break;
}
}
}
}
Expand All @@ -302,9 +323,11 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key = param.parse::<f64>().unwrap();
let key_param = Value::Double(key);
for (i, key) in key_column.iter().enumerate() {
if key == &key_param {
rows.push(Some(i));
break;
if let Some(key) = key {
if key == &key_param {
rows.push(Some(i));
break;
}
}
}
}
Expand All @@ -322,16 +345,19 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let key = param_str.as_bytes().to_vec();
let key_param = Value::Bytes(key);
for (i, key) in key_column.iter().enumerate() {
if key == &key_param {
rows.push(Some(i));
break;
if let Some(key) = key {
if key == &key_param {
rows.push(Some(i));
break;
}
}
}
}
}
_ => {}
}

// step-2: write columns based on the matched rows
// return NULL if params not matched.
if rows.is_empty() {
return results.completed(0, 0);
Expand All @@ -346,7 +372,6 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
let value_column2 = self.block[value_idx2].clone();

let cols = vec![value_field1.clone(), value_field2.clone()];

let mut rw = results.start(&cols)?;

for row in rows.into_iter().map(|r| r.unwrap()) {
Expand All @@ -357,51 +382,56 @@ impl<W: io::Read + io::Write> MysqlShim<W> for Backend {
(value2, value_field2.clone()),
] {
match value {
Value::Bytes(v) => {
rw.write_col(v)?;
}
Value::Int(v) => match value_field.coltype {
ColumnType::MYSQL_TYPE_TINY => {
rw.write_col(v as i8)?;
}
ColumnType::MYSQL_TYPE_SHORT => {
rw.write_col(v as u16)?;
}
ColumnType::MYSQL_TYPE_LONG => {
rw.write_col(v as i32)?;
}
ColumnType::MYSQL_TYPE_LONGLONG => {
Some(val) => match val {
Value::Bytes(v) => {
rw.write_col(v)?;
}
_ => {
unreachable!()
}
},
Value::UInt(v) => match value_field.coltype {
ColumnType::MYSQL_TYPE_TINY => {
rw.write_col(v as u8)?;
}
ColumnType::MYSQL_TYPE_SHORT => {
rw.write_col(v as u16)?;
}
ColumnType::MYSQL_TYPE_LONG => {
rw.write_col(v as u32)?;
Value::Int(v) => match value_field.coltype {
ColumnType::MYSQL_TYPE_TINY => {
rw.write_col(v as i8)?;
}
ColumnType::MYSQL_TYPE_SHORT => {
rw.write_col(v as u16)?;
}
ColumnType::MYSQL_TYPE_LONG => {
rw.write_col(v as i32)?;
}
ColumnType::MYSQL_TYPE_LONGLONG => {
rw.write_col(v)?;
}
_ => {
unreachable!()
}
},
Value::UInt(v) => match value_field.coltype {
ColumnType::MYSQL_TYPE_TINY => {
rw.write_col(v as u8)?;
}
ColumnType::MYSQL_TYPE_SHORT => {
rw.write_col(v as u16)?;
}
ColumnType::MYSQL_TYPE_LONG => {
rw.write_col(v as u32)?;
}
ColumnType::MYSQL_TYPE_LONGLONG => {
rw.write_col(v)?;
}
_ => {
unreachable!()
}
},
Value::Float(v) => {
rw.write_col(v)?;
}
ColumnType::MYSQL_TYPE_LONGLONG => {
Value::Double(v) => {
rw.write_col(v)?;
}
_ => {
unreachable!()
rw.write_col("")?;
}
},
Value::Float(v) => {
rw.write_col(v)?;
}
Value::Double(v) => {
rw.write_col(v)?;
}
_ => {
rw.write_col("")?;
None => {
rw.write_col(None::<i64>)?;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ select dict_get(mysql_dic_id, 'id', 1), dict_get(mysql_dic_id, 'name', 1), dict_
query ITIFT
select dict_get(mysql_dic_id, 'id', 5), dict_get(mysql_dic_id, 'name', 5), dict_get(mysql_dic_id, 'age', 5), dict_get(mysql_dic_id, 'salary', 5), dict_get(mysql_dic_id, 'active', 5)
----
NULL NULL NULL NULL NULL
5 NULL NULL NULL NULL

query ITIFT
select dict_get(mysql_dic_id_not_null, 'id', 5), dict_get(mysql_dic_id_not_null, 'name', 5), dict_get(mysql_dic_id_not_null, 'age', 5), dict_get(mysql_dic_id_not_null, 'salary', 5), dict_get(mysql_dic_id_not_null, 'active', 5)
----
0 default_name 0 0.0 0
5 default_name 0 0.0 0

statement error 1006
select dict_get(mysql_dic_id, 'id2', 5)
Expand Down

0 comments on commit c0981d3

Please sign in to comment.