Skip to content

Commit df69ef5

Browse files
authored
fix: coerce_primitive for serde decoded data (#5101)
* fix: fix json decode number Signed-off-by: fan <[email protected]> * follow reviews Signed-off-by: fan <[email protected]> * follow reviews Signed-off-by: fan <[email protected]> * use fixed size space Signed-off-by: fan <[email protected]> --------- Signed-off-by: fan <[email protected]>
1 parent fbbb61d commit df69ef5

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

arrow-json/src/reader/mod.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,9 @@ mod tests {
717717

718718
use arrow_array::cast::AsArray;
719719
use arrow_array::types::Int32Type;
720-
use arrow_array::{make_array, Array, BooleanArray, ListArray, StringArray, StructArray};
720+
use arrow_array::{
721+
make_array, Array, BooleanArray, Float64Array, ListArray, StringArray, StructArray,
722+
};
721723
use arrow_buffer::{ArrowNativeType, Buffer};
722724
use arrow_cast::display::{ArrayFormatter, FormatOptions};
723725
use arrow_data::ArrayDataBuilder;
@@ -2259,4 +2261,43 @@ mod tests {
22592261
.values();
22602262
assert_eq!(values, &[1699148028689, 2, 3, 4]);
22612263
}
2264+
2265+
#[test]
2266+
fn test_coercing_primitive_into_string_decoder() {
2267+
let buf = &format!(
2268+
r#"[{{"a": 1, "b": "A", "c": "T"}}, {{"a": 2, "b": "BB", "c": "F"}}, {{"a": {}, "b": 123, "c": false}}, {{"a": {}, "b": 789, "c": true}}]"#,
2269+
(std::i32::MAX as i64 + 10),
2270+
std::i64::MAX - 10
2271+
);
2272+
let schema = Schema::new(vec![
2273+
Field::new("a", DataType::Float64, true),
2274+
Field::new("b", DataType::Utf8, true),
2275+
Field::new("c", DataType::Utf8, true),
2276+
]);
2277+
let json_array: Vec<serde_json::Value> = serde_json::from_str(buf).unwrap();
2278+
let schema_ref = Arc::new(schema);
2279+
2280+
// read record batches
2281+
let reader = ReaderBuilder::new(schema_ref.clone()).with_coerce_primitive(true);
2282+
let mut decoder = reader.build_decoder().unwrap();
2283+
decoder.serialize(json_array.as_slice()).unwrap();
2284+
let batch = decoder.flush().unwrap().unwrap();
2285+
assert_eq!(
2286+
batch,
2287+
RecordBatch::try_new(
2288+
schema_ref,
2289+
vec![
2290+
Arc::new(Float64Array::from(vec![
2291+
1.0,
2292+
2.0,
2293+
(std::i32::MAX as i64 + 10) as f64,
2294+
(std::i64::MAX - 10) as f64
2295+
])),
2296+
Arc::new(StringArray::from(vec!["A", "BB", "123", "789"])),
2297+
Arc::new(StringArray::from(vec!["T", "F", "false", "true"])),
2298+
]
2299+
)
2300+
.unwrap()
2301+
);
2302+
}
22622303
}

arrow-json/src/reader/string_array.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,18 @@ impl<O: OffsetSizeTrait> ArrayDecoder for StringArrayDecoder<O> {
6161
TapeElement::Number(idx) if coerce_primitive => {
6262
data_capacity += tape.get_string(idx).len();
6363
}
64-
_ => return Err(tape.error(*p, "string")),
64+
TapeElement::I64(_)
65+
| TapeElement::I32(_)
66+
| TapeElement::F64(_)
67+
| TapeElement::F32(_)
68+
if coerce_primitive =>
69+
{
70+
// An arbitrary estimate
71+
data_capacity += 10;
72+
}
73+
_ => {
74+
return Err(tape.error(*p, "string"));
75+
}
6576
}
6677
}
6778

@@ -89,6 +100,26 @@ impl<O: OffsetSizeTrait> ArrayDecoder for StringArrayDecoder<O> {
89100
TapeElement::Number(idx) if coerce_primitive => {
90101
builder.append_value(tape.get_string(idx));
91102
}
103+
TapeElement::I64(high) if coerce_primitive => match tape.get(p + 1) {
104+
TapeElement::I32(low) => {
105+
let val = (high as i64) << 32 | (low as u32) as i64;
106+
builder.append_value(val.to_string());
107+
}
108+
_ => unreachable!(),
109+
},
110+
TapeElement::I32(n) if coerce_primitive => {
111+
builder.append_value(n.to_string());
112+
}
113+
TapeElement::F32(n) if coerce_primitive => {
114+
builder.append_value(n.to_string());
115+
}
116+
TapeElement::F64(high) if coerce_primitive => match tape.get(p + 1) {
117+
TapeElement::F32(low) => {
118+
let val = f64::from_bits((high as u64) << 32 | low as u64);
119+
builder.append_value(val.to_string());
120+
}
121+
_ => unreachable!(),
122+
},
92123
_ => unreachable!(),
93124
}
94125
}

0 commit comments

Comments
 (0)