Skip to content

[PECOBLR-201] add variant support #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def convert_col(t_column_desc):
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])

@staticmethod
def _col_to_description(col, session_id_hex=None):
def _col_to_description(col, field=None, session_id_hex=None):
type_entry = col.typeDesc.types[0]

if type_entry.primitiveEntry:
Expand Down Expand Up @@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None):
else:
precision, scale = None, None

# Extract variant type from field if available
if field is not None:
try:
# Check for variant type in metadata
if field.metadata and b"Spark:DataType:SqlName" in field.metadata:
sql_type = field.metadata.get(b"Spark:DataType:SqlName")
if sql_type == b"VARIANT":
cleaned_type = "variant"
except Exception as e:
logger.debug(f"Could not extract variant type from field: {e}")

return col.columnName, cleaned_type, None, None, precision, scale, None

@staticmethod
def _hive_schema_to_description(t_table_schema, session_id_hex=None):
def _hive_schema_to_description(
t_table_schema, schema_bytes=None, session_id_hex=None
):
field_dict = {}
if pyarrow and schema_bytes:
try:
arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
# Build a dictionary mapping column names to fields
for field in arrow_schema:
field_dict[field.name] = field
except Exception as e:
logger.debug(f"Could not parse arrow schema: {e}")

return [
ThriftDatabricksClient._col_to_description(col, session_id_hex)
ThriftDatabricksClient._col_to_description(
col,
field_dict.get(col.columnName) if field_dict else None,
session_id_hex,
)
for col in t_table_schema.columns
]

Expand Down Expand Up @@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state):
or direct_results.resultSet.hasMoreRows
)

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
self._session_id_hex,
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
Expand All @@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state):
else:
schema_bytes = None

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
schema_bytes,
self._session_id_hex,
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
command_id = CommandId.from_thrift_handle(resp.operationHandle)

Expand Down Expand Up @@ -863,11 +891,6 @@ def get_execution_result(

t_result_set_metadata_resp = resp.resultSetMetadata

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
self._session_id_hex,
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
Expand All @@ -880,6 +903,12 @@ def get_execution_result(
else:
schema_bytes = None

description = self._hive_schema_to_description(
t_result_set_metadata_resp.schema,
schema_bytes,
self._session_id_hex,
)

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
has_more_rows = resp.hasMoreRows
Expand Down
91 changes: 91 additions & 0 deletions tests/e2e/test_variant_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from datetime import datetime
import json

try:
import pyarrow
except ImportError:
pyarrow = None

from tests.e2e.test_driver import PySQLPytestTestCase
from tests.e2e.common.predicates import pysql_supports_arrow


@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support")
class TestVariantTypes(PySQLPytestTestCase):
"""Tests for the proper detection and handling of VARIANT type columns"""

@pytest.fixture(scope="class")
def variant_table(self, connection_details):
"""A pytest fixture that creates a test table and cleans up after tests"""
self.arguments = connection_details.copy()
table_name = "pysql_test_variant_types_table"

with self.cursor() as cursor:
try:
# Create the table with variant columns
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table (
id INTEGER,
variant_col VARIANT,
regular_string_col STRING
)
"""
)

# Insert test records with different variant values
cursor.execute(
"""
INSERT INTO pysql_test_variant_types_table
VALUES
(1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'),
(2, PARSE_JSON('[1, 2, 3, 4]'), 'another string')
"""
)
yield table_name
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")

def test_variant_type_detection(self, variant_table):
"""Test that VARIANT type columns are properly detected in schema"""
with self.cursor() as cursor:
cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0")

# Verify column types in description
assert (
cursor.description[0][1] == "int"
), "Integer column type not correctly identified"
assert (
cursor.description[1][1] == "variant"
), "VARIANT column type not correctly identified"
assert (
cursor.description[2][1] == "string"
), "String column type not correctly identified"

def test_variant_data_retrieval(self, variant_table):
"""Test that VARIANT data is properly retrieved and can be accessed as JSON"""
with self.cursor() as cursor:
cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id")
rows = cursor.fetchall()

# First row should have a JSON object
json_obj = rows[0][1]
assert isinstance(
json_obj, str
), "VARIANT column should be returned as string"

parsed = json.loads(json_obj)
assert parsed.get("name") == "John"
assert parsed.get("age") == 30

# Second row should have a JSON array
json_array = rows[1][1]
assert isinstance(
json_array, str
), "VARIANT array should be returned as string"

# Parsing to verify it's valid JSON array
parsed_array = json.loads(json_array)
assert isinstance(parsed_array, list)
assert parsed_array == [1, 2, 3, 4]
82 changes: 81 additions & 1 deletion tests/unit/test_thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly(
[],
auth_provider=AuthProvider(),
ssl_options=SSLOptions(),
http_client=MagicMock(),
http_client=MagicMock(),
**complex_arg_types,
)
thrift_backend.execute_command(
Expand All @@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly(
t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow
)

@unittest.skipIf(pyarrow is None, "Requires pyarrow")
def test_col_to_description(self):
test_cases = [
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"),
("normal_col", {}, "string"),
("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"),
("missing_field", None, "string"), # None field case
]

for column_name, field_metadata, expected_type in test_cases:
with self.subTest(column_name=column_name, expected_type=expected_type):
col = ttypes.TColumnDesc(
columnName=column_name,
typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE),
)

field = (
None
if field_metadata is None
else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata)
)

result = ThriftDatabricksClient._col_to_description(col, field)

self.assertEqual(result[0], column_name)
self.assertEqual(result[1], expected_type)
self.assertIsNone(result[2])
self.assertIsNone(result[3])
self.assertIsNone(result[4])
self.assertIsNone(result[5])
self.assertIsNone(result[6])

@unittest.skipIf(pyarrow is None, "Requires pyarrow")
def test_hive_schema_to_description(self):
test_cases = [
(
[
("regular_col", ttypes.TTypeId.STRING_TYPE),
("variant_col", ttypes.TTypeId.STRING_TYPE),
],
[
("regular_col", {}),
("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}),
],
[("regular_col", "string"), ("variant_col", "variant")],
),
(
[("regular_col", ttypes.TTypeId.STRING_TYPE)],
None, # No arrow schema
[("regular_col", "string")],
),
]

for columns, arrow_fields, expected_types in test_cases:
with self.subTest(arrow_fields=arrow_fields is not None):
t_table_schema = ttypes.TTableSchema(
columns=[
ttypes.TColumnDesc(
columnName=name, typeDesc=self._make_type_desc(col_type)
)
for name, col_type in columns
]
)

schema_bytes = None
if arrow_fields:
fields = [
pyarrow.field(name, pyarrow.string(), metadata=metadata)
for name, metadata in arrow_fields
]
schema_bytes = pyarrow.schema(fields).serialize().to_pybytes()

description = ThriftDatabricksClient._hive_schema_to_description(
t_table_schema, schema_bytes
)

for i, (expected_name, expected_type) in enumerate(expected_types):
self.assertEqual(description[i][0], expected_name)
self.assertEqual(description[i][1], expected_type)


if __name__ == "__main__":
unittest.main()
Loading