From f50d9ab91279086b053c68458075996aa0434695 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 29 Jul 2025 14:50:16 +0530 Subject: [PATCH 1/3] preliminary complex types support Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 48 +++++++++++++++++ tests/e2e/test_complex_types.py | 27 +++++++--- tests/unit/test_client.py | 4 +- tests/unit/test_downloader.py | 6 ++- tests/unit/test_telemetry_retry.py | 56 +++++++++++++------- 5 files changed, 110 insertions(+), 31 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index a6a0a298b..b7d0b3428 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Any, List, Optional, TYPE_CHECKING import logging @@ -82,6 +83,47 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) + assert isinstance( + self.backend, SeaDatabricksClient + ), "SeaResultSet must be used with SeaDatabricksClient" + + def _convert_complex_types_to_string( + self, rows: "pyarrow.Table" + ) -> "pyarrow.Table": + """ + Convert complex types (array, struct, map) to string representation. + Args: + rows: Input PyArrow table + Returns: + PyArrow table with complex types converted to strings + """ + + if not pyarrow: + raise ImportError( + "PyArrow is not installed: _use_arrow_native_complex_types = False requires pyarrow" + ) + + def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array": + python_values = col.to_pylist() + json_strings = [ + (None if val is None else json.dumps(val)) for val in python_values + ] + return pyarrow.array(json_strings, type=pyarrow.string()) + + converted_columns = [] + for col in rows.columns: + converted_col = col + if ( + pyarrow.types.is_list(col.type) + or pyarrow.types.is_large_list(col.type) + or pyarrow.types.is_struct(col.type) + or pyarrow.types.is_map(col.type) + ): + converted_col = convert_complex_column_to_string(col) + converted_columns.append(converted_col) + + return pyarrow.Table.from_arrays(converted_columns, names=rows.column_names) + def _convert_json_types(self, row: List[str]) -> List[Any]: """ Convert string values in the row to appropriate Python types based on column metadata. @@ -200,6 +242,9 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results @@ -213,6 +258,9 @@ def fetchall_arrow(self) -> "pyarrow.Table": if isinstance(self.results, JsonQueue): results = self._convert_json_to_arrow_table(results) + if not self.backend._use_arrow_native_complex_types: + results = self._convert_complex_types_to_string(results) + self._next_row_index += results.num_rows return results diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index c8a3a0781..6b954449d 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -54,10 +54,19 @@ def table_fixture(self, connection_details): ("map_array_col", list), ], ) - def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_arrow( + self, field, expected_type, table_fixture, backend_params + ): """Confirms the return types of a complex type field when reading as arrow""" - with self.cursor() as cursor: + with self.cursor(extra_params=backend_params) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() @@ -75,11 +84,17 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): ("map_array_col"), ], ) - def test_read_complex_types_as_string(self, field, table_fixture): + @pytest.mark.parametrize( + "backend_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_read_complex_types_as_string(self, field, table_fixture, backend_params): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor( - extra_params={"_use_arrow_native_complex_types": False} - ) as cursor: + extra_params = {**backend_params, "_use_arrow_native_complex_types": False} + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 4271f0d7d..19375cde3 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -262,9 +262,7 @@ def test_negative_fetch_throws_exception(self): mock_backend = Mock() mock_backend.fetch_results.return_value = (Mock(), False, 0) - result_set = ThriftResultSet( - Mock(), Mock(), mock_backend - ) + result_set = ThriftResultSet(Mock(), Mock(), mock_backend) with self.assertRaises(ValueError) as e: result_set.fetchmany(-1) diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index ed782a801..c514980ee 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -26,12 +26,14 @@ class DownloaderTests(unittest.TestCase): def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] + def time_side_effect(): call_count[0] += 1 if call_count[0] <= 2: # First two calls (validation, start_time) return 1000 else: # All subsequent calls (logging, duration calculation) return end_time + mock_time.side_effect = time_side_effect @patch("time.time", return_value=1000) @@ -104,7 +106,7 @@ def test_run_get_response_not_ok(self, mock_time): @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) @@ -133,7 +135,7 @@ def test_run_uncompressed_successful(self, mock_time): @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index 11055b558..94137c5b1 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -6,7 +6,8 @@ from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory from databricks.sql.auth.retry import DatabricksRetryPolicy -PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' +PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" + def create_mock_conn(responses): """Creates a mock connection object whose getresponse() method yields a series of responses.""" @@ -16,15 +17,18 @@ def create_mock_conn(responses): mock_http_response = MagicMock() mock_http_response.status = resp.get("status") mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b'{}') + body = resp.get("body", b"{}") mock_http_response.fp = io.BytesIO(body) + def release(): mock_http_response.fp.close() + mock_http_response.release_conn = release mock_http_responses.append(mock_http_response) mock_conn.getresponse.side_effect = mock_http_responses return mock_conn + class TestTelemetryClientRetries: @pytest.fixture(autouse=True) def setup_and_teardown(self): @@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3): host_url="test.databricks.com", ) client = TelemetryClientFactory.get_telemetry_client(session_id) - + retry_policy = DatabricksRetryPolicy( delay_min=0.01, delay_max=0.02, stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, + stop_after_attempts_count=num_retries, delay_default=0.1, force_dangerous_codes=[], - urllib3_kwargs={'total': num_retries} + urllib3_kwargs={"total": num_retries}, ) adapter = client._http_client.session.adapters.get("https://") adapter.max_retries = retry_policy return client @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], ) def test_non_retryable_status_codes_are_not_retried(self, status_code, description): """ @@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti client = self.get_client(f"session-{status_code}") mock_responses = [{"status": status_code}] - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) @@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self): Verifies that the client respects the Retry-After header and retries on 429, 502, 503. """ num_retries = 3 - expected_total_calls = num_retries + 1 + expected_total_calls = num_retries + 1 retry_after = 1 client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + mock_responses = [ + {"status": 503, "headers": {"Retry-After": str(retry_after)}}, + {"status": 429}, + {"status": 502}, + {"status": 503}, + ] + + with patch( + PATCH_TARGET, return_value=create_mock_conn(mock_responses) + ) as mock_get_conn: start_time = time.time() client.export_failure_log("TestError", "Test message") TelemetryClientFactory.close(client._session_id_hex) end_time = time.time() - - assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls - assert end_time - start_time > retry_after \ No newline at end of file + + assert ( + mock_get_conn.return_value.getresponse.call_count + == expected_total_calls + ) + assert end_time - start_time > retry_after From 87d73da62da21f6737d1df47568e438cf870ecd3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 29 Jul 2025 14:52:09 +0530 Subject: [PATCH 2/3] nit: cleaner reorganise Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index b7d0b3428..8bbc36a8c 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -106,7 +106,7 @@ def _convert_complex_types_to_string( def convert_complex_column_to_string(col: "pyarrow.Array") -> "pyarrow.Array": python_values = col.to_pylist() json_strings = [ - (None if val is None else json.dumps(val)) for val in python_values + (json.dumps(val) if val is not None else None) for val in python_values ] return pyarrow.array(json_strings, type=pyarrow.string()) From f74836dcaed6f5c5c6c9de443235d2b9cb212c27 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 29 Jul 2025 16:15:13 +0530 Subject: [PATCH 3/3] fix type issues Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/result_set.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index 8bbc36a8c..ae25576da 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -27,6 +27,8 @@ class SeaResultSet(ResultSet): """ResultSet implementation for SEA backend.""" + backend: SeaDatabricksClient + def __init__( self, connection: Connection, @@ -83,10 +85,6 @@ def __init__( arrow_schema_bytes=execute_response.arrow_schema_bytes, ) - assert isinstance( - self.backend, SeaDatabricksClient - ), "SeaResultSet must be used with SeaDatabricksClient" - def _convert_complex_types_to_string( self, rows: "pyarrow.Table" ) -> "pyarrow.Table":