@@ -235,8 +235,8 @@ def compare_cursor_description(
235
235
236
236
def _safe_compare (self , val1 , val2 ):
237
237
"""
238
- Safely compare two values, handling lists, dicts, and complex types .
239
-
238
+ Safely compare two values, handling Row objects and PyArrow tables .
239
+
240
240
Returns True if values are equal, False otherwise.
241
241
"""
242
242
try :
@@ -245,28 +245,40 @@ def _safe_compare(self, val1, val2):
245
245
return True
246
246
if val1 is None or val2 is None :
247
247
return False
248
-
249
- # For lists, tuples, and other sequences (but not strings)
248
+
249
+ # For Row objects, convert to dictionaries
250
+ if hasattr (val1 , "asDict" ) and hasattr (val2 , "asDict" ):
251
+ return self ._safe_compare (
252
+ val1 .asDict (recursive = True ), val2 .asDict (recursive = True )
253
+ )
254
+
255
+ # For PyArrow arrays/tables
256
+ if hasattr (val1 , "to_pylist" ) and hasattr (val2 , "to_pylist" ):
257
+ return val1 .to_pylist () == val2 .to_pylist ()
258
+
259
+ # For lists and tuples
250
260
if isinstance (val1 , (list , tuple )) and isinstance (val2 , (list , tuple )):
251
261
if len (val1 ) != len (val2 ):
252
262
return False
253
263
return all (self ._safe_compare (v1 , v2 ) for v1 , v2 in zip (val1 , val2 ))
254
-
264
+
255
265
# For dictionaries
256
266
if isinstance (val1 , dict ) and isinstance (val2 , dict ):
257
267
if set (val1 .keys ()) != set (val2 .keys ()):
258
268
return False
259
269
return all (self ._safe_compare (val1 [k ], val2 [k ]) for k in val1 .keys ())
260
-
261
- # For Row objects (which are tuples with special properties)
262
- if hasattr (val1 , 'asDict' ) and hasattr (val2 , 'asDict' ):
263
- return self ._safe_compare (val1 .asDict (recursive = True ), val2 .asDict (recursive = True ))
264
-
265
- # Default comparison
266
- return val1 == val2
267
- except (ValueError , TypeError ) as e :
268
- # If comparison fails (e.g., numpy arrays), convert to string
269
- return str (val1 ) == str (val2 )
270
+
271
+ # Default comparison - ensure we always return a boolean
272
+ result = val1 == val2
273
+ # If result is not a simple boolean, use bool() to convert it
274
+ return bool (result )
275
+
276
+ except (ValueError , TypeError ):
277
+ # Fallback to string comparison for problematic types
278
+ try :
279
+ return str (val1 ) == str (val2 )
280
+ except :
281
+ return False
270
282
271
283
def compare_rows (
272
284
self , thrift_rows : List [Row ], sea_rows : List [Row ], result : ComparisonResult
@@ -302,15 +314,17 @@ def compare_rows(
302
314
# Check if dictionaries are different by comparing all fields
303
315
all_fields = set (thrift_dict .keys ()) | set (sea_dict .keys ())
304
316
dicts_differ = False
305
-
317
+
306
318
for field in all_fields :
307
319
if field not in thrift_dict or field not in sea_dict :
308
320
dicts_differ = True
309
321
break
310
- elif not self ._safe_compare (thrift_dict .get (field ), sea_dict .get (field )):
322
+ elif not self ._safe_compare (
323
+ thrift_dict .get (field ), sea_dict .get (field )
324
+ ):
311
325
dicts_differ = True
312
326
break
313
-
327
+
314
328
if dicts_differ :
315
329
316
330
for field in all_fields :
@@ -353,9 +367,9 @@ def compare_rows(
353
367
thrift_values = [m [1 ] for m in mismatches ]
354
368
sea_values = [m [2 ] for m in mismatches ]
355
369
356
- if all (self . _safe_compare ( v , thrift_values [ 0 ]) for v in thrift_values ) and all (
357
- self ._safe_compare (v , sea_values [0 ]) for v in sea_values
358
- ):
370
+ if all (
371
+ self ._safe_compare (v , thrift_values [0 ]) for v in thrift_values
372
+ ) and all ( self . _safe_compare ( v , sea_values [ 0 ]) for v in sea_values ) :
359
373
result .add_difference (
360
374
f"Field '{ field } ' value mismatch in all rows" ,
361
375
thrift_values [0 ],
0 commit comments