Skip to content

Commit 428dc18

Browse files
authored
More avoid cudf.dtype internally in favor of pre-defined, supported types (#17918)
Continuation of #17839 Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #17918
1 parent 61e47bb commit 428dc18

File tree

7 files changed

+72
-68
lines changed

7 files changed

+72
-68
lines changed

python/cudf/cudf/core/buffer/buffer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
22

33
from __future__ import annotations
44

@@ -13,7 +13,6 @@
1313
import pylibcudf
1414
import rmm
1515

16-
import cudf
1716
from cudf.core.abc import Serializable
1817
from cudf.utils.string import format_bytes
1918

@@ -504,7 +503,7 @@ def get_ptr_and_size(array_interface: Mapping) -> tuple[int, int]:
504503

505504
shape = array_interface["shape"] or (1,)
506505
strides = array_interface["strides"]
507-
itemsize = cudf.dtype(array_interface["typestr"]).itemsize
506+
itemsize = numpy.dtype(array_interface["typestr"]).itemsize
508507
if strides is None or pylibcudf.column.is_c_contiguous(
509508
shape, strides, itemsize
510509
):

python/cudf/cudf/core/dtypes.py

-5
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,6 @@ def dtype(arbitrary):
6464
raise TypeError(f"Unsupported type {np_dtype}")
6565
return np_dtype
6666

67-
if isinstance(arbitrary, str) and arbitrary in {"hex", "hex32", "hex64"}:
68-
# read_csv only accepts "hex"
69-
# e.g. test_csv_reader_hexadecimals, test_csv_reader_hexadecimal_overflow
70-
return arbitrary
71-
7267
# use `pandas_dtype` to try and interpret
7368
# `arbitrary` as a Pandas extension type.
7469
# Return the corresponding NumPy/cuDF type.

python/cudf/cudf/core/scalar.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,16 @@ def __repr__(self) -> str:
476476
# https://github.com/numpy/numpy/issues/17552
477477
return f"{self.__class__.__name__}({self.value!s}, dtype={self.dtype})"
478478

479-
def _binop_result_dtype_or_error(self, other, op):
479+
def _binop_result_dtype_or_error(self, other, op) -> np.dtype:
480480
if op in {"__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"}:
481-
return np.bool_
481+
return np.dtype(np.bool_)
482482

483483
out_dtype = get_allowed_combinations_for_operator(
484484
self.dtype, other.dtype, op
485485
)
486486

487487
# datetime handling
488-
if out_dtype in {"M", "m"}:
488+
if out_dtype.kind in {"M", "m"}:
489489
if self.dtype.char in {"M", "m"} and other.dtype.char not in {
490490
"M",
491491
"m",
@@ -505,7 +505,7 @@ def _binop_result_dtype_or_error(self, other, op):
505505
return np.dtype(f"m8[{res}]")
506506
return np.result_type(self.dtype, other.dtype)
507507

508-
return cudf.dtype(out_dtype)
508+
return out_dtype
509509

510510
def _binaryop(self, other, op: str):
511511
if is_scalar(other):

python/cudf/cudf/core/tools/numeric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def to_numeric(
174174
type_set = list(np.typecodes["UnsignedInteger"])
175175

176176
for t in type_set:
177-
downcast_dtype = cudf.dtype(t)
177+
downcast_dtype = np.dtype(t)
178178
if downcast_dtype.itemsize <= col.dtype.itemsize:
179179
if col.can_cast_safely(downcast_dtype):
180180
col = col.cast(downcast_dtype)

python/cudf/cudf/io/csv.py

+53-46
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections import abc
99
from io import BytesIO, StringIO
10-
from typing import cast
10+
from typing import TYPE_CHECKING, cast
1111

1212
import numpy as np
1313
import pandas as pd
@@ -16,7 +16,7 @@
1616

1717
import cudf
1818
from cudf._lib.column import Column
19-
from cudf.api.types import is_hashable, is_scalar
19+
from cudf.api.types import is_scalar
2020
from cudf.core.buffer import acquire_spill_lock
2121
from cudf.core.column_accessor import ColumnAccessor
2222
from cudf.utils import ioutils
@@ -26,6 +26,10 @@
2626
)
2727
from cudf.utils.performance_tracking import _performance_tracking
2828

29+
if TYPE_CHECKING:
30+
from cudf._typing import DtypeObj
31+
32+
2933
_CSV_HEX_TYPE_MAP = {
3034
"hex": np.dtype("int64"),
3135
"hex64": np.dtype("int64"),
@@ -158,33 +162,49 @@ def read_csv(
158162
header = 0
159163

160164
hex_cols: list[abc.Hashable] = []
161-
new_dtypes: list[plc.DataType] | dict[abc.Hashable, plc.DataType] = []
165+
cudf_dtypes: list[DtypeObj] | dict[abc.Hashable, DtypeObj] | DtypeObj = []
166+
plc_dtypes: list[plc.DataType] | dict[abc.Hashable, plc.DataType] = []
162167
if dtype is not None:
163168
if isinstance(dtype, abc.Mapping):
164-
new_dtypes = {}
169+
plc_dtypes = {}
170+
cudf_dtypes = {}
165171
for k, col_type in dtype.items():
166-
if is_hashable(col_type) and col_type in _CSV_HEX_TYPE_MAP:
172+
if isinstance(col_type, str) and col_type in _CSV_HEX_TYPE_MAP:
167173
col_type = _CSV_HEX_TYPE_MAP[col_type]
168174
hex_cols.append(str(k))
169175

170-
new_dtypes[k] = _get_plc_data_type_from_dtype(
171-
cudf.dtype(col_type)
172-
)
173-
elif cudf.api.types.is_scalar(dtype) or isinstance(
174-
dtype, (np.dtype, pd.api.extensions.ExtensionDtype, type)
176+
cudf_dtype = cudf.dtype(col_type)
177+
cudf_dtypes[k] = cudf_dtype
178+
plc_dtypes[k] = _get_plc_data_type_from_dtype(cudf_dtype)
179+
elif isinstance(
180+
dtype,
181+
(
182+
str,
183+
np.dtype,
184+
pd.api.extensions.ExtensionDtype,
185+
cudf.core.dtypes._BaseDtype,
186+
type,
187+
),
175188
):
176-
if is_hashable(dtype) and dtype in _CSV_HEX_TYPE_MAP:
189+
if isinstance(dtype, str) and dtype in _CSV_HEX_TYPE_MAP:
177190
dtype = _CSV_HEX_TYPE_MAP[dtype]
178191
hex_cols.append(0)
179-
180-
cast(list, new_dtypes).append(_get_plc_data_type_from_dtype(dtype))
192+
else:
193+
dtype = cudf.dtype(dtype)
194+
cudf_dtypes = dtype
195+
cast(list, plc_dtypes).append(_get_plc_data_type_from_dtype(dtype))
181196
elif isinstance(dtype, abc.Collection):
182197
for index, col_dtype in enumerate(dtype):
183-
if is_hashable(col_dtype) and col_dtype in _CSV_HEX_TYPE_MAP:
198+
if (
199+
isinstance(col_dtype, str)
200+
and col_dtype in _CSV_HEX_TYPE_MAP
201+
):
184202
col_dtype = _CSV_HEX_TYPE_MAP[col_dtype]
185203
hex_cols.append(index)
186-
187-
new_dtypes.append(_get_plc_data_type_from_dtype(col_dtype))
204+
else:
205+
col_dtype = cudf.dtype(col_dtype)
206+
cudf_dtypes.append(col_dtype)
207+
plc_dtypes.append(_get_plc_data_type_from_dtype(col_dtype))
188208
else:
189209
raise ValueError(
190210
"dtype should be a scalar/str/list-like/dict-like"
@@ -243,7 +263,7 @@ def read_csv(
243263
if hex_cols is not None:
244264
options.set_parse_hex(list(hex_cols))
245265

246-
options.set_dtypes(new_dtypes)
266+
options.set_dtypes(plc_dtypes)
247267

248268
if true_values is not None:
249269
options.set_true_values([str(val) for val in true_values])
@@ -266,15 +286,21 @@ def read_csv(
266286
ca = ColumnAccessor(data, rangeindex=len(data) == 0)
267287
df = cudf.DataFrame._from_data(ca)
268288

269-
if isinstance(dtype, abc.Mapping):
270-
for k, v in dtype.items():
271-
if isinstance(cudf.dtype(v), cudf.CategoricalDtype):
272-
df._data[str(k)] = df._data[str(k)].astype(v)
273-
elif dtype == "category" or isinstance(dtype, cudf.CategoricalDtype):
289+
# Cast result to categorical if specified in dtype=
290+
# since categorical is not handled in pylibcudf
291+
if isinstance(cudf_dtypes, dict):
292+
to_category = {
293+
k: v
294+
for k, v in cudf_dtypes.items()
295+
if isinstance(v, cudf.CategoricalDtype)
296+
}
297+
if to_category:
298+
df = df.astype(to_category)
299+
elif isinstance(cudf_dtypes, cudf.CategoricalDtype):
274300
df = df.astype(dtype)
275-
elif isinstance(dtype, abc.Collection) and not is_scalar(dtype):
276-
for index, col_dtype in enumerate(dtype):
277-
if isinstance(cudf.dtype(col_dtype), cudf.CategoricalDtype):
301+
elif isinstance(cudf_dtypes, list):
302+
for index, col_dtype in enumerate(cudf_dtypes):
303+
if isinstance(col_dtype, cudf.CategoricalDtype):
278304
col_name = df._column_names[index]
279305
df._data[col_name] = df._data[col_name].astype(col_dtype)
280306

@@ -527,30 +553,11 @@ def _validate_args(
527553
)
528554

529555

530-
def _get_plc_data_type_from_dtype(dtype) -> plc.DataType:
556+
def _get_plc_data_type_from_dtype(dtype: DtypeObj) -> plc.DataType:
531557
# TODO: Remove this work-around Dictionary types
532558
# in libcudf are fully mapped to categorical columns:
533559
# https://github.com/rapidsai/cudf/issues/3960
534560
if isinstance(dtype, cudf.CategoricalDtype):
561+
# TODO: should we do this generally in dtype_to_pylibcudf_type?
535562
dtype = dtype.categories.dtype
536-
elif dtype == "category":
537-
dtype = "str"
538-
539-
if isinstance(dtype, str):
540-
if dtype == "date32":
541-
return plc.DataType(plc.types.TypeId.TIMESTAMP_DAYS)
542-
elif dtype in ("date", "date64"):
543-
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
544-
elif dtype == "timestamp":
545-
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
546-
elif dtype == "timestamp[us]":
547-
return plc.DataType(plc.types.TypeId.TIMESTAMP_MICROSECONDS)
548-
elif dtype == "timestamp[s]":
549-
return plc.DataType(plc.types.TypeId.TIMESTAMP_SECONDS)
550-
elif dtype == "timestamp[ms]":
551-
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
552-
elif dtype == "timestamp[ns]":
553-
return plc.DataType(plc.types.TypeId.TIMESTAMP_NANOSECONDS)
554-
555-
dtype = cudf.dtype(dtype)
556563
return dtype_to_pylibcudf_type(dtype)

python/cudf/cudf/io/parquet.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def write_to_dataset(
527527
return metadata
528528

529529

530-
def _parse_metadata(meta) -> tuple[bool, Any, Any]:
530+
def _parse_metadata(meta) -> tuple[bool, Any, None | np.dtype]:
531531
file_is_range_index = False
532532
file_index_cols = None
533533
file_column_dtype = None
@@ -541,7 +541,7 @@ def _parse_metadata(meta) -> tuple[bool, Any, Any]:
541541
):
542542
file_is_range_index = True
543543
if "column_indexes" in meta and len(meta["column_indexes"]) == 1:
544-
file_column_dtype = meta["column_indexes"][0]["numpy_type"]
544+
file_column_dtype = np.dtype(meta["column_indexes"][0]["numpy_type"])
545545
return file_is_range_index, file_index_cols, file_column_dtype
546546

547547

@@ -2368,6 +2368,6 @@ def _process_metadata(
23682368
df.index.names = index_col
23692369

23702370
if df._num_columns == 0 and column_index_type is not None:
2371-
df._data.label_dtype = cudf.dtype(column_index_type)
2371+
df._data.label_dtype = column_index_type
23722372

23732373
return df

python/cudf/cudf/utils/dtypes.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,9 @@ def _get_nan_for_dtype(dtype: DtypeObj) -> DtypeObj:
430430
return np.float64("nan")
431431

432432

433-
def get_allowed_combinations_for_operator(dtype_l, dtype_r, op):
433+
def get_allowed_combinations_for_operator(
434+
dtype_l: np.dtype, dtype_r: np.dtype, op: str
435+
) -> np.dtype:
434436
error = TypeError(
435437
f"{op} not supported between {dtype_l} and {dtype_r} scalars"
436438
)
@@ -456,18 +458,19 @@ def get_allowed_combinations_for_operator(dtype_l, dtype_r, op):
456458
# special rules for string
457459
if dtype_l == "object" or dtype_r == "object":
458460
if (dtype_l == dtype_r == "object") and op == "__add__":
459-
return "str"
461+
return CUDF_STRING_DTYPE
460462
else:
461463
raise error
462464

463465
# Check if we can directly operate
464466

465467
for valid_combo in allowed:
466-
ltype, rtype, outtype = valid_combo
467-
if np.can_cast(dtype_l.char, ltype) and np.can_cast(
468-
dtype_r.char, rtype
468+
ltype, rtype, outtype = valid_combo # type: ignore[misc]
469+
if np.can_cast(dtype_l.char, ltype) and np.can_cast( # type: ignore[has-type]
470+
dtype_r.char,
471+
rtype, # type: ignore[has-type]
469472
):
470-
return outtype
473+
return np.dtype(outtype) # type: ignore[has-type]
471474

472475
raise error
473476

0 commit comments

Comments
 (0)