diff --git a/.github/workflows/code-quality-checks.yml b/.github/workflows/code-quality-checks.yml index b6db61a3c..22db995c5 100644 --- a/.github/workflows/code-quality-checks.yml +++ b/.github/workflows/code-quality-checks.yml @@ -8,6 +8,16 @@ jobs: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + dependency-version: ["default", "min"] + # Optimize matrix - test min/max on subset of Python versions + exclude: + - python-version: "3.12" + dependency-version: "min" + - python-version: "3.13" + dependency-version: "min" + + name: "Unit Tests (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" + steps: #---------------------------------------------- # check-out repo and set-up python @@ -37,7 +47,7 @@ jobs: uses: actions/cache@v4 with: path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- @@ -50,8 +60,31 @@ jobs: - name: Install library run: poetry install --no-interaction #---------------------------------------------- + # override with custom dependency versions + #---------------------------------------------- + - name: Install Python tools for custom versions + if: matrix.dependency-version != 'default' + run: poetry run pip install toml packaging + + - name: Generate requirements file + if: matrix.dependency-version != 'default' + run: | + poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}.txt + echo "Generated requirements for ${{ matrix.dependency-version }} versions:" + cat requirements-${{ matrix.dependency-version }}.txt + + - name: Override with custom dependency versions + if: matrix.dependency-version != 'default' + run: poetry run pip install -r requirements-${{ matrix.dependency-version }}.txt + + #---------------------------------------------- # run test suite #---------------------------------------------- + - name: Show installed versions + run: | + echo "=== Dependency Version: ${{ matrix.dependency-version }} ===" + poetry run pip list + - name: Run tests run: poetry run python -m pytest tests/unit run-unit-tests-with-arrow: @@ -59,6 +92,15 @@ jobs: strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + dependency-version: ["default", "min"] + exclude: + - python-version: "3.12" + dependency-version: "min" + - python-version: "3.13" + dependency-version: "min" + + name: "Unit Tests + PyArrow (Python ${{ matrix.python-version }}, ${{ matrix.dependency-version }} deps)" + steps: #---------------------------------------------- # check-out repo and set-up python @@ -88,7 +130,7 @@ jobs: uses: actions/cache@v4 with: path: .venv-pyarrow - key: venv-pyarrow-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} + key: venv-pyarrow-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ matrix.dependency-version }}-${{ github.event.repository.name }}-${{ hashFiles('**/poetry.lock') }} #---------------------------------------------- # install dependencies if cache does not exist #---------------------------------------------- @@ -101,8 +143,30 @@ jobs: - name: Install library run: poetry install --no-interaction --all-extras #---------------------------------------------- + # override with custom dependency versions + #---------------------------------------------- + - name: Install Python tools for custom versions + if: matrix.dependency-version != 'default' + run: poetry run pip install toml packaging + + - name: Generate requirements file with pyarrow + if: matrix.dependency-version != 'default' + run: | + poetry run python scripts/dependency_manager.py ${{ matrix.dependency-version }} --output requirements-${{ matrix.dependency-version }}-arrow.txt + echo "Generated requirements for ${{ matrix.dependency-version }} versions with PyArrow:" + cat requirements-${{ matrix.dependency-version }}-arrow.txt + + - name: Override with custom dependency versions + if: matrix.dependency-version != 'default' + run: poetry run pip install -r requirements-${{ matrix.dependency-version }}-arrow.txt + #---------------------------------------------- # run test suite #---------------------------------------------- + - name: Show installed versions + run: | + echo "=== Dependency Version: ${{ matrix.dependency-version }} with PyArrow ===" + poetry run pip list + - name: Run tests run: poetry run python -m pytest tests/unit check-linting: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c602d358..06c12bdc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History +# 4.1.2 (2025-08-22) +- Streaming ingestion support for PUT operation (databricks/databricks-sql-python#643 by @sreekanth-db) +- Removed use_threads argument on concat_tables for compatibility with pyarrow<14 (databricks/databricks-sql-python#684 by @jprakash-db) + +# 4.1.1 (2025-08-21) +- Add documentation for proxy support (databricks/databricks-sql-python#680 by @vikrantpuppala) +- Fix compatibility with urllib3<2 and add CI actions to improve dependency checks (databricks/databricks-sql-python#678 by @vikrantpuppala) + # 4.1.0 (2025-08-18) - Removed Codeowners (databricks/databricks-sql-python#623 by @jprakash-db) - Azure Service Principal Credential Provider (databricks/databricks-sql-python#621 by @jprakash-db) diff --git a/README.md b/README.md index a4c5a1307..d57efda1f 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ The Databricks SQL Connector for Python allows you to develop Python application This connector uses Arrow as the data-exchange format, and supports APIs (e.g. `fetchmany_arrow`) to directly fetch Arrow tables. Arrow tables are wrapped in the `ArrowQueue` class to provide a natural API to get several rows at a time. [PyArrow](https://arrow.apache.org/docs/python/index.html) is required to enable this and use these APIs, you can install it via `pip install pyarrow` or `pip install databricks-sql-connector[pyarrow]`. +The connector includes built-in support for HTTP/HTTPS proxy servers with multiple authentication methods including basic authentication and Kerberos/Negotiate authentication. See `docs/proxy.md` and `examples/proxy_authentication.py` for details. + You are welcome to file an issue here for general use cases. You can also contact Databricks Support [here](help.databricks.com). ## Requirements diff --git a/docs/proxy.md b/docs/proxy.md new file mode 100644 index 000000000..2e0bec292 --- /dev/null +++ b/docs/proxy.md @@ -0,0 +1,232 @@ +# Proxy Support + +The Databricks SQL Connector supports connecting through HTTP and HTTPS proxy servers with various authentication methods. This feature automatically detects system proxy configuration and handles proxy authentication transparently. + +## Quick Start + +The connector automatically uses your system's proxy configuration when available: + +```python +from databricks import sql + +# Basic connection - uses system proxy automatically +with sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/endpoints/your-endpoint-id", + access_token="your-token" +) as connection: + # Your queries here... +``` + +For advanced proxy authentication (like Kerberos), specify the authentication method: + +```python +with sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/endpoints/your-endpoint-id", + access_token="your-token", + _proxy_auth_method="negotiate" # Enable Kerberos proxy auth +) as connection: + # Your queries here... +``` + +## Proxy Configuration + +### Environment Variables + +The connector follows standard proxy environment variable conventions: + +| Variable | Description | Example | +|----------|-------------|---------| +| `HTTP_PROXY` | Proxy for HTTP requests | `http://proxy.company.com:8080` | +| `HTTPS_PROXY` | Proxy for HTTPS requests | `https://proxy.company.com:8080` | +| `NO_PROXY` | Hosts to bypass proxy | `localhost,127.0.0.1,.company.com` | + +**Note**: The connector also recognizes lowercase versions (`http_proxy`, `https_proxy`, `no_proxy`). + +### Proxy URL Formats + +Basic proxy (no authentication): +```bash +export HTTPS_PROXY="http://proxy.company.com:8080" +``` + +Proxy with basic authentication: +```bash +export HTTPS_PROXY="http://username:password@proxy.company.com:8080" +``` + +## Authentication Methods + +The connector supports multiple proxy authentication methods via the `_proxy_auth_method` parameter: + +### 1. Basic Authentication (`basic` or `None`) + +**Default behavior** when credentials are provided in the proxy URL or when `_proxy_auth_method="basic"` is specified. + +```python +# Method 1: Credentials in proxy URL (recommended) +# Set environment: HTTPS_PROXY="http://user:pass@proxy.company.com:8080" +with sql.connect( + server_hostname="your-workspace.com", + http_path="/sql/1.0/endpoints/abc123", + access_token="your-token" + # No _proxy_auth_method needed - detected automatically +) as conn: + pass + +# Method 2: Explicit basic authentication +with sql.connect( + server_hostname="your-workspace.com", + http_path="/sql/1.0/endpoints/abc123", + access_token="your-token", + _proxy_auth_method="basic" # Explicit basic auth +) as conn: + pass +``` + +### 2. Kerberos/Negotiate Authentication (`negotiate`) + +For corporate environments using Kerberos authentication with proxy servers. + +**Prerequisites:** +- Valid Kerberos tickets (run `kinit` first) +- Properly configured Kerberos environment + +```python +with sql.connect( + server_hostname="your-workspace.com", + http_path="/sql/1.0/endpoints/abc123", + access_token="your-token", + _proxy_auth_method="negotiate" # Enable Kerberos proxy auth +) as conn: + pass +``` + +**Kerberos Setup Example:** +```bash +# Obtain Kerberos tickets +kinit your-username@YOUR-DOMAIN.COM + +# Set proxy (no credentials in URL for Kerberos) +export HTTPS_PROXY="http://proxy.company.com:8080" + +# Run your Python script +python your_script.py +``` + +## Proxy Bypass + +The connector respects system proxy bypass rules. Requests to hosts listed in `NO_PROXY` or system bypass lists will connect directly, bypassing the proxy. + +```bash +# Bypass proxy for local and internal hosts +export NO_PROXY="localhost,127.0.0.1,*.internal.company.com,10.*" +``` + +## Advanced Configuration + +### Per-Request Proxy Decisions + +The connector automatically makes per-request decisions about proxy usage based on: + +1. **System proxy configuration** - Detected from environment variables +2. **Proxy bypass rules** - Honor `NO_PROXY` and system bypass settings +3. **Target host** - Check if the specific host should use proxy + +### Connection Pooling + +The connector maintains separate connection pools for direct and proxy connections, allowing efficient handling of mixed proxy/direct traffic. + +### SSL/TLS with Proxy + +HTTPS connections through HTTP proxies use the CONNECT method for SSL tunneling. The connector handles this automatically while preserving all SSL verification settings. + +## Troubleshooting + +### Common Issues + +**Problem**: Connection fails with proxy-related errors +``` +Solution: +1. Verify proxy environment variables are set correctly +2. Check if proxy requires authentication +3. Ensure proxy allows CONNECT method for HTTPS +4. Test proxy connectivity with curl: + curl -x $HTTPS_PROXY https://your-workspace.com +``` + +**Problem**: Kerberos authentication fails +``` +Solution: +1. Verify Kerberos tickets: klist +2. Renew tickets if expired: kinit +3. Check proxy supports negotiate authentication +4. Ensure time synchronization between client and KDC +``` + +**Problem**: Some requests bypass proxy unexpectedly +``` +Solution: +1. Check NO_PROXY environment variable +2. Review system proxy bypass settings +3. Verify the target hostname format +``` + +### Debug Logging + +Enable detailed logging to troubleshoot proxy issues: + +```python +import logging + +# Enable connector debug logging +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("databricks.sql").setLevel(logging.DEBUG) + +# Enable urllib3 logging for HTTP details +logging.getLogger("urllib3").setLevel(logging.DEBUG) +``` + +### Testing Proxy Configuration + +Use the provided example script to test different proxy authentication methods: + +```bash +cd examples/ +python proxy_authentication.py +``` + +This script tests: +- Default proxy behavior +- Basic authentication +- Kerberos/Negotiate authentication + +## Examples + +See `examples/proxy_authentication.py` for a comprehensive demonstration of proxy authentication methods. + +## Implementation Details + +### How Proxy Detection Works + +1. **Environment Variables**: Check `HTTP_PROXY`/`HTTPS_PROXY` environment variables +2. **System Configuration**: Use Python's `urllib.request.getproxies()` to detect system settings +3. **Bypass Rules**: Honor `NO_PROXY` and `urllib.request.proxy_bypass()` rules +4. **Per-Request Logic**: Decide proxy usage for each request based on target host + +### Supported Proxy Types + +- **HTTP Proxies**: For both HTTP and HTTPS traffic (via CONNECT) +- **HTTPS Proxies**: Encrypted proxy connections +- **Authentication**: Basic, Negotiate/Kerberos +- **Bypass Rules**: Full support for NO_PROXY patterns + +### Connection Architecture + +The connector uses a unified HTTP client that maintains: +- **Direct Pool Manager**: For non-proxy connections +- **Proxy Pool Manager**: For proxy connections +- **Per-Request Routing**: Automatic selection based on target host + +This architecture ensures optimal performance and correct proxy handling across all connector operations. diff --git a/examples/README.md b/examples/README.md index 43d248dab..d73c58a6b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,3 +42,4 @@ this example the string `ExamplePartnerTag` will be added to the the user agent - **`custom_cred_provider.py`** shows how to pass a custom credential provider to bypass connector authentication. Please install databricks-sdk prior to running this example. - **`v3_retries_query_execute.py`** shows how to enable v3 retries in connector version 2.9.x including how to enable retries for non-default retry cases. - **`parameters.py`** shows how to use parameters in native and inline modes. +- **`proxy_authentication.py`** demonstrates how to connect through proxy servers using different authentication methods including basic authentication and Kerberos/Negotiate authentication. diff --git a/examples/proxy_authentication.py b/examples/proxy_authentication.py new file mode 100644 index 000000000..8547336b3 --- /dev/null +++ b/examples/proxy_authentication.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Example: Databricks SQL Connector with Proxy Authentication + +This example demonstrates how to connect to Databricks through a proxy server +using different authentication methods: +1. Basic authentication (username/password in proxy URL) +2. Kerberos/Negotiate authentication +3. Default system proxy behavior + +Prerequisites: +- Configure your system proxy settings (HTTP_PROXY/HTTPS_PROXY environment variables) +- For Kerberos: Ensure you have valid Kerberos tickets (kinit) +- Set your Databricks credentials in environment variables +""" + +import os +from databricks import sql +import logging + +# Configure logging to see proxy activity +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# Uncomment for detailed debugging (shows HTTP requests/responses) +# logging.getLogger("urllib3").setLevel(logging.DEBUG) +# logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG) + +def check_proxy_environment(): + """Check if proxy environment variables are configured.""" + proxy_vars = ['HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy'] + configured_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)} + + if configured_proxies: + print("✓ Proxy environment variables found:") + for var, value in configured_proxies.items(): + # Hide credentials in output for security + safe_value = value.split('@')[-1] if '@' in value else value + print(f" {var}: {safe_value}") + return True + else: + print("⚠ No proxy environment variables found") + print(" Set HTTP_PROXY and/or HTTPS_PROXY if using a proxy") + return False + +def test_connection(connection_params, test_name): + """Test a database connection with given parameters.""" + print(f"\n--- Testing {test_name} ---") + + try: + with sql.connect(**connection_params) as connection: + print("✓ Successfully connected!") + + with connection.cursor() as cursor: + # Test basic query + cursor.execute("SELECT current_user() as user, current_database() as database") + result = cursor.fetchone() + print(f"✓ Connected as user: {result.user}") + print(f"✓ Default database: {result.database}") + + # Test a simple computation + cursor.execute("SELECT 1 + 1 as result") + result = cursor.fetchone() + print(f"✓ Query result: 1 + 1 = {result.result}") + + return True + + except Exception as e: + print(f"✗ Connection failed: {e}") + return False + +def main(): + print("Databricks SQL Connector - Proxy Authentication Examples") + print("=" * 60) + + # Check proxy configuration + has_proxy = check_proxy_environment() + + # Get Databricks connection parameters + server_hostname = os.environ.get('DATABRICKS_SERVER_HOSTNAME') + http_path = os.environ.get('DATABRICKS_HTTP_PATH') + access_token = os.environ.get('DATABRICKS_TOKEN') + + if not all([server_hostname, http_path, access_token]): + print("\n✗ Missing required environment variables:") + print(" DATABRICKS_SERVER_HOSTNAME") + print(" DATABRICKS_HTTP_PATH") + print(" DATABRICKS_TOKEN") + return 1 + + print(f"\nConnecting to: {server_hostname}") + + # Base connection parameters + base_params = { + 'server_hostname': server_hostname, + 'http_path': http_path, + 'access_token': access_token + } + + success_count = 0 + total_tests = 0 + + # Test 1: Default proxy behavior (no _proxy_auth_method specified) + # This uses basic auth if credentials are in proxy URL, otherwise no auth + print("\n" + "="*60) + print("Test 1: Default Proxy Behavior") + print("Uses basic authentication if credentials are in proxy URL") + total_tests += 1 + if test_connection(base_params, "Default Proxy Behavior"): + success_count += 1 + + # Test 2: Explicit basic authentication + print("\n" + "="*60) + print("Test 2: Explicit Basic Authentication") + print("Explicitly requests basic authentication (same as default)") + total_tests += 1 + basic_params = base_params.copy() + basic_params['_proxy_auth_method'] = 'basic' + if test_connection(basic_params, "Basic Proxy Authentication"): + success_count += 1 + + # Test 3: Kerberos/Negotiate authentication + print("\n" + "="*60) + print("Test 3: Kerberos/Negotiate Authentication") + print("Uses Kerberos tickets for proxy authentication") + print("Note: Requires valid Kerberos tickets (run 'kinit' first)") + total_tests += 1 + kerberos_params = base_params.copy() + kerberos_params['_proxy_auth_method'] = 'negotiate' + if test_connection(kerberos_params, "Kerberos Proxy Authentication"): + success_count += 1 + + # Summary + print(f"\n{'='*60}") + print(f"Summary: {success_count}/{total_tests} tests passed") + + if success_count == total_tests: + print("✓ All proxy authentication methods working!") + return 0 + elif success_count > 0: + print("⚠ Some proxy authentication methods failed") + print("This may be normal depending on your proxy configuration") + return 0 + else: + print("✗ All proxy authentication methods failed") + if not has_proxy: + print("Consider checking your proxy configuration") + return 1 + +if __name__ == "__main__": + exit(main()) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py new file mode 100644 index 000000000..f615d082c --- /dev/null +++ b/examples/query_tags_example.py @@ -0,0 +1,30 @@ +import os +import databricks.sql as sql + +""" +This example demonstrates how to use Query Tags. + +Query Tags are key-value pairs that can be attached to SQL executions and will appear +in the system.query.history table for analytical purposes. + +Format: "key1:value1,key2:value2,key3:value3" +""" + +print("=== Query Tags Example ===\n") + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + session_configuration={ + 'QUERY_TAGS': 'team:engineering,test:query-tags', + 'ansi_mode': False + } +) as connection: + + with connection.cursor() as cursor: + cursor.execute("SELECT 1") + result = cursor.fetchone() + print(f" Result: {result[0]}") + +print("\n=== Query Tags Example Complete ===") \ No newline at end of file diff --git a/examples/streaming_put.py b/examples/streaming_put.py new file mode 100644 index 000000000..4e7697099 --- /dev/null +++ b/examples/streaming_put.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +""" +Simple example of streaming PUT operations. + +This demonstrates the basic usage of streaming PUT with the __input_stream__ token. +""" + +import io +import os +from databricks import sql + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + with connection.cursor() as cursor: + # Create a simple data stream + data = b"Hello, streaming world!" + stream = io.BytesIO(data) + + # Get catalog, schema, and volume from environment variables + catalog = os.getenv("DATABRICKS_CATALOG") + schema = os.getenv("DATABRICKS_SCHEMA") + volume = os.getenv("DATABRICKS_VOLUME") + + # Upload to Unity Catalog volume + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/{volume}/hello.txt' OVERWRITE", + input_stream=stream + ) + + print("File uploaded successfully!") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a48793b2f..6f0f74710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "4.1.0" +version = "4.1.2" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/scripts/dependency_manager.py b/scripts/dependency_manager.py new file mode 100644 index 000000000..d73d095f2 --- /dev/null +++ b/scripts/dependency_manager.py @@ -0,0 +1,234 @@ +""" +Dependency version management for testing. +Generates requirements files for min and default dependency versions. +For min versions, creates flexible constraints (e.g., >=1.2.5,<1.3.0) to allow +compatible patch updates instead of pinning exact versions. +""" + +import toml +import sys +import argparse +from packaging.specifiers import SpecifierSet +from packaging.requirements import Requirement +from pathlib import Path + +class DependencyManager: + def __init__(self, pyproject_path="pyproject.toml"): + self.pyproject_path = Path(pyproject_path) + self.dependencies = self._load_dependencies() + + # Map of packages that need specific transitive dependency constraints when downgraded + self.transitive_dependencies = { + 'pandas': { + # When pandas is downgraded to 1.x, ensure numpy compatibility + 'numpy': { + 'min_constraint': '>=1.16.5,<2.0.0', # pandas 1.x works with numpy 1.x + 'applies_when': lambda version: version.startswith('1.') + } + } + } + + def _load_dependencies(self): + """Load dependencies from pyproject.toml""" + with open(self.pyproject_path, 'r') as f: + pyproject = toml.load(f) + return pyproject['tool']['poetry']['dependencies'] + + def _parse_constraint(self, name, constraint): + """Parse a dependency constraint into version info""" + if isinstance(constraint, str): + return constraint, False # version_constraint, is_optional + elif isinstance(constraint, list): + # Handle complex constraints like pandas/pyarrow + first_constraint = constraint[0] + version = first_constraint['version'] + is_optional = first_constraint.get('optional', False) + return version, is_optional + elif isinstance(constraint, dict): + if 'version' in constraint: + return constraint['version'], constraint.get('optional', False) + return None, False + + def _extract_versions_from_specifier(self, spec_set_str): + """Extract minimum version from a specifier set""" + try: + # Handle caret (^) and tilde (~) constraints that packaging doesn't support + if spec_set_str.startswith('^'): + # ^1.2.3 means >=1.2.3, <2.0.0 + min_version = spec_set_str[1:] # Remove ^ + return min_version, None + elif spec_set_str.startswith('~'): + # ~1.2.3 means >=1.2.3, <1.3.0 + min_version = spec_set_str[1:] # Remove ~ + return min_version, None + + spec_set = SpecifierSet(spec_set_str) + min_version = None + + for spec in spec_set: + if spec.operator in ('>=', '=='): + min_version = spec.version + break + + return min_version, None + except Exception as e: + print(f"Warning: Could not parse constraint '{spec_set_str}': {e}", file=sys.stderr) + return None, None + + def _create_flexible_minimum_constraint(self, package_name, min_version): + """Create a flexible minimum constraint that allows compatible updates""" + try: + # Split version into parts + version_parts = min_version.split('.') + + if len(version_parts) >= 2: + major = version_parts[0] + minor = version_parts[1] + + # Special handling for packages that commonly have conflicts + # For these packages, use wider constraints to allow more compatibility + if package_name in ['requests', 'urllib3', 'pandas']: + # Use wider constraint: >=min_version,=2.18.1,<3.0.0 + next_major = int(major) + 1 + upper_bound = f"{next_major}.0.0" + return f"{package_name}>={min_version},<{upper_bound}" + else: + # For other packages, use minor version constraint + # e.g., 1.2.5 becomes >=1.2.5,<1.3.0 + next_minor = int(minor) + 1 + upper_bound = f"{major}.{next_minor}.0" + return f"{package_name}>={min_version},<{upper_bound}" + else: + # If version doesn't have minor version, just use exact version + return f"{package_name}=={min_version}" + + except (ValueError, IndexError) as e: + print(f"Warning: Could not create flexible constraint for {package_name}=={min_version}: {e}", file=sys.stderr) + # Fallback to exact version + return f"{package_name}=={min_version}" + + def _get_transitive_dependencies(self, package_name, version, version_type): + """Get transitive dependencies that need specific constraints based on the main package version""" + transitive_reqs = [] + + if package_name in self.transitive_dependencies: + transitive_deps = self.transitive_dependencies[package_name] + + for dep_name, dep_config in transitive_deps.items(): + # Check if this transitive dependency applies for this version + if dep_config['applies_when'](version): + if version_type == "min": + # Use the predefined constraint for minimum versions + constraint = dep_config['min_constraint'] + transitive_reqs.append(f"{dep_name}{constraint}") + # For default version_type, we don't add transitive deps as Poetry handles them + + return transitive_reqs + + def generate_requirements(self, version_type="min", include_optional=False): + """ + Generate requirements for specified version type. + + Args: + version_type: "min" or "default" + include_optional: Whether to include optional dependencies + """ + requirements = [] + transitive_requirements = [] + + for name, constraint in self.dependencies.items(): + if name == 'python': + continue + + version_constraint, is_optional = self._parse_constraint(name, constraint) + if not version_constraint: + continue + + if is_optional and not include_optional: + continue + + if version_type == "default": + # For default, just use the constraint as-is (let poetry resolve) + requirements.append(f"{name}{version_constraint}") + elif version_type == "min": + min_version, _ = self._extract_versions_from_specifier(version_constraint) + if min_version: + # Create flexible constraint that allows patch updates for compatibility + flexible_constraint = self._create_flexible_minimum_constraint(name, min_version) + requirements.append(flexible_constraint) + + # Check if this package needs specific transitive dependencies + transitive_deps = self._get_transitive_dependencies(name, min_version, version_type) + transitive_requirements.extend(transitive_deps) + + # Combine main requirements with transitive requirements + all_requirements = requirements + transitive_requirements + + # Remove duplicates (prefer main requirements over transitive ones) + seen_packages = set() + final_requirements = [] + + # First add main requirements + for req in requirements: + package_name = Requirement(req).name + seen_packages.add(package_name) + final_requirements.append(req) + + # Then add transitive requirements that don't conflict + for req in transitive_requirements: + package_name = Requirement(req).name + if package_name not in seen_packages: + final_requirements.append(req) + + return final_requirements + + + def write_requirements_file(self, filename, version_type="min", include_optional=False): + """Write requirements to a file""" + requirements = self.generate_requirements(version_type, include_optional) + + with open(filename, 'w') as f: + if version_type == "min": + f.write(f"# Minimum compatible dependency versions generated from pyproject.toml\n") + f.write(f"# Uses flexible constraints to resolve compatibility conflicts:\n") + f.write(f"# - Common packages (requests, urllib3, pandas): >=min,=min, None: + """Handle PUT operation with streaming data. + + Args: + presigned_url: The presigned URL for upload + stream: Binary stream to upload + headers: HTTP headers + + Raises: + ProgrammingError: If no input stream is provided + OperationalError: If the upload fails + """ + + if not stream: + raise ProgrammingError( + "No input stream provided for streaming operation", + session_id_hex=self.connection.get_session_id_hex(), + ) + + r = self.connection.http_client.request( + HttpMethod.PUT, presigned_url, body=stream.read(), headers=headers + ) + + self._handle_staging_http_response(r) + @log_latency(StatementType.SQL) def _handle_staging_get( self, local_file: str, presigned_url: str, headers: Optional[dict] = None @@ -840,6 +892,7 @@ def execute( operation: str, parameters: Optional[TParameterCollection] = None, enforce_embedded_schema_correctness=False, + input_stream: Optional[BinaryIO] = None, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -914,7 +967,8 @@ def execute( if self.active_result_set and self.active_result_set.is_staging_operation: self._handle_staging_operation( - staging_allowed_local_path=self.connection.staging_allowed_local_path + staging_allowed_local_path=self.connection.staging_allowed_local_path, + input_stream=input_stream, ) return self diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py index c31b5a3cf..7ccd69c54 100644 --- a/src/databricks/sql/common/unified_http_client.py +++ b/src/databricks/sql/common/unified_http_client.py @@ -10,6 +10,14 @@ from urllib3.util import make_headers from urllib3.exceptions import MaxRetryError +# Compatibility import for different urllib3 versions +try: + # If urllib3~=2.0 is installed + from urllib3 import BaseHTTPResponse +except ImportError: + # If urllib3~=1.0 is installed + from urllib3 import HTTPResponse as BaseHTTPResponse + from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType from databricks.sql.exc import RequestError from databricks.sql.common.http import HttpMethod @@ -222,7 +230,7 @@ def request_context( url: str, headers: Optional[Dict[str, str]] = None, **kwargs, - ) -> Generator[urllib3.BaseHTTPResponse, None, None]: + ) -> Generator[BaseHTTPResponse, None, None]: """ Context manager for making HTTP requests with proper resource cleanup. @@ -233,7 +241,7 @@ def request_context( **kwargs: Additional arguments passed to urllib3 request Yields: - urllib3.BaseHTTPResponse: The HTTP response object + BaseHTTPResponse: The HTTP response object """ logger.debug( "Making %s request to %s", method, urllib.parse.urlparse(url).netloc @@ -270,7 +278,7 @@ def request( url: str, headers: Optional[Dict[str, str]] = None, **kwargs, - ) -> urllib3.BaseHTTPResponse: + ) -> BaseHTTPResponse: """ Make an HTTP request. @@ -281,7 +289,7 @@ def request( **kwargs: Additional arguments passed to urllib3 request Returns: - urllib3.BaseHTTPResponse: The HTTP response object with data and metadata pre-loaded + BaseHTTPResponse: The HTTP response object with data and metadata pre-loaded """ with self.request_context(method, url, headers=headers, **kwargs) as response: # Read the response data to ensure it's available after context exit diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 9e6214648..9f96e8743 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -298,7 +298,7 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table": num_rows -= table_slice.num_rows logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def remaining_rows(self) -> "pyarrow.Table": """ @@ -321,7 +321,7 @@ def remaining_rows(self) -> "pyarrow.Table": self.table_row_index += table_slice.num_rows self.table = self._create_next_table() self.table_row_index = 0 - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]: """Create next table at the given row offset""" @@ -880,7 +880,7 @@ def concat_table_chunks( result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) else: - return pyarrow.concat_tables(table_chunks, use_threads=True) + return pyarrow.concat_tables(table_chunks) def build_client_context(server_hostname: str, version: str, **kwargs): diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index aeeb67974..dd7c56996 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -117,7 +117,7 @@ def test_long_running_query(self, extra_params): scale_factor = 1 with self.cursor(extra_params) as cursor: while duration < min_duration: - assert scale_factor < 1024, "Detected infinite loop" + assert scale_factor < 4096, "Detected infinite loop" start = time.time() cursor.execute( diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 2798541ad..b2350bd98 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -2,6 +2,7 @@ import time from typing import Optional, List from unittest.mock import MagicMock, PropertyMock, patch +import io import pytest from urllib3.exceptions import MaxRetryError @@ -91,22 +92,78 @@ def _test_retry_disabled_with_message( assert error_msg_substring in str(cm.exception) +class SimpleHttpResponse: + """A simple HTTP response mock that works with both urllib3 v1.x and v2.x""" + + def __init__(self, status: int, headers: dict, redirect_location: Optional[str] = None): + # Import the correct HTTP message type that urllib3 v1.x expects + try: + from http.client import HTTPMessage + except ImportError: + from httplib import HTTPMessage + + self.status = status + # Create proper HTTPMessage for urllib3 v1.x compatibility + self.headers = HTTPMessage() + for key, value in headers.items(): + self.headers[key] = str(value) + self.msg = self.headers # For urllib3~=1.0.0 compatibility + self.reason = "Mocked Response" + self.version = 11 + self.length = 0 + self.length_remaining = 0 + self._redirect_location = redirect_location + self._body = b"" + self._fp = io.BytesIO(self._body) + self._url = "https://example.com" + + def get_redirect_location(self, *args, **kwargs): + """Return the redirect location or False""" + return False if self._redirect_location is None else self._redirect_location + + def read(self, amt=None): + """Mock read method for file-like behavior""" + return self._body + + def close(self): + """Mock close method""" + pass + + def drain_conn(self): + """Mock drain_conn method for urllib3 v2.x""" + pass + + def isclosed(self): + """Mock isclosed method for urllib3 v1.x""" + return False + + def release_conn(self): + """Mock release_conn method for thrift HTTP client""" + pass + + @property + def data(self): + """Mock data property for urllib3 v2.x""" + return self._body + + @property + def url(self): + """Mock url property""" + return self._url + + @url.setter + def url(self, value): + """Mock url setter""" + self._url = value + + @contextmanager def mocked_server_response( status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None ): - """Context manager for patching urllib3 responses""" - - # When mocking mocking a BaseHTTPResponse for urllib3 the mock must include - # 1. A status code - # 2. A headers dict - # 3. mock.get_redirect_location() return falsy by default - - # `msg` is included for testing when urllib3~=1.0.0 is installed - mock_response = MagicMock(headers=headers, msg=headers, status=status) - mock_response.get_redirect_location.return_value = ( - False if redirect_location is None else redirect_location - ) + """Context manager for patching urllib3 responses with version compatibility""" + + mock_response = SimpleHttpResponse(status, headers, redirect_location) with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock: getconn_mock.return_value.getresponse.return_value = mock_response @@ -127,18 +184,14 @@ def mock_sequential_server_responses(responses: List[dict]): - redirect_location: str """ - mock_responses = [] - - # Each resp should have these members: - - for resp in responses: - _mock = MagicMock( - headers=resp["headers"], msg=resp["headers"], status=resp["status"] - ) - _mock.get_redirect_location.return_value = ( - False if resp["redirect_location"] is None else resp["redirect_location"] + mock_responses = [ + SimpleHttpResponse( + status=resp["status"], + headers=resp["headers"], + redirect_location=resp["redirect_location"] ) - mock_responses.append(_mock) + for resp in responses + ] with patch("urllib3.connectionpool.HTTPSConnectionPool._get_conn") as getconn_mock: getconn_mock.return_value.getresponse.side_effect = mock_responses @@ -475,21 +528,23 @@ def test_retry_abort_close_operation_on_404(self, extra_params, caplog): ], ) @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient._send_telemetry") - def test_retry_max_redirects_raises_too_many_redirects_exception( + def test_3xx_redirect_codes_are_not_retried( self, mock_send_telemetry, extra_params ): """GIVEN the connector is configured with a custom max_redirects - WHEN the DatabricksRetryPolicy is created - THEN the connector raises a MaxRedirectsError if that number is exceeded + WHEN the DatabricksRetryPolicy receives a 302 redirect + THEN the connector does not retry since 3xx codes are not retried per policy """ - max_redirects, expected_call_count = 1, 2 + max_redirects, expected_call_count = 1, 1 - # Code 302 is a redirect + # Code 302 is a redirect, but 3xx codes are not retried per policy + # Note: We don't set redirect_location because that would cause urllib3 v2.x + # to follow redirects internally, bypassing our retry policy test with mocked_server_response( - status=302, redirect_location="/foo.bar" + status=302, redirect_location=None ) as mock_obj: - with pytest.raises(MaxRetryError) as cm: + with pytest.raises(RequestError): # Should get RequestError, not MaxRetryError with self.connection( extra_params={ **extra_params, @@ -498,8 +553,7 @@ def test_retry_max_redirects_raises_too_many_redirects_exception( } ): pass - assert "too many redirects" == str(cm.value.reason) - # Total call count should be 2 (original + 1 retry) + # Total call count should be 1 (original only, no retries for 3xx codes) assert mock_obj.return_value.getresponse.call_count == expected_call_count @pytest.mark.parametrize( @@ -510,21 +564,23 @@ def test_retry_max_redirects_raises_too_many_redirects_exception( ], ) @patch("databricks.sql.telemetry.telemetry_client.TelemetryClient._send_telemetry") - def test_retry_max_redirects_unset_doesnt_redirect_forever( + def test_3xx_codes_not_retried_regardless_of_max_redirects_setting( self, mock_send_telemetry, extra_params ): """GIVEN the connector is configured without a custom max_redirects - WHEN the DatabricksRetryPolicy is used - THEN the connector raises a MaxRedirectsError if that number is exceeded + WHEN the DatabricksRetryPolicy receives a 302 redirect + THEN the connector does not retry since 3xx codes are not retried per policy - This test effectively guarantees that regardless of _retry_max_redirects, - _stop_after_attempts_count is enforced. + This test confirms that 3xx codes (including redirects) are not retried + according to the DatabricksRetryPolicy regardless of redirect settings. """ - # Code 302 is a redirect + # Code 302 is a redirect, but 3xx codes are not retried per policy + # Note: We don't set redirect_location because that would cause urllib3 v2.x + # to follow redirects internally, bypassing our retry policy test with mocked_server_response( - status=302, redirect_location="/foo.bar/" + status=302, redirect_location=None ) as mock_obj: - with pytest.raises(MaxRetryError) as cm: + with pytest.raises(RequestError): # Should get RequestError, not MaxRetryError with self.connection( extra_params={ **extra_params, @@ -533,8 +589,8 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever( ): pass - # Total call count should be 6 (original + _retry_stop_after_attempts_count) - assert mock_obj.return_value.getresponse.call_count == 6 + # Total call count should be 1 (original only, no retries for 3xx codes) + assert mock_obj.return_value.getresponse.call_count == 1 @pytest.mark.parametrize( "extra_params", @@ -543,13 +599,13 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever( {"use_sea": True}, ], ) - def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count( + def test_3xx_codes_stop_request_immediately_no_retry_attempts( self, extra_params ): - # If I add another 503 or 302 here the test will fail with a MaxRetryError + # Since 3xx codes are not retried per policy, we only ever see the first 302 response responses = [ {"status": 302, "headers": {}, "redirect_location": "/foo.bar"}, - {"status": 500, "headers": {}, "redirect_location": None}, + {"status": 500, "headers": {}, "redirect_location": None}, # Never reached ] additional_settings = { @@ -568,7 +624,7 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count( ): pass - # The error should be the result of the 500, not because of too many requests. + # The error should be the result of the 302, since 3xx codes are not retried assert "too many redirects" not in str(cm.value.message) assert "Error during request to server" in str(cm.value.message) diff --git a/tests/e2e/common/streaming_put_tests.py b/tests/e2e/common/streaming_put_tests.py new file mode 100644 index 000000000..83da10fd3 --- /dev/null +++ b/tests/e2e/common/streaming_put_tests.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +E2E tests for streaming PUT operations. +""" + +import io +import logging +import pytest + +logger = logging.getLogger(__name__) + + +class PySQLStreamingPutTestSuiteMixin: + """Test suite for streaming PUT operations.""" + + def test_streaming_put_basic(self, catalog, schema): + """Test basic streaming PUT functionality.""" + + # Create test data + test_data = b"Hello, streaming world! This is test data." + filename = "streaming_put_test.txt" + file_path = f"/Volumes/{catalog}/{schema}/e2etests/{filename}" + + try: + with self.connection() as conn: + with conn.cursor() as cursor: + self._cleanup_test_file(file_path) + + with io.BytesIO(test_data) as stream: + cursor.execute( + f"PUT '__input_stream__' INTO '{file_path}'", + input_stream=stream + ) + + # Verify file exists + cursor.execute(f"LIST '/Volumes/{catalog}/{schema}/e2etests/'") + files = cursor.fetchall() + + # Check if our file is in the list + file_paths = [row[0] for row in files] + assert file_path in file_paths, f"File {file_path} not found in {file_paths}" + finally: + self._cleanup_test_file(file_path) + + def test_streaming_put_missing_stream(self, catalog, schema): + """Test that missing stream raises appropriate error.""" + + with self.connection() as conn: + with conn.cursor() as cursor: + # Test without providing stream + with pytest.raises(Exception): # Should fail + cursor.execute( + f"PUT '__input_stream__' INTO '/Volumes/{catalog}/{schema}/e2etests/test.txt'" + # Note: No input_stream parameter + ) + + def _cleanup_test_file(self, file_path): + """Clean up a test file if it exists.""" + try: + with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"REMOVE '{file_path}'") + logger.info("Successfully cleaned up test file: %s", file_path) + except Exception as e: + logger.error("Cleanup failed for %s: %s", file_path, e) \ No newline at end of file diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 53b7383e6..e04e348c9 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,6 +50,7 @@ from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin +from tests.e2e.common.streaming_put_tests import PySQLStreamingPutTestSuiteMixin from databricks.sql.exc import SessionAlreadyClosedError @@ -290,6 +291,7 @@ class TestPySQLCoreSuite( PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLUCVolumeTestSuiteMixin, + PySQLStreamingPutTestSuiteMixin, ): validate_row_value_type = True validate_result = True @@ -899,7 +901,7 @@ def test_timestamps_arrow(self): ) def test_multi_timestamps_arrow(self, extra_params): with self.cursor( - {"session_configuration": {"ansi_mode": False}, **extra_params} + {"session_configuration": {"ansi_mode": False, "query_tags": "test:multi-timestamps,driver:python"}, **extra_params} ) as cursor: query, expected = self.multi_query() expected = [ diff --git a/tests/e2e/test_variant_types.py b/tests/e2e/test_variant_types.py new file mode 100644 index 000000000..b5dc1f421 --- /dev/null +++ b/tests/e2e/test_variant_types.py @@ -0,0 +1,91 @@ +import pytest +from datetime import datetime +import json + +try: + import pyarrow +except ImportError: + pyarrow = None + +from tests.e2e.test_driver import PySQLPytestTestCase +from tests.e2e.common.predicates import pysql_supports_arrow + + +@pytest.mark.skipif(not pysql_supports_arrow(), reason="Requires arrow support") +class TestVariantTypes(PySQLPytestTestCase): + """Tests for the proper detection and handling of VARIANT type columns""" + + @pytest.fixture(scope="class") + def variant_table(self, connection_details): + """A pytest fixture that creates a test table and cleans up after tests""" + self.arguments = connection_details.copy() + table_name = "pysql_test_variant_types_table" + + with self.cursor() as cursor: + try: + # Create the table with variant columns + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS pysql_test_variant_types_table ( + id INTEGER, + variant_col VARIANT, + regular_string_col STRING + ) + """ + ) + + # Insert test records with different variant values + cursor.execute( + """ + INSERT INTO pysql_test_variant_types_table + VALUES + (1, PARSE_JSON('{"name": "John", "age": 30}'), 'regular string'), + (2, PARSE_JSON('[1, 2, 3, 4]'), 'another string') + """ + ) + yield table_name + finally: + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + + def test_variant_type_detection(self, variant_table): + """Test that VARIANT type columns are properly detected in schema""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} LIMIT 0") + + # Verify column types in description + assert ( + cursor.description[0][1] == "int" + ), "Integer column type not correctly identified" + assert ( + cursor.description[1][1] == "variant" + ), "VARIANT column type not correctly identified" + assert ( + cursor.description[2][1] == "string" + ), "String column type not correctly identified" + + def test_variant_data_retrieval(self, variant_table): + """Test that VARIANT data is properly retrieved and can be accessed as JSON""" + with self.cursor() as cursor: + cursor.execute(f"SELECT * FROM {variant_table} ORDER BY id") + rows = cursor.fetchall() + + # First row should have a JSON object + json_obj = rows[0][1] + assert isinstance( + json_obj, str + ), "VARIANT column should be returned as string" + + parsed = json.loads(json_obj) + assert parsed.get("name") == "John" + assert parsed.get("age") == 30 + + # Second row should have a JSON array + json_array = rows[1][1] + assert isinstance( + json_array, str + ), "VARIANT array should be returned as string" + + # Parsing to verify it's valid JSON array + parsed_array = json.loads(json_array) + assert isinstance(parsed_array, list) + assert parsed_array == [1, 2, 3, 4] diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f604f2874..26a898cb8 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -185,6 +185,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter + "QUERY_TAGS": "team:marketing,dashboard:abc123", # Supported parameter "unsupported_param": "value", # Unsupported parameter } catalog = "test_catalog" @@ -196,6 +197,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i "session_confs": { "ansi_mode": "FALSE", "statement_timeout": "3600", + "query_tags": "team:marketing,dashboard:abc123", }, "catalog": catalog, "schema": schema, @@ -641,6 +643,7 @@ def test_filter_session_configuration(self): "TIMEZONE": "UTC", "enable_photon": False, "MAX_FILE_PARTITION_BYTES": 128.5, + "QUERY_TAGS": "team:engineering,project:data-pipeline", "unsupported_param": "value", "ANOTHER_UNSUPPORTED": 42, } @@ -663,6 +666,7 @@ def test_filter_session_configuration(self): "timezone": "UTC", # string -> "UTC", key lowercased "enable_photon": "False", # boolean False -> "False", key lowercased "max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased + "query_tags": "team:engineering,project:data-pipeline", } assert result == expected_result @@ -683,12 +687,14 @@ def test_filter_session_configuration(self): "ansi_mode": "false", # lowercase key "STATEMENT_TIMEOUT": 7200, # uppercase key "TiMeZoNe": "America/New_York", # mixed case key + "QueRy_TaGs": "team:marketing,test:case-insensitive", } result = _filter_session_configuration(case_insensitive_config) expected_case_result = { "ansi_mode": "false", "statement_timeout": "7200", "timezone": "America/New_York", + "query_tags": "team:marketing,test:case-insensitive", } assert result == expected_case_result diff --git a/tests/unit/test_sea_http_client.py b/tests/unit/test_sea_http_client.py index 10f10592d..39ecb58a7 100644 --- a/tests/unit/test_sea_http_client.py +++ b/tests/unit/test_sea_http_client.py @@ -96,7 +96,8 @@ def test_make_request_success(self, mock_get_auth_headers, sea_http_client): # Setup mock response mock_response = Mock() mock_response.status = 200 - mock_response.json.return_value = {"result": "success"} + # Mock response.data.decode() to return a valid JSON string + mock_response.data.decode.return_value = '{"result": "success"}' mock_response.__enter__ = Mock(return_value=mock_response) mock_response.__exit__ = Mock(return_value=None) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e019e05a2..c135a846b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -155,7 +155,7 @@ def test_socket_timeout_passthrough(self, mock_client_class): @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() + mock_session_config = {"ANSI_MODE": "FALSE", "QUERY_TAGS": "team:engineering,project:data-pipeline"} databricks.sql.connect( session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS ) diff --git a/tests/unit/test_streaming_put.py b/tests/unit/test_streaming_put.py new file mode 100644 index 000000000..2b9a9e6d6 --- /dev/null +++ b/tests/unit/test_streaming_put.py @@ -0,0 +1,113 @@ +import io +from unittest.mock import patch, Mock, MagicMock + +import pytest + +import databricks.sql.client as client + + +class TestStreamingPut: + """Unit tests for streaming PUT functionality.""" + + @pytest.fixture + def cursor(self): + return client.Cursor(connection=Mock(), backend=Mock()) + + def _setup_mock_staging_put_stream_response(self, mock_backend): + """Helper method to set up mock staging PUT stream response.""" + mock_result_set = Mock() + mock_result_set.is_staging_operation = True + mock_backend.execute_command.return_value = mock_result_set + + mock_row = Mock() + mock_row.operation = "PUT" + mock_row.localFile = "__input_stream__" + mock_row.presignedUrl = "https://example.com/upload" + mock_row.headers = "{}" + mock_result_set.fetchone.return_value = mock_row + + return mock_result_set + + def test_execute_with_valid_stream(self, cursor): + """Test execute method with valid input stream.""" + + # Mock the backend response + self._setup_mock_staging_put_stream_response(cursor.backend) + + # Test with valid stream + test_stream = io.BytesIO(b"test data") + + with patch.object(cursor, "_handle_staging_put_stream") as mock_handler: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=test_stream, + ) + + # Verify staging handler was called + mock_handler.assert_called_once() + + def test_execute_with_none_stream_for_staging_put(self, cursor): + """Test execute method rejects None stream for streaming PUT operations.""" + + # Mock staging operation response for None case + self._setup_mock_staging_put_stream_response(cursor.backend) + + # None with __input_stream__ raises ProgrammingError + with pytest.raises(client.ProgrammingError) as excinfo: + cursor.execute( + "PUT '__input_stream__' INTO '/Volumes/test/cat/schema/vol/file.txt'", + input_stream=None, + ) + error_msg = str(excinfo.value) + assert "No input stream provided for streaming operation" in error_msg + + def test_handle_staging_put_stream_success(self, cursor): + """Test successful streaming PUT operation.""" + + presigned_url = "https://example.com/upload" + headers = {"Content-Type": "text/plain"} + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.data = b"success" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream, headers=headers + ) + + # Verify the HTTP client was called correctly + mock_http_request.assert_called_once() + call_args = mock_http_request.call_args + # Check positional arguments: (method, url, body=..., headers=...) + assert call_args[0][0].value == "PUT" # First positional arg is method + assert call_args[0][1] == presigned_url # Second positional arg is url + # Check keyword arguments + assert call_args[1]["body"] == b"test data" + assert call_args[1]["headers"] == headers + + def test_handle_staging_put_stream_http_error(self, cursor): + """Test streaming PUT operation with HTTP error.""" + + presigned_url = "https://example.com/upload" + + with patch.object( + cursor.connection.http_client, "request" + ) as mock_http_request: + mock_response = MagicMock() + mock_response.status = 500 + mock_response.data = b"Internal Server Error" + mock_http_request.return_value = mock_response + + test_stream = io.BytesIO(b"test data") + with pytest.raises(client.OperationalError) as excinfo: + cursor._handle_staging_put_stream( + presigned_url=presigned_url, stream=test_stream + ) + + # Check for the actual error message format + assert "500" in str(excinfo.value) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0445ace3e..7254b66cb 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -2330,7 +2330,7 @@ def test_execute_command_sets_complex_type_fields_correctly( [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), - http_client=MagicMock(), + http_client=MagicMock(), **complex_arg_types, ) thrift_backend.execute_command( @@ -2356,6 +2356,86 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow ) + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_col_to_description(self): + test_cases = [ + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}, "variant"), + ("normal_col", {}, "string"), + ("weird_field", {b"Spark:DataType:SqlName": b"Some unexpected value"}, "string"), + ("missing_field", None, "string"), # None field case + ] + + for column_name, field_metadata, expected_type in test_cases: + with self.subTest(column_name=column_name, expected_type=expected_type): + col = ttypes.TColumnDesc( + columnName=column_name, + typeDesc=self._make_type_desc(ttypes.TTypeId.STRING_TYPE), + ) + + field = ( + None + if field_metadata is None + else pyarrow.field(column_name, pyarrow.string(), metadata=field_metadata) + ) + + result = ThriftDatabricksClient._col_to_description(col, field) + + self.assertEqual(result[0], column_name) + self.assertEqual(result[1], expected_type) + self.assertIsNone(result[2]) + self.assertIsNone(result[3]) + self.assertIsNone(result[4]) + self.assertIsNone(result[5]) + self.assertIsNone(result[6]) + + @unittest.skipIf(pyarrow is None, "Requires pyarrow") + def test_hive_schema_to_description(self): + test_cases = [ + ( + [ + ("regular_col", ttypes.TTypeId.STRING_TYPE), + ("variant_col", ttypes.TTypeId.STRING_TYPE), + ], + [ + ("regular_col", {}), + ("variant_col", {b"Spark:DataType:SqlName": b"VARIANT"}), + ], + [("regular_col", "string"), ("variant_col", "variant")], + ), + ( + [("regular_col", ttypes.TTypeId.STRING_TYPE)], + None, # No arrow schema + [("regular_col", "string")], + ), + ] + + for columns, arrow_fields, expected_types in test_cases: + with self.subTest(arrow_fields=arrow_fields is not None): + t_table_schema = ttypes.TTableSchema( + columns=[ + ttypes.TColumnDesc( + columnName=name, typeDesc=self._make_type_desc(col_type) + ) + for name, col_type in columns + ] + ) + + schema_bytes = None + if arrow_fields: + fields = [ + pyarrow.field(name, pyarrow.string(), metadata=metadata) + for name, metadata in arrow_fields + ] + schema_bytes = pyarrow.schema(fields).serialize().to_pybytes() + + description = ThriftDatabricksClient._hive_schema_to_description( + t_table_schema, schema_bytes + ) + + for i, (expected_name, expected_type) in enumerate(expected_types): + self.assertEqual(description[i][0], expected_name) + self.assertEqual(description[i][1], expected_type) + if __name__ == "__main__": unittest.main()