Skip to content

Commit

Permalink
wip: numeric to double
Browse files Browse the repository at this point in the history
  • Loading branch information
lupko committed Feb 5, 2024
1 parent 3828e0b commit 6e308e7
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 76 deletions.
269 changes: 194 additions & 75 deletions c/driver/postgresql/copy/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <algorithm>
#include <cinttypes>
#include <limits>
#include <memory>
#include <string>
#include <utility>
Expand Down Expand Up @@ -242,26 +243,58 @@ class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader {
}
};

// // Converts COPY resulting from the Postgres NUMERIC type into a string.
// Rewritten based on the Postgres implementation of NUMERIC cast to string in
// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source,
// DEC_DIGITS is always 4 and DBASE is always 10000).
// Base class for readers of Postgres NUMERIC type. Code in this class provides
// common utility methods that are useful for both conversion of NUMERIC to
// Arrow string and NUMERIC to Arrow double.
//
// Briefly, the Postgres representation of "numeric" is an array of int16_t ("digits")
// from most significant to least significant. Each "digit" is a value between 0000 and
// 9999. There are weight + 1 digits before the decimal point and dscale digits after the
// decimal point. Both of those values can be zero or negative. A "sign" component
// encodes the positive or negativeness of the value and is also used to encode special
// values (inf, -inf, and nan).
//
// The methods implemented here are responsible for reading input data and preparing the
// string representation of the value.
//
// The conversion methods are rewritten based on the Postgres implementation of
// NUMERIC cast to string in src/backend/utils/adt/numeric.c : get_str_from_var() (
// Note that in the initial source, DEC_DIGITS is always 4 and DBASE is always 10000).
class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader {
public:
ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
ArrowError* error) override {
// -1 for NULL
if (field_size_bytes < 0) {
return ArrowArrayAppendNull(array, 1);
}
protected:
// Number of decimal digits per Postgres digit
static const int kDecDigits = 4;
// The "base" of the Postgres representation (i.e., each "digit" is 0 to 9999)
static const int kNBase = 10000;
// Valid values for the sign component
static const uint16_t kNumericPos = 0x0000;
static const uint16_t kNumericNeg = 0x4000;
static const uint16_t kNumericNAN = 0xC000;
static const uint16_t kNumericPinf = 0xD000;
static const uint16_t kNumericNinf = 0xF000;

int16_t ndigits_;
int16_t weight_;
uint16_t sign_;
uint16_t dscale_;
std::vector<int16_t> digits_;

// Returns maximum number of characters required to hold
// string representation of NUMERIC value.
int64_t max_chars_required_() const {
int64_t max_chars_required = std::max<int64_t>(1, (weight_ + 1) * kDecDigits);
max_chars_required += dscale_ + kDecDigits + 2;

return max_chars_required;
}

// Reads all data for a single NUMERIC value.
//
// If the input has issues, returns non-zero error code and sets the
// Arrow error.
//
// On success, populates ndigits_, weight_, sign_, dscale_ and digits_.
ArrowErrorCode ReadInputDigit(ArrowBufferView* data, ArrowError* error) {
// Read the input
if (data->size_bytes < static_cast<int64_t>(4 * sizeof(int16_t))) {
ArrowErrorSet(error,
Expand All @@ -272,64 +305,44 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader {
return EINVAL;
}

int16_t ndigits = ReadUnsafe<int16_t>(data);
int16_t weight = ReadUnsafe<int16_t>(data);
uint16_t sign = ReadUnsafe<uint16_t>(data);
uint16_t dscale = ReadUnsafe<uint16_t>(data);
ndigits_ = ReadUnsafe<int16_t>(data);
weight_ = ReadUnsafe<int16_t>(data);
sign_ = ReadUnsafe<uint16_t>(data);
dscale_ = ReadUnsafe<uint16_t>(data);

if (data->size_bytes < static_cast<int64_t>(ndigits * sizeof(int16_t))) {
if (data->size_bytes < static_cast<int64_t>(ndigits_ * sizeof(int16_t))) {
ArrowErrorSet(error,
"Expected at least %d bytes of field data for numeric digits copy "
"data but only %d bytes of input remain",
static_cast<int>(ndigits * sizeof(int16_t)),
static_cast<int>(ndigits_ * sizeof(int16_t)),
static_cast<int>(data->size_bytes)); // NOLINT(runtime/int)
return EINVAL;
}

digits_.clear();
for (int16_t i = 0; i < ndigits; i++) {
for (int16_t i = 0; i < ndigits_; i++) {
digits_.push_back(ReadUnsafe<int16_t>(data));
}

// Handle special values
std::string special_value;
switch (sign) {
case kNumericNAN:
special_value = std::string("nan");
break;
case kNumericPinf:
special_value = std::string("inf");
break;
case kNumericNinf:
special_value = std::string("-inf");
break;
case kNumericPos:
case kNumericNeg:
special_value = std::string("");
break;
default:
ArrowErrorSet(error,
"Unexpected value for sign read from Postgres numeric field: %d",
static_cast<int>(sign));
return EINVAL;
}

if (!special_value.empty()) {
NANOARROW_RETURN_NOT_OK(
ArrowBufferAppend(data_, special_value.data(), special_value.size()));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes));
return AppendValid(array);
}
return 0;
}

// Calculate string space requirement
int64_t max_chars_required = std::max<int64_t>(1, (weight + 1) * kDecDigits);
max_chars_required += dscale + kDecDigits + 2;
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(data_, max_chars_required));
char* out0 = reinterpret_cast<char*>(data_->data + data_->size_bytes);
char* out = out0;
// Converts NUMERIC value to string. The result is written to the target and will
// not be null-terminated.
//
// This code is a rewrite of PostgreSQL function get_str_from_var() found in
// src/backend/utils/adt/numeric.c.
//
// This method has two assumptions:
//
// - the target buffer is allocated and large enough to hold the result
// - the NUMERIC value is non-special value (e.g. not +/- infinity or NaN)
int64_t DigitsToString(char** target) {
char* out = *target;
char* out0 = *target;

// Build output string in-place, starting with the negative sign
if (sign == kNumericNeg) {
if (sign_ == kNumericNeg) {
*out++ = '-';
}

Expand All @@ -338,12 +351,12 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader {
int d1;
int16_t dig;

if (weight < 0) {
d = weight + 1;
if (weight_ < 0) {
d = weight_ + 1;
*out++ = '0';
} else {
for (d = 0; d <= weight; d++) {
if (d < ndigits) {
for (d = 0; d <= weight_; d++) {
if (d < ndigits_) {
dig = digits_[d];
} else {
dig = 0;
Expand All @@ -370,12 +383,12 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader {
// keep here.
int64_t actual_chars_required = out - out0;

if (dscale > 0) {
if (dscale_ > 0) {
*out++ = '.';
actual_chars_required += dscale + 1;
actual_chars_required += dscale_ + 1;

for (int i = 0; i < dscale; d++, i += kDecDigits) {
if (d >= 0 && d < ndigits) {
for (int i = 0; i < dscale_; d++, i += kDecDigits) {
if (d >= 0 && d < ndigits_) {
dig = digits_[d];
} else {
dig = 0;
Expand All @@ -391,25 +404,128 @@ class PostgresCopyNumericFieldReader : public PostgresCopyFieldReader {
}
}

return actual_chars_required;
}
};

// Converts COPY resulting from the Postgres NUMERIC type into a string.
class PostgresCopyNumericToStrFieldReader : public PostgresCopyNumericFieldReader {
public:
ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
ArrowError* error) override {
// -1 for NULL
if (field_size_bytes < 0) {
return ArrowArrayAppendNull(array, 1);
}

ArrowErrorCode digit_error = ReadInputDigit(data, error);
if (digit_error) {
return digit_error;
}

// Handle special values
std::string special_value;
switch (sign_) {
case kNumericNAN:
special_value = std::string("nan");
break;
case kNumericPinf:
special_value = std::string("inf");
break;
case kNumericNinf:
special_value = std::string("-inf");
break;
case kNumericPos:
case kNumericNeg:
special_value = std::string("");
break;
default:
ArrowErrorSet(error,
"Unexpected value for sign read from Postgres numeric field: %d",
static_cast<int>(sign_));
return EINVAL;
}

if (!special_value.empty()) {
NANOARROW_RETURN_NOT_OK(
ArrowBufferAppend(data_, special_value.data(), special_value.size()));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes));
return AppendValid(array);
}

int64_t max_chars_required = max_chars_required_();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(data_, max_chars_required));
char* out = reinterpret_cast<char*>(data_->data + data_->size_bytes);

int64_t actual_chars_required = DigitsToString(&out);

// Update data buffer size and add offsets
data_->size_bytes += actual_chars_required;
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(offsets_, data_->size_bytes));
return AppendValid(array);
}
};

private:
std::vector<int16_t> digits_;
// Converts COPY resulting from the Postgres NUMERIC type into a double.
//
// Similar to Postgres numericvar_to_double_no_overflow() method found in
// src/backend/utils/adt/numeric.c, this reader will first convert the NUMERIC
// to string and then use strtod() to get a double value.
class PostgresCopyNumericToDoubleFieldReader : public PostgresCopyNumericFieldReader {
public:
ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
ArrowError* error) override {
// -1 for NULL
if (field_size_bytes < 0) {
return ArrowArrayAppendNull(array, 1);
}

// Number of decimal digits per Postgres digit
static const int kDecDigits = 4;
// The "base" of the Postgres representation (i.e., each "digit" is 0 to 9999)
static const int kNBase = 10000;
// Valid values for the sign component
static const uint16_t kNumericPos = 0x0000;
static const uint16_t kNumericNeg = 0x4000;
static const uint16_t kNumericNAN = 0xC000;
static const uint16_t kNumericPinf = 0xD000;
static const uint16_t kNumericNinf = 0xF000;
ArrowErrorCode digit_error = ReadInputDigit(data, error);
if (digit_error) {
return digit_error;
}

double value;
bool special_value = false;
switch (sign_) {
case kNumericNAN:
value = std::numeric_limits<double>::quiet_NaN();
special_value = true;
break;
case kNumericPinf:
value = std::numeric_limits<double>::infinity();
special_value = true;
break;
case kNumericNinf:
value = -std::numeric_limits<double>::infinity();
special_value = true;
break;
case kNumericPos:
case kNumericNeg:
break;
default:
ArrowErrorSet(error,
"Unexpected value for sign read from Postgres numeric field: %d",
static_cast<int>(sign_));
return EINVAL;
}

if (special_value) {
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &value, sizeof(double)));
return AppendValid(array);
}

int64_t max_chars_required = max_chars_required_();
char* target = new char[max_chars_required];
int64_t actual_characters_required = DigitsToString(&target);
target[actual_characters_required + 1] = '\0';

value = strtod(target, NULL);
delete[] target;

NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &value, sizeof(double)));
return AppendValid(array);
}
};

// Reader for Pg->Arrow conversions whose Arrow representation is simply the
Expand Down Expand Up @@ -761,6 +877,9 @@ static inline ArrowErrorCode MakeCopyFieldReader(
case PostgresTypeId::kFloat8:
*out = std::make_unique<PostgresCopyNetworkEndianFieldReader<uint64_t>>();
return NANOARROW_OK;
case PostgresTypeId::kNumeric:
*out = std::make_unique<PostgresCopyNumericToDoubleFieldReader>();
return NANOARROW_OK;
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
Expand All @@ -776,7 +895,7 @@ static inline ArrowErrorCode MakeCopyFieldReader(
*out = std::make_unique<PostgresCopyBinaryFieldReader>();
return NANOARROW_OK;
case PostgresTypeId::kNumeric:
*out = std::make_unique<PostgresCopyNumericFieldReader>();
*out = std::make_unique<PostgresCopyNumericToStrFieldReader>();
return NANOARROW_OK;
default:
return ErrorCantConvert(error, pg_type, schema_view);
Expand Down
3 changes: 2 additions & 1 deletion c/driver/postgresql/postgres_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ class PostgresType {

// ---- Numeric/Decimal-------------------
case PostgresTypeId::kNumeric:
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING));
// NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING));
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_DOUBLE));
break;

// ---- Binary/string --------------------
Expand Down

0 comments on commit 6e308e7

Please sign in to comment.