7
7
import warnings
8
8
from collections import abc
9
9
from io import BytesIO , StringIO
10
- from typing import cast
10
+ from typing import TYPE_CHECKING , cast
11
11
12
12
import numpy as np
13
13
import pandas as pd
16
16
17
17
import cudf
18
18
from cudf ._lib .column import Column
19
- from cudf .api .types import is_hashable , is_scalar
19
+ from cudf .api .types import is_scalar
20
20
from cudf .core .buffer import acquire_spill_lock
21
21
from cudf .core .column_accessor import ColumnAccessor
22
22
from cudf .utils import ioutils
26
26
)
27
27
from cudf .utils .performance_tracking import _performance_tracking
28
28
29
+ if TYPE_CHECKING :
30
+ from cudf ._typing import DtypeObj
31
+
32
+
29
33
_CSV_HEX_TYPE_MAP = {
30
34
"hex" : np .dtype ("int64" ),
31
35
"hex64" : np .dtype ("int64" ),
@@ -158,33 +162,49 @@ def read_csv(
158
162
header = 0
159
163
160
164
hex_cols : list [abc .Hashable ] = []
161
- new_dtypes : list [plc .DataType ] | dict [abc .Hashable , plc .DataType ] = []
165
+ cudf_dtypes : list [DtypeObj ] | dict [abc .Hashable , DtypeObj ] | DtypeObj = []
166
+ plc_dtypes : list [plc .DataType ] | dict [abc .Hashable , plc .DataType ] = []
162
167
if dtype is not None :
163
168
if isinstance (dtype , abc .Mapping ):
164
- new_dtypes = {}
169
+ plc_dtypes = {}
170
+ cudf_dtypes = {}
165
171
for k , col_type in dtype .items ():
166
- if is_hashable (col_type ) and col_type in _CSV_HEX_TYPE_MAP :
172
+ if isinstance (col_type , str ) and col_type in _CSV_HEX_TYPE_MAP :
167
173
col_type = _CSV_HEX_TYPE_MAP [col_type ]
168
174
hex_cols .append (str (k ))
169
175
170
- new_dtypes [k ] = _get_plc_data_type_from_dtype (
171
- cudf .dtype (col_type )
172
- )
173
- elif cudf .api .types .is_scalar (dtype ) or isinstance (
174
- dtype , (np .dtype , pd .api .extensions .ExtensionDtype , type )
176
+ cudf_dtype = cudf .dtype (col_type )
177
+ cudf_dtypes [k ] = cudf_dtype
178
+ plc_dtypes [k ] = _get_plc_data_type_from_dtype (cudf_dtype )
179
+ elif isinstance (
180
+ dtype ,
181
+ (
182
+ str ,
183
+ np .dtype ,
184
+ pd .api .extensions .ExtensionDtype ,
185
+ cudf .core .dtypes ._BaseDtype ,
186
+ type ,
187
+ ),
175
188
):
176
- if is_hashable (dtype ) and dtype in _CSV_HEX_TYPE_MAP :
189
+ if isinstance (dtype , str ) and dtype in _CSV_HEX_TYPE_MAP :
177
190
dtype = _CSV_HEX_TYPE_MAP [dtype ]
178
191
hex_cols .append (0 )
179
-
180
- cast (list , new_dtypes ).append (_get_plc_data_type_from_dtype (dtype ))
192
+ else :
193
+ dtype = cudf .dtype (dtype )
194
+ cudf_dtypes = dtype
195
+ cast (list , plc_dtypes ).append (_get_plc_data_type_from_dtype (dtype ))
181
196
elif isinstance (dtype , abc .Collection ):
182
197
for index , col_dtype in enumerate (dtype ):
183
- if is_hashable (col_dtype ) and col_dtype in _CSV_HEX_TYPE_MAP :
198
+ if (
199
+ isinstance (col_dtype , str )
200
+ and col_dtype in _CSV_HEX_TYPE_MAP
201
+ ):
184
202
col_dtype = _CSV_HEX_TYPE_MAP [col_dtype ]
185
203
hex_cols .append (index )
186
-
187
- new_dtypes .append (_get_plc_data_type_from_dtype (col_dtype ))
204
+ else :
205
+ col_dtype = cudf .dtype (col_dtype )
206
+ cudf_dtypes .append (col_dtype )
207
+ plc_dtypes .append (_get_plc_data_type_from_dtype (col_dtype ))
188
208
else :
189
209
raise ValueError (
190
210
"dtype should be a scalar/str/list-like/dict-like"
@@ -243,7 +263,7 @@ def read_csv(
243
263
if hex_cols is not None :
244
264
options .set_parse_hex (list (hex_cols ))
245
265
246
- options .set_dtypes (new_dtypes )
266
+ options .set_dtypes (plc_dtypes )
247
267
248
268
if true_values is not None :
249
269
options .set_true_values ([str (val ) for val in true_values ])
@@ -266,15 +286,21 @@ def read_csv(
266
286
ca = ColumnAccessor (data , rangeindex = len (data ) == 0 )
267
287
df = cudf .DataFrame ._from_data (ca )
268
288
269
- if isinstance (dtype , abc .Mapping ):
270
- for k , v in dtype .items ():
271
- if isinstance (cudf .dtype (v ), cudf .CategoricalDtype ):
272
- df ._data [str (k )] = df ._data [str (k )].astype (v )
273
- elif dtype == "category" or isinstance (dtype , cudf .CategoricalDtype ):
289
+ # Cast result to categorical if specified in dtype=
290
+ # since categorical is not handled in pylibcudf
291
+ if isinstance (cudf_dtypes , dict ):
292
+ to_category = {
293
+ k : v
294
+ for k , v in cudf_dtypes .items ()
295
+ if isinstance (v , cudf .CategoricalDtype )
296
+ }
297
+ if to_category :
298
+ df = df .astype (to_category )
299
+ elif isinstance (cudf_dtypes , cudf .CategoricalDtype ):
274
300
df = df .astype (dtype )
275
- elif isinstance (dtype , abc . Collection ) and not is_scalar ( dtype ):
276
- for index , col_dtype in enumerate (dtype ):
277
- if isinstance (cudf . dtype ( col_dtype ) , cudf .CategoricalDtype ):
301
+ elif isinstance (cudf_dtypes , list ):
302
+ for index , col_dtype in enumerate (cudf_dtypes ):
303
+ if isinstance (col_dtype , cudf .CategoricalDtype ):
278
304
col_name = df ._column_names [index ]
279
305
df ._data [col_name ] = df ._data [col_name ].astype (col_dtype )
280
306
@@ -527,30 +553,11 @@ def _validate_args(
527
553
)
528
554
529
555
530
- def _get_plc_data_type_from_dtype (dtype ) -> plc .DataType :
556
+ def _get_plc_data_type_from_dtype (dtype : DtypeObj ) -> plc .DataType :
531
557
# TODO: Remove this work-around Dictionary types
532
558
# in libcudf are fully mapped to categorical columns:
533
559
# https://github.com/rapidsai/cudf/issues/3960
534
560
if isinstance (dtype , cudf .CategoricalDtype ):
561
+ # TODO: should we do this generally in dtype_to_pylibcudf_type?
535
562
dtype = dtype .categories .dtype
536
- elif dtype == "category" :
537
- dtype = "str"
538
-
539
- if isinstance (dtype , str ):
540
- if dtype == "date32" :
541
- return plc .DataType (plc .types .TypeId .TIMESTAMP_DAYS )
542
- elif dtype in ("date" , "date64" ):
543
- return plc .DataType (plc .types .TypeId .TIMESTAMP_MILLISECONDS )
544
- elif dtype == "timestamp" :
545
- return plc .DataType (plc .types .TypeId .TIMESTAMP_MILLISECONDS )
546
- elif dtype == "timestamp[us]" :
547
- return plc .DataType (plc .types .TypeId .TIMESTAMP_MICROSECONDS )
548
- elif dtype == "timestamp[s]" :
549
- return plc .DataType (plc .types .TypeId .TIMESTAMP_SECONDS )
550
- elif dtype == "timestamp[ms]" :
551
- return plc .DataType (plc .types .TypeId .TIMESTAMP_MILLISECONDS )
552
- elif dtype == "timestamp[ns]" :
553
- return plc .DataType (plc .types .TypeId .TIMESTAMP_NANOSECONDS )
554
-
555
- dtype = cudf .dtype (dtype )
556
563
return dtype_to_pylibcudf_type (dtype )
0 commit comments