Skip to content

Commit ba36ebe

Browse files
safer comparision
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent a9b9006 commit ba36ebe

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

examples/experimental/comparator.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def compare_cursor_description(
235235

236236
def _safe_compare(self, val1, val2):
237237
"""
238-
Safely compare two values, handling lists, dicts, and complex types.
239-
238+
Safely compare two values, handling Row objects and PyArrow tables.
239+
240240
Returns True if values are equal, False otherwise.
241241
"""
242242
try:
@@ -245,28 +245,40 @@ def _safe_compare(self, val1, val2):
245245
return True
246246
if val1 is None or val2 is None:
247247
return False
248-
249-
# For lists, tuples, and other sequences (but not strings)
248+
249+
# For Row objects, convert to dictionaries
250+
if hasattr(val1, "asDict") and hasattr(val2, "asDict"):
251+
return self._safe_compare(
252+
val1.asDict(recursive=True), val2.asDict(recursive=True)
253+
)
254+
255+
# For PyArrow arrays/tables
256+
if hasattr(val1, "to_pylist") and hasattr(val2, "to_pylist"):
257+
return val1.to_pylist() == val2.to_pylist()
258+
259+
# For lists and tuples
250260
if isinstance(val1, (list, tuple)) and isinstance(val2, (list, tuple)):
251261
if len(val1) != len(val2):
252262
return False
253263
return all(self._safe_compare(v1, v2) for v1, v2 in zip(val1, val2))
254-
264+
255265
# For dictionaries
256266
if isinstance(val1, dict) and isinstance(val2, dict):
257267
if set(val1.keys()) != set(val2.keys()):
258268
return False
259269
return all(self._safe_compare(val1[k], val2[k]) for k in val1.keys())
260-
261-
# For Row objects (which are tuples with special properties)
262-
if hasattr(val1, 'asDict') and hasattr(val2, 'asDict'):
263-
return self._safe_compare(val1.asDict(recursive=True), val2.asDict(recursive=True))
264-
265-
# Default comparison
266-
return val1 == val2
267-
except (ValueError, TypeError) as e:
268-
# If comparison fails (e.g., numpy arrays), convert to string
269-
return str(val1) == str(val2)
270+
271+
# Default comparison - ensure we always return a boolean
272+
result = val1 == val2
273+
# If result is not a simple boolean, use bool() to convert it
274+
return bool(result)
275+
276+
except (ValueError, TypeError):
277+
# Fallback to string comparison for problematic types
278+
try:
279+
return str(val1) == str(val2)
280+
except:
281+
return False
270282

271283
def compare_rows(
272284
self, thrift_rows: List[Row], sea_rows: List[Row], result: ComparisonResult
@@ -302,15 +314,17 @@ def compare_rows(
302314
# Check if dictionaries are different by comparing all fields
303315
all_fields = set(thrift_dict.keys()) | set(sea_dict.keys())
304316
dicts_differ = False
305-
317+
306318
for field in all_fields:
307319
if field not in thrift_dict or field not in sea_dict:
308320
dicts_differ = True
309321
break
310-
elif not self._safe_compare(thrift_dict.get(field), sea_dict.get(field)):
322+
elif not self._safe_compare(
323+
thrift_dict.get(field), sea_dict.get(field)
324+
):
311325
dicts_differ = True
312326
break
313-
327+
314328
if dicts_differ:
315329

316330
for field in all_fields:
@@ -353,9 +367,9 @@ def compare_rows(
353367
thrift_values = [m[1] for m in mismatches]
354368
sea_values = [m[2] for m in mismatches]
355369

356-
if all(self._safe_compare(v, thrift_values[0]) for v in thrift_values) and all(
357-
self._safe_compare(v, sea_values[0]) for v in sea_values
358-
):
370+
if all(
371+
self._safe_compare(v, thrift_values[0]) for v in thrift_values
372+
) and all(self._safe_compare(v, sea_values[0]) for v in sea_values):
359373
result.add_difference(
360374
f"Field '{field}' value mismatch in all rows",
361375
thrift_values[0],

0 commit comments

Comments
 (0)