Skip to content

Experiment with Extension as a DataType #7398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
143 changes: 143 additions & 0 deletions arrow-array/src/array/extension_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{any::Any, sync::Arc};

use arrow_data::ArrayData;
use arrow_schema::{extension::DynExtensionType, ArrowError, DataType};

use super::{make_array, Array, ArrayRef};

/// Array type for DataType::Extension
#[derive(Debug)]
pub struct ExtensionArray {
data_type: DataType,
storage: ArrayRef,
}

impl ExtensionArray {
/// Try to create a new ExtensionArray
pub fn try_new(
extension: Arc<dyn DynExtensionType + Send + Sync>,
storage: ArrayRef,
) -> Result<Self, ArrowError> {
Ok(Self {
data_type: DataType::Extension(extension),
storage,
})
}

/// Create a new ExtensionArray
pub fn new(extension: Arc<dyn DynExtensionType + Send + Sync>, storage: ArrayRef) -> Self {
Self::try_new(extension, storage).unwrap()
}

/// Return the underlying storage array
pub fn storage(&self) -> &ArrayRef {
&self.storage
}

/// Return a new array with new storage of the same type
pub fn with_storage(&self, new_storage: ArrayRef) -> Self {
assert_eq!(new_storage.data_type(), new_storage.data_type());
Self {
data_type: self.data_type.clone(),
storage: new_storage,
}
}
}

impl From<ArrayData> for ExtensionArray {
fn from(data: ArrayData) -> Self {
if let DataType::Extension(extension) = data.data_type() {
let storage_data = ArrayData::try_new(
extension.storage_type().clone(),
data.len(),
data.nulls().map(|b| b.buffer()).cloned(),
data.offset(),
data.buffers().to_vec(),
data.child_data().to_vec(),
)
.unwrap();

Self {
data_type: data.data_type().clone(),
storage: Arc::new(make_array(storage_data)) as ArrayRef,
}
} else {
panic!("{} is not Extension", data.data_type())
}
}
}

impl Array for ExtensionArray {
fn as_any(&self) -> &dyn Any {
self
}

fn to_data(&self) -> ArrayData {
let storage_data = self.storage.to_data();
ArrayData::try_new(
self.data_type.clone(),
storage_data.len(),
storage_data.nulls().map(|b| b.buffer()).cloned(),
storage_data.offset(),
storage_data.buffers().to_vec(),
storage_data.child_data().to_vec(),
)
.unwrap()
}

fn into_data(self) -> ArrayData {
self.to_data()
}

fn data_type(&self) -> &DataType {
&self.data_type
}

fn slice(&self, offset: usize, length: usize) -> ArrayRef {
Arc::new(Self {
data_type: self.data_type.clone(),
storage: self.storage.slice(offset, length),
})
}

fn len(&self) -> usize {
self.storage.len()
}

fn is_empty(&self) -> bool {
self.storage.is_empty()
}

fn offset(&self) -> usize {
self.storage.offset()
}

fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
self.storage.nulls()
}

fn get_buffer_memory_size(&self) -> usize {
self.storage.get_buffer_memory_size()
}

fn get_array_memory_size(&self) -> usize {
self.storage.get_array_memory_size()
}
}
4 changes: 4 additions & 0 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ mod list_view_array;

pub use list_view_array::*;

mod extension_array;
pub use extension_array::*;

use crate::iterator::ArrayIter;

/// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html)
Expand Down Expand Up @@ -829,6 +832,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef,
DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef,
DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef,
DataType::Extension(_) => Arc::new(ExtensionArray::from(data)) as ArrayRef,
dt => panic!("Unexpected data type {dt:?}"),
}
}
Expand Down
21 changes: 20 additions & 1 deletion arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
}
}
}
DataType::Extension(extension) => new_buffers(extension.storage_type(), capacity),
}
}

Expand Down Expand Up @@ -590,6 +591,12 @@ impl ArrayData {

/// Returns a new [`ArrayData`] valid for `data_type` containing `len` null values
pub fn new_null(data_type: &DataType, len: usize) -> Self {
if let DataType::Extension(extension) = data_type {
let mut storage_data = Self::new_null(extension.storage_type(), len);
storage_data.data_type = data_type.clone();
return storage_data;
}

let bit_len = bit_util::ceil(len, 8);
let zeroed = |len: usize| Buffer::from(MutableBuffer::from_len_zeroed(len));

Expand Down Expand Up @@ -1664,6 +1671,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout {
}
}
DataType::Dictionary(key_type, _value_type) => layout(key_type),
DataType::Extension(extension) => layout(extension.storage_type()),
}
}

Expand Down Expand Up @@ -2119,7 +2127,7 @@ impl From<ArrayData> for ArrayDataBuilder {
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::{Field, Fields};
use arrow_schema::{extension::TestExtension, Field, Fields};

// See arrow/tests/array_data_validation.rs for test of array validation

Expand Down Expand Up @@ -2448,4 +2456,15 @@ mod tests {
assert!(array.is_null(i));
}
}

#[test]
fn test_data_extension() {
let data_type = DataType::Extension(Arc::new(TestExtension {
storage_type: DataType::Utf8,
}));
let array_null = ArrayData::new_null(&data_type, 3);
assert_eq!(array_null.len(), 3);
assert_eq!(array_null.data_type(), &data_type);
assert_eq!(array_null.null_count(), 3);
}
}
1 change: 1 addition & 0 deletions arrow-data/src/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ fn equal_values(
DataType::Float16 => primitive_equal::<f16>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Map(_, _) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Extension(_) => unimplemented!("Extension not implemented"),
}
}

Expand Down
3 changes: 3 additions & 0 deletions arrow-data/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ fn build_extend(array: &ArrayData) -> Extend {
UnionMode::Dense => union::build_extend_dense(array),
},
DataType::RunEndEncoded(_, _) => todo!(),
DataType::Extension(_) => unimplemented!("Extension not implemented"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This highlights the challenge with this approach, every kernel/function must now be updated to accommodate this new extension type. Not only does this stand a very high chance of regressing functionality, but also is IMO not in the spirit of extension types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the tradeoff is taking on the burden of extensibility (so that other projects don't have to implement it separately, for example, by redefining wrapper DataTypes or Arrays). This extra burden hasn't caused issues in Arrow C++/pyarrow that I'm aware of but absolutely fair that it might here.

}
}

Expand Down Expand Up @@ -332,6 +333,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
UnionMode::Dense => union::extend_nulls_dense,
},
DataType::RunEndEncoded(_, _) => todo!(),
DataType::Extension(_) => unimplemented!("ListView/LargeListView not implemented"),
})
}

Expand Down Expand Up @@ -590,6 +592,7 @@ impl<'a> MutableArrayData<'a> {
MutableArrayData::new(child_arrays, use_nulls, array_capacity)
})
.collect::<Vec<_>>(),
DataType::Extension(_) => unimplemented!("Extension not implemented"),
};

// Get the dictionary if any, and if it is a concatenation of multiple
Expand Down
1 change: 1 addition & 0 deletions arrow-integration-test/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value {
json!({"name": "map", "keysSorted": keys_sorted})
}
DataType::RunEndEncoded(_, _) => todo!(),
DataType::Extension(extension) => data_type_to_json(extension.storage_type()),
}
}

Expand Down
46 changes: 42 additions & 4 deletions arrow-ipc/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Utilities for converting between IPC types and native Arrow types

use arrow_buffer::Buffer;
use arrow_schema::extension::DynExtensionTypeFactory;
use arrow_schema::*;
use flatbuffers::{
FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier,
Expand Down Expand Up @@ -194,8 +195,16 @@ impl From<crate::Field<'_>> for Field {
}
}

/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema].
/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema]
pub fn fb_to_schema(fb: crate::Schema) -> Schema {
fb_to_schema_with_extension_factory(fb, None).unwrap()
}

/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] with extension support
pub fn fb_to_schema_with_extension_factory(
fb: crate::Schema,
extension_factory: Option<&dyn DynExtensionTypeFactory>,
) -> Result<Schema, ArrowError> {
let mut fields: Vec<Field> = vec![];
let c_fields = fb.fields().unwrap();
let len = c_fields.len();
Expand All @@ -207,7 +216,15 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema {
}
_ => (),
};
fields.push(c_field.into());
let field: Field = c_field.into();
if let Some(factory) = extension_factory {
if let Some(extension) = factory.make_from_field(&field)? {
fields.push(field.clone().with_data_type(DataType::Extension(extension)));
continue;
}
}

fields.push(field);
}

let mut metadata: HashMap<String, String> = HashMap::default();
Expand All @@ -224,7 +241,8 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema {
}
}
}
Schema::new_with_metadata(fields, metadata)

Ok(Schema::new_with_metadata(fields, metadata))
}

/// Try deserialize flat buffer format bytes into a schema
Expand Down Expand Up @@ -514,7 +532,24 @@ pub(crate) fn build_field<'a>(
) -> WIPOffset<crate::Field<'a>> {
// Optional custom metadata.
let mut fb_metadata = None;
if !field.metadata().is_empty() {

// Handle extension type metadata if applicable
if let DataType::Extension(extension) = field.data_type() {
let mut field_metadata = HashMap::from([
(
"ARROW:extension:name".to_string(),
extension.extension_name().to_string(),
),
(
"ARROW:extension:metadata".to_string(),
extension.serialized_metadata(),
),
]);

for (k, v) in field.metadata() {
field_metadata.insert(k.clone(), v.clone());
}
} else if !field.metadata().is_empty() {
fb_metadata = Some(metadata_to_fb(fbb, field.metadata()));
};

Expand Down Expand Up @@ -883,6 +918,9 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&children[..])),
}
}
DataType::Extension(extension) => {
get_fb_field_type(extension.storage_type(), dictionary_tracker, fbb)
}
}
}

Expand Down
Loading
Loading