diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a138c8a..06c12bdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 4.1.2 (2025-08-22) +- Streaming ingestion support for PUT operation (databricks/databricks-sql-python#643 by @sreekanth-db) +- Removed use_threads argument on concat_tables for compatibility with pyarrow<14 (databricks/databricks-sql-python#684 by @jprakash-db) + # 4.1.1 (2025-08-21) - Add documentation for proxy support (databricks/databricks-sql-python#680 by @vikrantpuppala) - Fix compatibility with urllib3<2 and add CI actions to improve dependency checks (databricks/databricks-sql-python#678 by @vikrantpuppala) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py new file mode 100644 index 00000000..f615d082 --- /dev/null +++ b/examples/query_tags_example.py @@ -0,0 +1,30 @@ +import os +import databricks.sql as sql + +""" +This example demonstrates how to use Query Tags. + +Query Tags are key-value pairs that can be attached to SQL executions and will appear +in the system.query.history table for analytical purposes. + +Format: "key1:value1,key2:value2,key3:value3" +""" + +print("=== Query Tags Example ===\n") + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + session_configuration={ + 'QUERY_TAGS': 'team:engineering,test:query-tags', + 'ansi_mode': False + } +) as connection: + + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + result = cursor.fetchone() + print(f" Result: {result[0]}") + +print("\n=== Query Tags Example Complete ===") \ No newline at end of file diff --git a/examples/streaming_put.py b/examples/streaming_put.py new file mode 100644 index 00000000..4e769709 --- /dev/null +++ b/examples/streaming_put.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +""" +Simple example of streaming PUT operations. + +This demonstrates the basic usage of streaming PUT with the __input_stream__ token. +""" + +import io +import os +from databricks import sql + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Create a simple data stream + data = b"Hello, streaming world!" + stream = io.BytesIO(data) + + # Get catalog, schema, and volume from environment variables + catalog = os.getenv("DATABRICKS_CATALOG") + schema = os.getenv("DATABRICKS_SCHEMA") + volume = os.getenv("DATABRICKS_VOLUME") + + # Upload to Unity Catalog volume + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", + input_stream=stream + ) + + print("File uploaded successfully!") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0c0342d2..6f0f7471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.1.1" +version = "4.1.2" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index f7c9f5e9..31b5cbb7 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -68,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "4.1.1" +__version__ = "4.1.2" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 46ce8c98..61ecf969 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -15,6 +15,7 @@ "STATEMENT_TIMEOUT": "0", "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", + "QUERY_TAGS": "", } diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index 02c88aa6..d2b10e71 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -735,7 +735,7 @@ def convert_col(t_column_desc): return pyarrow.schema([convert_col(col) for col in t_table_schema.columns]) @staticmethod - def _col_to_description(col, session_id_hex=None): + def _col_to_description(col, field=None, session_id_hex=None): type_entry = col.typeDesc.types[0] if type_entry.primitiveEntry: @@ -764,12 +764,39 @@ def _col_to_description(col, session_id_hex=None): else: precision, scale = None, None + # Extract variant type from field if available + if field is not None: + try: + # Check for variant type in metadata + if field.metadata and b"Spark:DataType:SqlName" in field.metadata: + sql_type = field.metadata.get(b"Spark:DataType:SqlName") + if sql_type == b"VARIANT": + cleaned_type = "variant" + except Exception as e: + logger.debug(f"Could not extract variant type from field: {e}") + return col.columnName, cleaned_type, None, None, precision, scale, None @staticmethod - def _hive_schema_to_description(t_table_schema, session_id_hex=None): + def _hive_schema_to_description( + t_table_schema, schema_bytes=None, session_id_hex=None + ): + field_dict = {} + if pyarrow and schema_bytes: + try: + arrow_schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + # Build a dictionary mapping column names to fields + for field in arrow_schema: + field_dict[field.name] = field + except Exception as e: + logger.debug(f"Could not parse arrow schema: {e}") + return [ - ThriftDatabricksClient._col_to_description(col, session_id_hex) + ThriftDatabricksClient._col_to_description( + col, + field_dict.get(col.columnName) if field_dict else None, + session_id_hex, + ) for col in t_table_schema.columns ] @@ -802,11 +829,6 @@ def _results_message_to_execute_response(self, resp, operation_state): or direct_results.resultSet.hasMoreRows ) - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -819,6 +841,12 @@ def _results_message_to_execute_response(self, resp, operation_state): else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed command_id = CommandId.from_thrift_handle(resp.operationHandle) @@ -863,11 +891,6 @@ def get_execution_result( t_result_set_metadata_resp = resp.resultSetMetadata - description = self._hive_schema_to_description( - t_result_set_metadata_resp.schema, - self._session_id_hex, - ) - if pyarrow: schema_bytes = ( t_result_set_metadata_resp.arrowSchema @@ -880,6 +903,12 @@ def get_execution_result( else: schema_bytes = None + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema, + schema_bytes, + self._session_id_hex, + ) + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation has_more_rows = resp.hasMoreRows diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 3cd7bcac..78a01142 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Tuple, List, Optional, Any, Union, Sequence +from typing import Dict, Tuple, List, Optional, Any, Union, Sequence, BinaryIO import pandas try: @@ -662,7 +662,9 @@ def _check_not_closed(self): ) def _handle_staging_operation( - self, staging_allowed_local_path: Union[None, str, List[str]] + self, + staging_allowed_local_path: Union[None, str, List[str]], + input_stream: Optional[BinaryIO] = None, ): """Fetch the HTTP request instruction from a staging ingestion command and call the designated handler. @@ -671,6 +673,28 @@ def _handle_staging_operation( is not descended from staging_allowed_local_path. """ + assert self.active_result_set is not None + row = self.active_result_set.fetchone() + assert row is not None + + # May be real headers, or could be json string + headers = ( + json.loads(row.headers) if isinstance(row.headers, str) else row.headers + ) + headers = dict(headers) if headers else {} + + # Handle __input_stream__ token for PUT operations + if ( + row.operation == "PUT" + and getattr(row, "localFile", None) == "__input_stream__" + ): + return self._handle_staging_put_stream( + presigned_url=row.presignedUrl, + stream=input_stream, + headers=headers, + ) + + # For non-streaming operations, validate staging_allowed_local_path if isinstance(staging_allowed_local_path, type(str())): _staging_allowed_local_paths = [staging_allowed_local_path] elif isinstance(staging_allowed_local_path, type(list())): @@ -685,10 +709,6 @@ def _handle_staging_operation( os.path.abspath(i) for i in _staging_allowed_local_paths ] - assert self.active_result_set is not None - row = self.active_result_set.fetchone() - assert row is not None - # Must set to None in cases where server response does not include localFile abs_localFile = None @@ -711,19 +731,16 @@ def _handle_staging_operation( session_id_hex=self.connection.get_session_id_hex(), ) - # May be real headers, or could be json string - headers = ( - json.loads(row.headers) if isinstance(row.headers, str) else row.headers - ) - handler_args = { "presigned_url": row.presignedUrl, "local_file": abs_localFile, - "headers": dict(headers) or {}, + "headers": headers, } logger.debug( - f"Attempting staging operation indicated by server: {row.operation} - {getattr(row, 'localFile', '')}" + "Attempting staging operation indicated by server: %s - %s", + row.operation, + getattr(row, "localFile", ""), ) # TODO: Create a retry loop here to re-attempt if the request times out or fails @@ -762,6 +779,10 @@ def _handle_staging_put( HttpMethod.PUT, presigned_url, body=fh.read(), headers=headers ) + self._handle_staging_http_response(r) + + def _handle_staging_http_response(self, r): + # fmt: off # HTTP status codes OK = 200 @@ -784,6 +805,37 @@ def _handle_staging_put( + "but not yet applied on the server. It's possible this command may fail later." ) + @log_latency(StatementType.SQL) + def _handle_staging_put_stream( + self, + presigned_url: str, + stream: BinaryIO, + headers: dict = {}, + ) -> None: + """Handle PUT operation with streaming data. + + Args: + presigned_url: The presigned URL for upload + stream: Binary stream to upload + headers: HTTP headers + + Raises: + ProgrammingError: If no input stream is provided + OperationalError: If the upload fails + """ + + if not stream: + raise ProgrammingError( + "No input stream provided for streaming operation", + session_id_hex=self.connection.get_session_id_hex(), + ) + + r = self.connection.http_client.request( + HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers + ) + + self._handle_staging_http_response(r) + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None @@ -840,6 +892,7 @@ def execute( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + input_stream: Optional[BinaryIO] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -914,7 +967,8 @@ def execute( if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.connection.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path, + input_stream=input_stream, ) return self diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9e621464..9f96e874 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -298,7 +298,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def remaining_rows(self) -> "pyarrow.Table": """ @@ -321,7 +321,7 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: """Create next table at the given row offset""" @@ -880,7 +880,7 @@ def concat_table_chunks( result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) else: - return pyarrow.concat_tables(table_chunks, use_threads=True) + return pyarrow.concat_tables(table_chunks) def build_client_context(server_hostname: str, version: str, **kwargs): diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index aeeb6797..dd7c5699 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -117,7 +117,7 @@ def test_long_running_query(self, extra_params): scale_factor = 1 with self.cursor(extra_params) as cursor: while duration < min_duration: - assert scale_factor < 1024, "Detected infinite loop" + assert scale_factor < 4096, "Detected infinite loop" start = time.time() cursor.execute( diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py new file mode 100644 index 00000000..83da10fd --- /dev/null +++ b/tests/e2e/common/streaming_put_tests.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +E2E tests for streaming PUT operations. +""" + +import io +import logging +import pytest + +logger = logging.getLogger(__name__) + + +class PySQLStreamingPutTestSuiteMixin: + """Test suite for streaming PUT operations.""" + + def test_streaming_put_basic(self, catalog, schema): + """Test basic streaming PUT functionality.""" + + # Create test data + test_data = b"Hello, streaming world! This is test data." + filename = "streaming_put_test.txt" + file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" + + try: + with self.connection() as conn: + with conn.cursor() as cursor: + self._cleanup_test_file(file_path) + + with io.BytesIO(test_data) as stream: + cursor.execute( + f"PUT '__input_stream__' INTO '{file_path}'", + input_stream=stream + ) + + # Verify file exists + cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") + files = cursor.fetchall() + + # Check if our file is in the list + file_paths = [row[0] for row in files] + assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + finally: + self._cleanup_test_file(file_path) + + def test_streaming_put_missing_stream(self, catalog, schema): + """Test that missing stream raises appropriate error.""" + + with self.connection() as conn: + with conn.cursor() as cursor: + # Test without providing stream + with pytest.raises(Exception): # Should fail + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" + # Note: No input_stream parameter + ) + + def _cleanup_test_file(self, file_path): + """Clean up a test file if it exists.""" + try: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"REMOVE '{file_path}'") + logger.info("Successfully cleaned up test file: %s", file_path) + except Exception as e: + logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 53b7383e..e04e348c 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,6 +50,7 @@ from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin +from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin from databricks.sql.exc import SessionAlreadyClosedError @@ -290,6 +291,7 @@ class TestPySQLCoreSuite( PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin, + PySQLStreamingPutTestSuiteMixin, ): validate_row_value_type = True validate_result = True @@ -899,7 +901,7 @@ def test_timestamps_arrow(self): ) def test_multi_timestamps_arrow(self, extra_params): with self.cursor( - {"session_configuration": {"ansi_mode": False}, **extra_params} + {"session_configuration": {"ansi_mode": False, "query_tags": "test:multi-timestamps,driver:python"}, **extra_params} ) as cursor: query, expected = self.multi_query() expected = [ diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py new file mode 100644 index 00000000..b5dc1f42 --- /dev/null +++ b/tests/e2e/test_variant_types.py @@ -0,0 +1,91 @@ +import pytest +from datetime import datetime +import json + +try: + import pyarrow +except ImportError: + pyarrow = None + +from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow + + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") +class TestVariantTypes(PySQLPytestTestCase): + """Tests for the proper detection and handling of VARIANT type columns""" + + @pytest.fixture(scope="class") + def variant_table(self, connection_details): + """A pytest fixture that creates a test table and cleans up after tests""" + self.arguments = connection_details.copy() + table_name = "pysql_test_variant_types_table" + + with self.cursor() as cursor: + try: + # Create the table with variant columns + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + id INTEGER, + variant_col VARIANT, + regular_string_col STRING + ) + """ + ) + + # Insert test records with different variant values + cursor.execute( + """ + INSERT INTO pysql_test_variant_types_table + VALUES + (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') + """ + ) + yield table_name + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + def test_variant_type_detection(self, variant_table): + """Test that VARIANT type columns are properly detected in schema""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") + + # Verify column types in description + assert ( + cursor.description[0][1] == "int" + ), "Integer column type not correctly identified" + assert ( + cursor.description[1][1] == "variant" + ), "VARIANT column type not correctly identified" + assert ( + cursor.description[2][1] == "string" + ), "String column type not correctly identified" + + def test_variant_data_retrieval(self, variant_table): + """Test that VARIANT data is properly retrieved and can be accessed as JSON""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id") + rows = cursor.fetchall() + + # First row should have a JSON object + json_obj = rows[0][1] + assert isinstance( + json_obj, str + ), "VARIANT column should be returned as string" + + parsed = json.loads(json_obj) + assert parsed.get("name") == "John" + assert parsed.get("age") == 30 + + # Second row should have a JSON array + json_array = rows[1][1] + assert isinstance( + json_array, str + ), "VARIANT array should be returned as string" + + # Parsing to verify it's valid JSON array + parsed_array = json.loads(json_array) + assert isinstance(parsed_array, list) + assert parsed_array == [1, 2, 3, 4] diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f287..26a898cb 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,6 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -196,6 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -641,6 +643,7 @@ def test_filter_session_configuration(self): "TIMEZONE": "UTC", "enable_photon": False, "MAX_FILE_PARTITION_BYTES": 128.5, + "QUERY_TAGS": "team:engineering,project:data-pipeline", "unsupported_param": "value", "ANOTHER_UNSUPPORTED": 42, } @@ -663,6 +666,7 @@ def test_filter_session_configuration(self): "timezone": "UTC", # string -> "UTC", key lowercased "enable_photon": "False", # boolean False -> "False", key lowercased "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased + "query_tags": "team:engineering,project:data-pipeline", } assert result == expected_result @@ -683,12 +687,14 @@ def test_filter_session_configuration(self): "ansi_mode": "false", # lowercase key "STATEMENT_TIMEOUT": 7200, # uppercase key "TiMeZoNe": "America/New_York", # mixed case key + "QueRy_TaGs": "team:marketing,test:case-insensitive", } result = _filter_session_configuration(case_insensitive_config) expected_case_result = { "ansi_mode": "false", "statement_timeout": "7200", "timezone": "America/New_York", + "query_tags": "team:marketing,test:case-insensitive", } assert result == expected_case_result diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e019e05a..c135a846 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -155,7 +155,7 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() + mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"} databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) diff --git a/tests/unit/test_streaming_put.py b/tests/unit/test_streaming_put.py new file mode 100644 index 00000000..2b9a9e6d --- /dev/null +++ b/tests/unit/test_streaming_put.py @@ -0,0 +1,113 @@ +import io +from unittest.mock import patch, Mock, MagicMock + +import pytest + +import databricks.sql.client as client + + +class TestStreamingPut: + """Unit tests for streaming PUT functionality.""" + + @pytest.fixture + def cursor(self): + return client.Cursor(connection=Mock(), backend=Mock()) + + def _setup_mock_staging_put_stream_response(self, mock_backend): + """Helper method to set up mock staging PUT stream response.""" + mock_result_set = Mock() + mock_result_set.is_staging_operation = True + mock_backend.execute_command.return_value = mock_result_set + + mock_row = Mock() + mock_row.operation = "PUT" + mock_row.localFile = "__input_stream__" + mock_row.presignedUrl = "https://example.com/upload" + mock_row.headers = "{}" + mock_result_set.fetchone.return_value = mock_row + + return mock_result_set + + def test_execute_with_valid_stream(self, cursor): + """Test execute method with valid input stream.""" + + # Mock the backend response + self._setup_mock_staging_put_stream_response(cursor.backend) + + # Test with valid stream + test_stream = io.BytesIO(b"test data") + + with patch.object(cursor, "_handle_staging_put_stream") as mock_handler: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=test_stream, + ) + + # Verify staging handler was called + mock_handler.assert_called_once() + + def test_execute_with_none_stream_for_staging_put(self, cursor): + """Test execute method rejects None stream for streaming PUT operations.""" + + # Mock staging operation response for None case + self._setup_mock_staging_put_stream_response(cursor.backend) + + # None with __input_stream__ raises ProgrammingError + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None, + ) + error_msg = str(excinfo.value) + assert "No input stream provided for streaming operation" in error_msg + + def test_handle_staging_put_stream_success(self, cursor): + """Test successful streaming PUT operation.""" + + presigned_url = "https://example.com/upload" + headers = {"Content-Type": "text/plain"} + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.data = b"success" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream, headers=headers + ) + + # Verify the HTTP client was called correctly + mock_http_request.assert_called_once() + call_args = mock_http_request.call_args + # Check positional arguments: (method, url, body=..., headers=...) + assert call_args[0][0].value == "PUT" # First positional arg is method + assert call_args[0][1] == presigned_url # Second positional arg is url + # Check keyword arguments + assert call_args[1]["body"] == b"test data" + assert call_args[1]["headers"] == headers + + def test_handle_staging_put_stream_http_error(self, cursor): + """Test streaming PUT operation with HTTP error.""" + + presigned_url = "https://example.com/upload" + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 500 + mock_response.data = b"Internal Server Error" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + with pytest.raises(client.OperationalError) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream + ) + + # Check for the actual error message format + assert "500" in str(excinfo.value) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0445ace3..7254b66c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( @@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow ) + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_col_to_description(self): + test_cases = [ + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"), + ("normal_col", {}, "string"), + ("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"), + ("missing_field", None, "string"), # None field case + ] + + for column_name, field_metadata, expected_type in test_cases: + with self.subTest(column_name=column_name, expected_type=expected_type): + col = ttypes.TColumnDesc( + columnName=column_name, + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + field = ( + None + if field_metadata is None + else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata) + ) + + result = ThriftDatabricksClient._col_to_description(col, field) + + self.assertEqual(result[0], column_name) + self.assertEqual(result[1], expected_type) + self.assertIsNone(result[2]) + self.assertIsNone(result[3]) + self.assertIsNone(result[4]) + self.assertIsNone(result[5]) + self.assertIsNone(result[6]) + + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_hive_schema_to_description(self): + test_cases = [ + ( + [ + ("regular_col", ttypes.TTypeId.STRING_TYPE), + ("variant_col", ttypes.TTypeId.STRING_TYPE), + ], + [ + ("regular_col", {}), + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}), + ], + [("regular_col", "string"), ("variant_col", "variant")], + ), + ( + [("regular_col", ttypes.TTypeId.STRING_TYPE)], + None, # No arrow schema + [("regular_col", "string")], + ), + ] + + for columns, arrow_fields, expected_types in test_cases: + with self.subTest(arrow_fields=arrow_fields is not None): + t_table_schema = ttypes.TTableSchema( + columns=[ + ttypes.TColumnDesc( + columnName=name, typeDesc=self._make_type_desc(col_type) + ) + for name, col_type in columns + ] + ) + + schema_bytes = None + if arrow_fields: + fields = [ + pyarrow.field(name, pyarrow.string(), metadata=metadata) + for name, metadata in arrow_fields + ] + schema_bytes = pyarrow.schema(fields).serialize().to_pybytes() + + description = ThriftDatabricksClient._hive_schema_to_description( + t_table_schema, schema_bytes + ) + + for i, (expected_name, expected_type) in enumerate(expected_types): + self.assertEqual(description[i][0], expected_name) + self.assertEqual(description[i][1], expected_type) + if __name__ == "__main__": unittest.main()