Skip to content

Commit 048fae1

Browse files
authored
[PECOBLR-201] add variant support (#560)
1 parent 415fb53 commit 048fae1

File tree

3 files changed

+214
-14
lines changed

3 files changed

+214
-14
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def convert_col(t_column_desc):
735735
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
736736

737737
@staticmethod
738-
def _col_to_description(col, session_id_hex=None):
738+
def _col_to_description(col, field=None, session_id_hex=None):
739739
type_entry = col.typeDesc.types[0]
740740

741741
if type_entry.primitiveEntry:
@@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None):
764764
else:
765765
precision, scale = None, None
766766

767+
# Extract variant type from field if available
768+
if field is not None:
769+
try:
770+
# Check for variant type in metadata
771+
if field.metadata and b"Spark:DataType:SqlName" in field.metadata:
772+
sql_type = field.metadata.get(b"Spark:DataType:SqlName")
773+
if sql_type == b"VARIANT":
774+
cleaned_type = "variant"
775+
except Exception as e:
776+
logger.debug(f"Could not extract variant type from field: {e}")
777+
767778
return col.columnName, cleaned_type, None, None, precision, scale, None
768779

769780
@staticmethod
770-
def _hive_schema_to_description(t_table_schema, session_id_hex=None):
781+
def _hive_schema_to_description(
782+
t_table_schema, schema_bytes=None, session_id_hex=None
783+
):
784+
field_dict = {}
785+
if pyarrow and schema_bytes:
786+
try:
787+
arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
788+
# Build a dictionary mapping column names to fields
789+
for field in arrow_schema:
790+
field_dict[field.name] = field
791+
except Exception as e:
792+
logger.debug(f"Could not parse arrow schema: {e}")
793+
771794
return [
772-
ThriftDatabricksClient._col_to_description(col, session_id_hex)
795+
ThriftDatabricksClient._col_to_description(
796+
col,
797+
field_dict.get(col.columnName) if field_dict else None,
798+
session_id_hex,
799+
)
773800
for col in t_table_schema.columns
774801
]
775802

@@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
802829
or direct_results.resultSet.hasMoreRows
803830
)
804831

805-
description = self._hive_schema_to_description(
806-
t_result_set_metadata_resp.schema,
807-
self._session_id_hex,
808-
)
809-
810832
if pyarrow:
811833
schema_bytes = (
812834
t_result_set_metadata_resp.arrowSchema
@@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state):
819841
else:
820842
schema_bytes = None
821843

844+
description = self._hive_schema_to_description(
845+
t_result_set_metadata_resp.schema,
846+
schema_bytes,
847+
self._session_id_hex,
848+
)
849+
822850
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
823851
command_id = CommandId.from_thrift_handle(resp.operationHandle)
824852

@@ -863,11 +891,6 @@ def get_execution_result(
863891

864892
t_result_set_metadata_resp = resp.resultSetMetadata
865893

866-
description = self._hive_schema_to_description(
867-
t_result_set_metadata_resp.schema,
868-
self._session_id_hex,
869-
)
870-
871894
if pyarrow:
872895
schema_bytes = (
873896
t_result_set_metadata_resp.arrowSchema
@@ -880,6 +903,12 @@ def get_execution_result(
880903
else:
881904
schema_bytes = None
882905

906+
description = self._hive_schema_to_description(
907+
t_result_set_metadata_resp.schema,
908+
schema_bytes,
909+
self._session_id_hex,
910+
)
911+
883912
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
884913
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
885914
has_more_rows = resp.hasMoreRows

tests/e2e/test_variant_types.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import pytest
2+
from datetime import datetime
3+
import json
4+
5+
try:
6+
import pyarrow
7+
except ImportError:
8+
pyarrow = None
9+
10+
from tests.e2e.test_driver import PySQLPytestTestCase
11+
from tests.e2e.common.predicates import pysql_supports_arrow
12+
13+
14+
@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
15+
class TestVariantTypes(PySQLPytestTestCase):
16+
"""Tests for the proper detection and handling of VARIANT type columns"""
17+
18+
@pytest.fixture(scope="class")
19+
def variant_table(self, connection_details):
20+
"""A pytest fixture that creates a test table and cleans up after tests"""
21+
self.arguments = connection_details.copy()
22+
table_name = "pysql_test_variant_types_table"
23+
24+
with self.cursor() as cursor:
25+
try:
26+
# Create the table with variant columns
27+
cursor.execute(
28+
"""
29+
CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table (
30+
id INTEGER,
31+
variant_col VARIANT,
32+
regular_string_col STRING
33+
)
34+
"""
35+
)
36+
37+
# Insert test records with different variant values
38+
cursor.execute(
39+
"""
40+
INSERT INTO pysql_test_variant_types_table
41+
VALUES
42+
(1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'),
43+
(2, PARSE_JSON('[1, 2, 3, 4]'), 'another string')
44+
"""
45+
)
46+
yield table_name
47+
finally:
48+
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
49+
50+
def test_variant_type_detection(self, variant_table):
51+
"""Test that VARIANT type columns are properly detected in schema"""
52+
with self.cursor() as cursor:
53+
cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0")
54+
55+
# Verify column types in description
56+
assert (
57+
cursor.description[0][1] == "int"
58+
), "Integer column type not correctly identified"
59+
assert (
60+
cursor.description[1][1] == "variant"
61+
), "VARIANT column type not correctly identified"
62+
assert (
63+
cursor.description[2][1] == "string"
64+
), "String column type not correctly identified"
65+
66+
def test_variant_data_retrieval(self, variant_table):
67+
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
68+
with self.cursor() as cursor:
69+
cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id")
70+
rows = cursor.fetchall()
71+
72+
# First row should have a JSON object
73+
json_obj = rows[0][1]
74+
assert isinstance(
75+
json_obj, str
76+
), "VARIANT column should be returned as string"
77+
78+
parsed = json.loads(json_obj)
79+
assert parsed.get("name") == "John"
80+
assert parsed.get("age") == 30
81+
82+
# Second row should have a JSON array
83+
json_array = rows[1][1]
84+
assert isinstance(
85+
json_array, str
86+
), "VARIANT array should be returned as string"
87+
88+
# Parsing to verify it's valid JSON array
89+
parsed_array = json.loads(json_array)
90+
assert isinstance(parsed_array, list)
91+
assert parsed_array == [1, 2, 3, 4]

tests/unit/test_thrift_backend.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly(
23302330
[],
23312331
auth_provider=AuthProvider(),
23322332
ssl_options=SSLOptions(),
2333-
http_client=MagicMock(),
2333+
http_client=MagicMock(),
23342334
**complex_arg_types,
23352335
)
23362336
thrift_backend.execute_command(
@@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly(
23562356
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
23572357
)
23582358

2359+
@unittest.skipIf(pyarrow is None, "Requires pyarrow")
2360+
def test_col_to_description(self):
2361+
test_cases = [
2362+
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"),
2363+
("normal_col", {}, "string"),
2364+
("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"),
2365+
("missing_field", None, "string"), # None field case
2366+
]
2367+
2368+
for column_name, field_metadata, expected_type in test_cases:
2369+
with self.subTest(column_name=column_name, expected_type=expected_type):
2370+
col = ttypes.TColumnDesc(
2371+
columnName=column_name,
2372+
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
2373+
)
2374+
2375+
field = (
2376+
None
2377+
if field_metadata is None
2378+
else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata)
2379+
)
2380+
2381+
result = ThriftDatabricksClient._col_to_description(col, field)
2382+
2383+
self.assertEqual(result[0], column_name)
2384+
self.assertEqual(result[1], expected_type)
2385+
self.assertIsNone(result[2])
2386+
self.assertIsNone(result[3])
2387+
self.assertIsNone(result[4])
2388+
self.assertIsNone(result[5])
2389+
self.assertIsNone(result[6])
2390+
2391+
@unittest.skipIf(pyarrow is None, "Requires pyarrow")
2392+
def test_hive_schema_to_description(self):
2393+
test_cases = [
2394+
(
2395+
[
2396+
("regular_col", ttypes.TTypeId.STRING_TYPE),
2397+
("variant_col", ttypes.TTypeId.STRING_TYPE),
2398+
],
2399+
[
2400+
("regular_col", {}),
2401+
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}),
2402+
],
2403+
[("regular_col", "string"), ("variant_col", "variant")],
2404+
),
2405+
(
2406+
[("regular_col", ttypes.TTypeId.STRING_TYPE)],
2407+
None, # No arrow schema
2408+
[("regular_col", "string")],
2409+
),
2410+
]
2411+
2412+
for columns, arrow_fields, expected_types in test_cases:
2413+
with self.subTest(arrow_fields=arrow_fields is not None):
2414+
t_table_schema = ttypes.TTableSchema(
2415+
columns=[
2416+
ttypes.TColumnDesc(
2417+
columnName=name, typeDesc=self._make_type_desc(col_type)
2418+
)
2419+
for name, col_type in columns
2420+
]
2421+
)
2422+
2423+
schema_bytes = None
2424+
if arrow_fields:
2425+
fields = [
2426+
pyarrow.field(name, pyarrow.string(), metadata=metadata)
2427+
for name, metadata in arrow_fields
2428+
]
2429+
schema_bytes = pyarrow.schema(fields).serialize().to_pybytes()
2430+
2431+
description = ThriftDatabricksClient._hive_schema_to_description(
2432+
t_table_schema, schema_bytes
2433+
)
2434+
2435+
for i, (expected_name, expected_type) in enumerate(expected_types):
2436+
self.assertEqual(description[i][0], expected_name)
2437+
self.assertEqual(description[i][1], expected_type)
2438+
23592439

23602440
if __name__ == "__main__":
23612441
unittest.main()

0 commit comments

Comments
 (0)