Skip to content

Commit

Permalink
fix get_json_value.evaluate_with_value() error
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangli20 committed Nov 8, 2023
1 parent ca59570 commit e78d163
Showing 1 changed file with 43 additions and 30 deletions.
73 changes: 43 additions & 30 deletions native-engine/datafusion-ext-functions/src/spark_get_json_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl HiveGetJsonObjectEvaluator {
serde_json::Value::Number(number) => Ok(Some(number.to_string())),
serde_json::Value::Bool(b) => Ok(Some(b.to_string())),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
serde_json::to_string(&value).map(Some).map_err(|_| {
serde_json::to_string(value_ref).map(Some).map_err(|_| {
HiveGetJsonObjectError::InvalidInput("array to json error".to_string())
})
}
Expand Down Expand Up @@ -408,7 +408,13 @@ impl HiveGetJsonObjectMatcher {

#[cfg(test)]
mod test {
use crate::spark_get_json_object::HiveGetJsonObjectEvaluator;
use crate::spark_get_json_object::{
spark_get_parsed_json_object, spark_parse_json, HiveGetJsonObjectEvaluator,
};
use arrow::array::{AsArray, StringArray};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;
use std::sync::Arc;

#[test]
fn test_hive_demo() {
Expand Down Expand Up @@ -525,40 +531,47 @@ mod test {
]
}
}"#;
let input_array = Arc::new(StringArray::from(vec![input]));
let parsed = spark_parse_json(&[ColumnarValue::Array(input_array)]).unwrap();

let path = "$.message.location.county";
assert_eq!(
HiveGetJsonObjectEvaluator::try_new(path)
.unwrap()
.evaluate(input)
.unwrap(),
Some(r#"["浦东","西直门"]"#.to_owned())
);
//let path = ColumnarValue::Scalar(ScalarValue::from("$.NOT_EXISTED"));
//let r = spark_get_parsed_json_object(&[parsed.clone(), path]).unwrap().into_array(1);
//let v = r.as_string::<i32>().iter().next().unwrap();
//assert_eq!(v, None);

let path = "$.message.name";
assert!(HiveGetJsonObjectEvaluator::try_new(path)
let path = ColumnarValue::Scalar(ScalarValue::from("$.message.location.county"));
let r = spark_get_parsed_json_object(&[parsed.clone(), path])
.unwrap()
.evaluate(input)
.into_array(1);
let v = r.as_string::<i32>().iter().next().unwrap();
assert_eq!(v, Some(r#"["浦东","西直门"]"#));

let path = ColumnarValue::Scalar(ScalarValue::from("$.message.location.NOT_EXISTED"));
let r = spark_get_parsed_json_object(&[parsed.clone(), path])
.unwrap()
.into_array(1);
let v = r.as_string::<i32>().iter().next().unwrap();
assert_eq!(v, Some(r#"[]"#));

let path = ColumnarValue::Scalar(ScalarValue::from("$.message.name"));
let r = spark_get_parsed_json_object(&[parsed.clone(), path])
.unwrap()
.contains("Asher"));
.into_array(1);
let v = r.as_string::<i32>().iter().next().unwrap();
assert!(v.unwrap().contains("Asher"));

let path = "$.message.location.city";
assert_eq!(
HiveGetJsonObjectEvaluator::try_new(path)
.unwrap()
.evaluate(input)
.unwrap(),
Some(r#"["1.234",1.234]"#.to_owned())
);
let path = ColumnarValue::Scalar(ScalarValue::from("$.message.location.city"));
let r = spark_get_parsed_json_object(&[parsed.clone(), path])
.unwrap()
.into_array(1);
let v = r.as_string::<i32>().iter().next().unwrap();
assert_eq!(v, Some(r#"["1.234",1.234]"#));

let path = "$.message.location[0]";
assert_eq!(
HiveGetJsonObjectEvaluator::try_new(path)
.unwrap()
.evaluate(input)
.unwrap(),
Some(r#"{"city":"1.234","county":"浦东"}"#.to_owned())
);
let path = ColumnarValue::Scalar(ScalarValue::from("$.message.location[0]"));
let r = spark_get_parsed_json_object(&[parsed.clone(), path])
.unwrap()
.into_array(1);
let v = r.as_string::<i32>().iter().next().unwrap();
assert_eq!(v, Some(r#"{"city":"1.234","county":"浦东"}"#));
}
}

0 comments on commit e78d163

Please sign in to comment.