Skip to content

Commit af8061e

Browse files
committed
Arrow: support Decimal64 and Bytes
1 parent e36eea4 commit af8061e

8 files changed

Lines changed: 631 additions & 43 deletions

File tree

sea-orm-arrow/src/lib.rs

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ pub fn arrow_array_to_value(
145145
.ok_or_else(|| type_err("BooleanArray", "Boolean", array))?;
146146
Ok(Value::Bool(Some(arr.value(row))))
147147
}
148+
// Binary types
149+
ColumnType::Binary(_) | ColumnType::VarBinary(_) => arrow_to_bytes(array, row),
148150
// Decimal types
149151
ColumnType::Decimal(_) | ColumnType::Money(_) => arrow_to_decimal(array, row),
150152
// Date/time types: delegate to feature-gated helpers.
@@ -208,6 +210,27 @@ pub fn is_datetime_column(col_type: &ColumnType) -> bool {
208210
)
209211
}
210212

213+
// ---------------------------------------------------------------------------
214+
// Binary helpers
215+
// ---------------------------------------------------------------------------
216+
217+
fn arrow_to_bytes(array: &dyn Array, row: usize) -> Result<Value, ArrowError> {
218+
if let Some(arr) = array.as_any().downcast_ref::<BinaryArray>() {
219+
return Ok(Value::Bytes(Some(arr.value(row).to_vec())));
220+
}
221+
if let Some(arr) = array.as_any().downcast_ref::<LargeBinaryArray>() {
222+
return Ok(Value::Bytes(Some(arr.value(row).to_vec())));
223+
}
224+
if let Some(arr) = array.as_any().downcast_ref::<FixedSizeBinaryArray>() {
225+
return Ok(Value::Bytes(Some(arr.value(row).to_vec())));
226+
}
227+
Err(type_err(
228+
"BinaryArray, LargeBinaryArray, or FixedSizeBinaryArray",
229+
"Binary/VarBinary",
230+
array,
231+
))
232+
}
233+
211234
// ---------------------------------------------------------------------------
212235
// Decimal helpers
213236
// ---------------------------------------------------------------------------
@@ -222,6 +245,13 @@ fn arrow_to_decimal(array: &dyn Array, row: usize) -> Result<Value, ArrowError>
222245
return decimal128_to_value(value, precision, scale);
223246
}
224247

248+
if let Some(arr) = array.as_any().downcast_ref::<Decimal64Array>() {
249+
let value = arr.value(row);
250+
let precision = arr.precision();
251+
let scale = arr.scale();
252+
return decimal64_to_value(value, precision, scale);
253+
}
254+
225255
if let Some(arr) = array.as_any().downcast_ref::<Decimal256Array>() {
226256
let value = arr.value(row);
227257
let precision = arr.precision();
@@ -230,12 +260,51 @@ fn arrow_to_decimal(array: &dyn Array, row: usize) -> Result<Value, ArrowError>
230260
}
231261

232262
Err(type_err(
233-
"Decimal128Array or Decimal256Array",
263+
"Decimal64Array, Decimal128Array, or Decimal256Array",
234264
"Decimal",
235265
array,
236266
))
237267
}
238268

269+
#[cfg(feature = "with-rust_decimal")]
270+
fn decimal64_to_value(value: i64, _precision: u8, scale: i8) -> Result<Value, ArrowError> {
271+
use sea_query::prelude::Decimal;
272+
273+
if scale < 0 {
274+
#[cfg(feature = "with-bigdecimal")]
275+
return decimal64_to_bigdecimal(value, scale);
276+
277+
#[cfg(not(feature = "with-bigdecimal"))]
278+
return Err(ArrowError::Unsupported(format!(
279+
"Decimal64 with negative scale={scale} not supported by rust_decimal. \
280+
Enable 'with-bigdecimal' feature."
281+
)));
282+
}
283+
284+
let decimal = Decimal::from_i128_with_scale(value as i128, scale as u32);
285+
Ok(Value::Decimal(Some(decimal)))
286+
}
287+
288+
#[cfg(not(feature = "with-rust_decimal"))]
289+
fn decimal64_to_value(_value: i64, _precision: u8, _scale: i8) -> Result<Value, ArrowError> {
290+
#[cfg(feature = "with-bigdecimal")]
291+
return decimal64_to_bigdecimal(_value, _scale);
292+
293+
#[cfg(not(feature = "with-bigdecimal"))]
294+
Err(ArrowError::Unsupported(
295+
"Decimal64Array requires 'with-rust_decimal' or 'with-bigdecimal' feature".into(),
296+
))
297+
}
298+
299+
#[cfg(feature = "with-bigdecimal")]
300+
fn decimal64_to_bigdecimal(value: i64, scale: i8) -> Result<Value, ArrowError> {
301+
use sea_query::prelude::bigdecimal::{BigDecimal, num_bigint::BigInt};
302+
303+
let bigint = BigInt::from(value);
304+
let decimal = BigDecimal::new(bigint, scale as i64);
305+
Ok(Value::BigDecimal(Some(Box::new(decimal))))
306+
}
307+
239308
#[cfg(feature = "with-rust_decimal")]
240309
fn decimal128_to_value(value: i128, precision: u8, scale: i8) -> Result<Value, ArrowError> {
241310
use sea_query::prelude::Decimal;
@@ -577,6 +646,7 @@ fn null_value_for_type(col_type: &ColumnType) -> Value {
577646
ColumnType::Float => Value::Float(None),
578647
ColumnType::Double => Value::Double(None),
579648
ColumnType::String(_) | ColumnType::Text | ColumnType::Char(_) => Value::String(None),
649+
ColumnType::Binary(_) | ColumnType::VarBinary(_) => Value::Bytes(None),
580650
ColumnType::Boolean => Value::Bool(None),
581651
#[cfg(feature = "with-rust_decimal")]
582652
ColumnType::Decimal(_) | ColumnType::Money(_) => Value::Decimal(None),
@@ -770,6 +840,28 @@ pub fn values_to_arrow_array(
770840
.collect();
771841
Ok(Arc::new(BinaryArray::from(bufs)))
772842
}
843+
DataType::LargeBinary => {
844+
let bufs: Vec<Option<&[u8]>> = values
845+
.iter()
846+
.map(|v| match v {
847+
Value::Bytes(Some(b)) => Some(b.as_slice()),
848+
_ => None,
849+
})
850+
.collect();
851+
Ok(Arc::new(LargeBinaryArray::from(bufs)))
852+
}
853+
DataType::FixedSizeBinary(byte_width) => {
854+
let mut builder = FixedSizeBinaryBuilder::with_capacity(values.len(), *byte_width);
855+
for v in values {
856+
match v {
857+
Value::Bytes(Some(b)) => builder.append_value(b.as_slice()).map_err(|e| {
858+
ArrowError::Unsupported(format!("FixedSizeBinary append error: {e}"))
859+
})?,
860+
_ => builder.append_null(),
861+
}
862+
}
863+
Ok(Arc::new(builder.finish()))
864+
}
773865
DataType::Date32 => {
774866
let arr: Date32Array = values.iter().map(extract_date32).collect();
775867
Ok(Arc::new(arr))
@@ -835,6 +927,18 @@ pub fn values_to_arrow_array(
835927
};
836928
Ok(arr)
837929
}
930+
DataType::Decimal64(precision, scale) => {
931+
let arr: Decimal64Array = values
932+
.iter()
933+
.map(|v| extract_decimal64(v, *scale))
934+
.collect();
935+
let arr = arr
936+
.with_precision_and_scale(*precision, *scale)
937+
.map_err(|e| {
938+
ArrowError::Unsupported(format!("Invalid Decimal64 precision/scale: {e}"))
939+
})?;
940+
Ok(Arc::new(arr))
941+
}
838942
DataType::Decimal128(precision, scale) => {
839943
let arr: Decimal128Array = values
840944
.iter()
@@ -1018,6 +1122,30 @@ pub fn option_values_to_arrow_array(
10181122
.collect();
10191123
Ok(Arc::new(BinaryArray::from(bufs)))
10201124
}
1125+
DataType::LargeBinary => {
1126+
let bufs: Vec<Option<&[u8]>> = values
1127+
.iter()
1128+
.map(|v| match v {
1129+
Some(Value::Bytes(Some(b))) => Some(b.as_slice()),
1130+
_ => None,
1131+
})
1132+
.collect();
1133+
Ok(Arc::new(LargeBinaryArray::from(bufs)))
1134+
}
1135+
DataType::FixedSizeBinary(byte_width) => {
1136+
let mut builder = FixedSizeBinaryBuilder::with_capacity(values.len(), *byte_width);
1137+
for v in values {
1138+
match v {
1139+
Some(Value::Bytes(Some(b))) => {
1140+
builder.append_value(b.as_slice()).map_err(|e| {
1141+
ArrowError::Unsupported(format!("FixedSizeBinary append error: {e}"))
1142+
})?
1143+
}
1144+
_ => builder.append_null(),
1145+
}
1146+
}
1147+
Ok(Arc::new(builder.finish()))
1148+
}
10211149
DataType::Date32 => {
10221150
let arr: Date32Array = values.iter().map(extract_date32_option).collect();
10231151
Ok(Arc::new(arr))
@@ -1091,6 +1219,18 @@ pub fn option_values_to_arrow_array(
10911219
};
10921220
Ok(arr)
10931221
}
1222+
DataType::Decimal64(precision, scale) => {
1223+
let arr: Decimal64Array = values
1224+
.iter()
1225+
.map(|v| extract_decimal64_option(v, *scale))
1226+
.collect();
1227+
let arr = arr
1228+
.with_precision_and_scale(*precision, *scale)
1229+
.map_err(|e| {
1230+
ArrowError::Unsupported(format!("Invalid Decimal64 precision/scale: {e}"))
1231+
})?;
1232+
Ok(Arc::new(arr))
1233+
}
10941234
DataType::Decimal128(precision, scale) => {
10951235
let arr: Decimal128Array = values
10961236
.iter()
@@ -1281,6 +1421,43 @@ fn offset_dt_to_timestamp(
12811421
// Decimal extraction helpers
12821422
// ---------------------------------------------------------------------------
12831423

1424+
fn extract_decimal64_option(v: &Option<Value>, target_scale: i8) -> Option<i64> {
1425+
extract_decimal64(v.as_ref()?, target_scale)
1426+
}
1427+
1428+
fn extract_decimal64(v: &Value, target_scale: i8) -> Option<i64> {
1429+
#[cfg(feature = "with-rust_decimal")]
1430+
if let Value::Decimal(Some(d)) = v {
1431+
let mantissa = d.mantissa();
1432+
let current_scale = d.scale() as i8;
1433+
let scale_diff = target_scale - current_scale;
1434+
let scaled = if scale_diff >= 0 {
1435+
mantissa * 10i128.pow(scale_diff as u32)
1436+
} else {
1437+
mantissa / 10i128.pow((-scale_diff) as u32)
1438+
};
1439+
return i64::try_from(scaled).ok();
1440+
}
1441+
#[cfg(feature = "with-bigdecimal")]
1442+
if let Value::BigDecimal(Some(d)) = v {
1443+
return bigdecimal_to_i64(d, target_scale);
1444+
}
1445+
let _ = (v, target_scale);
1446+
None
1447+
}
1448+
1449+
#[cfg(feature = "with-bigdecimal")]
1450+
fn bigdecimal_to_i64(
1451+
d: &sea_query::prelude::bigdecimal::BigDecimal,
1452+
target_scale: i8,
1453+
) -> Option<i64> {
1454+
use sea_query::prelude::bigdecimal::ToPrimitive;
1455+
1456+
let rescaled = d.clone().with_scale(target_scale as i64);
1457+
let (digits, _) = rescaled.into_bigint_and_exponent();
1458+
digits.to_i64()
1459+
}
1460+
12841461
fn extract_decimal128_option(v: &Option<Value>, target_scale: i8) -> Option<i128> {
12851462
extract_decimal128(v.as_ref()?, target_scale)
12861463
}

sea-orm-macros/src/derives/arrow_schema.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ pub fn expand_derive_arrow_schema(
5959
} else if meta.path.is_ident("arrow_comment") {
6060
let lit: LitStr = meta.value()?.parse()?;
6161
arrow_attrs.comment = Some(lit.value());
62+
} else if meta.path.is_ident("arrow_byte_width") {
63+
let lit: LitInt = meta.value()?.parse()?;
64+
arrow_attrs.byte_width = Some(lit.base10_parse()?);
6265
} else if meta.path.is_ident("column_type") {
6366
let lit: LitStr = meta.value()?.parse()?;
6467
column_type_str = Some(lit.value());
@@ -121,6 +124,7 @@ struct ArrowFieldAttrs {
121124
timezone: Option<String>,
122125
comment: Option<String>,
123126
nullable_attr: bool,
127+
byte_width: Option<i32>,
124128
}
125129

126130
struct ArrowFieldInfo {
@@ -189,16 +193,21 @@ fn column_type_to_arrow_datatype(col_type: &str, arrow_attrs: &ArrowFieldAttrs)
189193
let final_precision = arrow_attrs.precision.unwrap_or(precision);
190194
let final_scale = arrow_attrs.scale.unwrap_or(scale);
191195

192-
if final_precision <= 38 {
196+
if final_precision <= 18 {
197+
quote! { DataType::Decimal64(#final_precision, #final_scale) }
198+
} else if final_precision <= 38 {
193199
quote! { DataType::Decimal128(#final_precision, #final_scale) }
194200
} else {
195201
quote! { DataType::Decimal256(#final_precision, #final_scale) }
196202
}
197203
} else if col_type.starts_with("Money(") {
198-
// Money type - default to Decimal128(19, 4)
199204
let precision = arrow_attrs.precision.unwrap_or(19);
200205
let scale = arrow_attrs.scale.unwrap_or(4);
201-
quote! { DataType::Decimal128(#precision, #scale) }
206+
if precision <= 18 {
207+
quote! { DataType::Decimal64(#precision, #scale) }
208+
} else {
209+
quote! { DataType::Decimal128(#precision, #scale) }
210+
}
202211
} else if col_type == "TinyInteger" {
203212
quote! { DataType::Int8 }
204213
} else if col_type == "SmallInteger" {
@@ -252,7 +261,11 @@ fn column_type_to_arrow_datatype(col_type: &str, arrow_attrs: &ArrowFieldAttrs)
252261
} else if col_type == "TimestampWithTimeZone" {
253262
generate_timestamp_datatype(arrow_attrs, true)
254263
} else if col_type.starts_with("Binary(") || col_type.starts_with("VarBinary(") {
255-
quote! { DataType::Binary }
264+
if let Some(bw) = arrow_attrs.byte_width {
265+
quote! { DataType::FixedSizeBinary(#bw) }
266+
} else {
267+
quote! { DataType::Binary }
268+
}
256269
} else if col_type == "Json" || col_type == "JsonBinary" {
257270
quote! { DataType::Utf8 }
258271
} else if col_type == "Uuid" {
@@ -295,7 +308,9 @@ fn rust_type_to_arrow_datatype(field_type: &Type, arrow_attrs: &ArrowFieldAttrs)
295308
s if s.contains("Decimal") => {
296309
let precision = arrow_attrs.precision.unwrap_or(38);
297310
let scale = arrow_attrs.scale.unwrap_or(10);
298-
if precision <= 38 {
311+
if precision <= 18 {
312+
quote! { DataType::Decimal64(#precision, #scale) }
313+
} else if precision <= 38 {
299314
quote! { DataType::Decimal128(#precision, #scale) }
300315
} else {
301316
quote! { DataType::Decimal256(#precision, #scale) }
@@ -313,6 +328,13 @@ fn rust_type_to_arrow_datatype(field_type: &Type, arrow_attrs: &ArrowFieldAttrs)
313328
}
314329
s if s.contains("Date") => quote! { DataType::Date32 },
315330
s if s.contains("Time") => quote! { DataType::Time64(TimeUnit::Microsecond) },
331+
"Vec<u8>" => {
332+
if let Some(bw) = arrow_attrs.byte_width {
333+
quote! { DataType::FixedSizeBinary(#bw) }
334+
} else {
335+
quote! { DataType::Binary }
336+
}
337+
}
316338
_ => quote! { DataType::Binary }, // Safe fallback
317339
}
318340
}

0 commit comments

Comments
 (0)