diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 099e5bc48353a..dfb786fb362ba 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -748,7 +748,7 @@ Indexing Missing ^^^^^^^ - Bug in :meth:`DataFrame.fillna` and :meth:`Series.fillna` that would ignore the ``limit`` argument on :class:`.ExtensionArray` dtypes (:issue:`58001`) -- +- Bug in :meth:`DataFrame.fillna` where filling from another ``DataFrame`` of the same dtype could incorrectly cast the result to ``object`` dtype. (:issue:`61568`) MultiIndex ^^^^^^^^^^ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 13585d7de6beb..e387e96e4101e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -7145,7 +7145,24 @@ def fillna( else: new_data = self._mgr.fillna(value=value, limit=limit, inplace=inplace) elif isinstance(value, ABCDataFrame) and self.ndim == 2: - new_data = self.where(self.notna(), value)._mgr + filled_columns = {} + for col in self.columns: + lhs = self[col] + if col in value.columns: + rhs = value[col] + filled = lhs.where(notna(lhs), rhs) + + # restore original dtype if fallback to object occurred + if lhs.dtype == rhs.dtype and filled.dtype == object: + try: + filled = filled.astype(lhs.dtype) + except Exception: + pass + else: + filled = lhs + filled_columns[col] = filled + + new_data = type(self)(filled_columns, index=self.index)._mgr else: raise ValueError(f"invalid fill value with a {type(value)}") diff --git a/pandas/tests/frame/methods/test_fillna.py b/pandas/tests/frame/methods/test_fillna.py index 8915d6f205d65..fb2685a7aa2c6 100644 --- a/pandas/tests/frame/methods/test_fillna.py +++ b/pandas/tests/frame/methods/test_fillna.py @@ -795,3 +795,22 @@ def test_fillna_out_of_bounds_datetime(): msg = "Cannot cast 0001-01-01 00:00:00 to unit='ns' without overflow" with pytest.raises(OutOfBoundsDatetime, match=msg): df.fillna(Timestamp("0001-01-01")) + + +def test_fillna_dataframe_preserves_dtypes_mixed_columns(): + # GH#61568 + empty = DataFrame([[None] * 4] * 4, columns=list("ABCD"), dtype=np.float64) + full = DataFrame( + [ + [1.0, 2.0, "3.0", 4.0], + [5.0, 6.0, "7.0", 8.0], + [9.0, 10.0, "11.0", 12.0], + [13.0, 14.0, "15.0", 16.0], + ], + columns=list("ABCD"), + ) + result = empty.fillna(full) + expected_dtypes = Series( + {"A": "float64", "B": "float64", "C": "object", "D": "float64"} + ) + tm.assert_series_equal(result.dtypes, expected_dtypes)