Skip to content

Commit 60b3204

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Fix StringDtype vs object dtype mismatch in Data._safecast_df (facebook#4795)
Summary: Prompted after some failures in exports not related to my changes: https://github.com/facebook/Ax/actions/runs/21225324302/job/61070638075?fbclid=IwY2xjawPeXMFleHRuA2FlbQIxMQBicmlkETFRTkR6WlE4NHVrd3IyQXNlc3J0YwZhcHBfaWQBMAABHjTAiZi71n24w95hvzEewrKNPKOGzJisgR7t4qJ3APRMYlusgFC-gu7RLiSb_aem_Zk3pmTDonCFsJvZCTkpeMA Pandas 2.0+ changed default string column dtype from `object` to `StringDtype(na_value=nan)`. The `_safecast_df()` method doesn't properly handle the comparison between `StringDtype` and `np.dtype("O")` because they are different types that don't compare equal. Add explicit check for `pd.StringDtype` to force casting when needed. Differential Revision: D91185469
1 parent 3e34f99 commit 60b3204

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

ax/core/data.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,13 @@ def _safecast_df(cls: type[TData], df: pd.DataFrame) -> pd.DataFrame:
195195
if col in df.columns.values and coltype is not Any:
196196
# Pandas timestamp handlng is weird
197197
dtype = "datetime64[ns]" if coltype is pd.Timestamp else coltype
198-
if (dtype != dtypes[col]) and not (
199-
coltype is int and df.loc[:, col].isnull().any()
200-
):
198+
current_dtype = dtypes[col]
199+
# Handle StringDtype -> object conversion (pandas 2.0+ compatibility)
200+
needs_cast = (
201+
isinstance(current_dtype, pd.StringDtype)
202+
or (dtype != current_dtype)
203+
) and not (coltype is int and df.loc[:, col].isnull().any())
204+
if needs_cast:
201205
df[col] = df[col].astype(dtype)
202206
df.reset_index(inplace=True, drop=True)
203207
return df

0 commit comments

Comments
 (0)