Skip to content

Commit

Permalink
Fix bug with handling of null values in dictionaries (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb authored Feb 4, 2025
1 parent 999d672 commit 2fffb96
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
32 changes: 30 additions & 2 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use datafusion::arrow::array::{
downcast_array, AnyDictionaryArray, Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, LargeStringArray,
PrimitiveArray, RunArray, StringArray, StringViewArray,
PrimitiveArray, PrimitiveBuilder, RunArray, StringArray, StringViewArray,
};
use datafusion::arrow::compute::kernels::cast;
use datafusion::arrow::compute::take;
Expand Down Expand Up @@ -245,6 +245,34 @@ fn invoke_array_array<R: InvokeResult>(
}
}

/// Transform keys that may be pointing to values with nulls to nulls themselves.
/// keys = `[0, 1, 2, 3]`, values = `[null, "a", null, "b"]`
/// into
/// keys = `[null, 0, null, 1]`, values = `["a", "b"]`
///
/// Arrow / `DataFusion` assumes that dictionary values do not contain nulls, nulls are encoded by the keys.
/// Not following this invariant causes invalid dictionary arrays to be built later on inside of `DataFusion`
/// when arrays are concacted and such.
fn remap_dictionary_key_nulls(keys: PrimitiveArray<Int64Type>, values: ArrayRef) -> DictionaryArray<Int64Type> {
// fast path: no nulls in values
if values.null_count() == 0 {
return DictionaryArray::new(keys, values);
}

let mut new_keys_builder = PrimitiveBuilder::<Int64Type>::new();

for key in &keys {
match key {
Some(k) if values.is_null(k.as_usize()) => new_keys_builder.append_null(),
Some(k) => new_keys_builder.append_value(k),
None => new_keys_builder.append_null(),
}
}

let new_keys = new_keys_builder.finish();
DictionaryArray::new(new_keys, values)
}

fn invoke_array_scalars<R: InvokeResult>(
json_array: &ArrayRef,
path: &[JsonPath],
Expand Down Expand Up @@ -281,7 +309,7 @@ fn invoke_array_scalars<R: InvokeResult>(
let type_ids = values.as_union().type_ids();
keys = mask_dictionary_keys(&keys, type_ids);
}
Ok(Arc::new(DictionaryArray::new(keys, values)))
Ok(Arc::new(remap_dictionary_key_nulls(keys, values)))
} else {
// this is what cast would do under the hood to unpack a dictionary into an array of its values
Ok(take(&values, json_array.keys(), None)?)
Expand Down
66 changes: 64 additions & 2 deletions tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, RecordBatch};
use datafusion::arrow::datatypes::{Field, Int8Type, Schema};
use datafusion::arrow::array::{Array, ArrayRef, DictionaryArray, RecordBatch};
use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema};
use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType};
use datafusion::assert_batches_eq;
use datafusion::common::ScalarValue;
Expand Down Expand Up @@ -1280,6 +1280,68 @@ async fn test_dict_haystack() {
assert_batches_eq!(expected, &batches);
}

fn check_for_null_dictionary_values(array: &dyn Array) {
let array = array.as_any().downcast_ref::<DictionaryArray<Int64Type>>().unwrap();
let keys_array = array.keys();
let keys = keys_array
.iter()
.filter_map(|x| x.map(|v| usize::try_from(v).unwrap()))
.collect::<Vec<_>>();
let values_array = array.values();
// no non-null keys should point to a null value
for i in 0..values_array.len() {
if values_array.is_null(i) {
// keys should not contain
if keys.contains(&i) {
println!("keys: {:?}", keys);
println!("values: {:?}", values_array);
panic!("keys should not contain null values");
}
}
}
}

/// Test that we don't output nulls in dictionary values.
/// This can cause issues with arrow-rs and DataFusion; they expect nulls to be in keys.
#[tokio::test]
async fn test_dict_get_no_null_values() {
let ctx = build_dict_schema().await;

let sql = "select json_get(x, 'baz') v from data";
let expected = [
"+------------+",
"| v |",
"+------------+",
"| |",
"| {str=fizz} |",
"| |",
"| {str=abcd} |",
"| |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| |",
"+------------+",
];
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
assert_batches_eq!(expected, &batches);
for batch in batches {
check_for_null_dictionary_values(batch.column(0).as_ref());
}

let sql = "select json_get_str(x, 'baz') v from data";
let expected = [
"+------+", "| v |", "+------+", "| |", "| fizz |", "| |", "| abcd |", "| |", "| fizz |",
"| fizz |", "| fizz |", "| fizz |", "| |", "+------+",
];
let batches = ctx.sql(&sql).await.unwrap().collect().await.unwrap();
assert_batches_eq!(expected, &batches);
for batch in batches {
check_for_null_dictionary_values(batch.column(0).as_ref());
}
}

#[tokio::test]
async fn test_dict_haystack_filter() {
let sql = "select json_data v from dicts where json_get(json_data, 'foo') is not null";
Expand Down

0 comments on commit 2fffb96

Please sign in to comment.