diff --git a/examples/volume_operations.py b/examples/volume_operations.py new file mode 100644 index 000000000..8de12fc39 --- /dev/null +++ b/examples/volume_operations.py @@ -0,0 +1,71 @@ +""" +This example demonstrates how to use the UCVolumeClient with Unity Catalog Volumes. +""" + +from databricks import sql +import os + +host = os.getenv("DATABRICKS_SERVER_HOSTNAME") +http_path = os.getenv("DATABRICKS_HTTP_PATH") +access_token = os.getenv("DATABRICKS_TOKEN") +catalog = os.getenv("DATABRICKS_CATALOG") +schema = os.getenv("DATABRICKS_SCHEMA") + +if not all([host, http_path, access_token, catalog, schema]): + print("Error: Please set all required environment variables") + print("Required: DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, DATABRICKS_TOKEN, DATABRICKS_CATALOG, DATABRICKS_SCHEMA") + exit(1) + +# Type assertions for the linter +assert host and http_path and access_token and catalog and schema + +# Connect to Databricks +with sql.connect( + server_hostname=host, + http_path=http_path, + access_token=access_token, +) as connection: + + # Get the UC volume client + volume_client = connection.get_uc_volume_client() + + # Example volume name (change this to match your setup) + volume_name = "sv-volume" + + print(f"Using volume: /Volumes/{catalog}/{schema}/{volume_name}/") + print() + + # Check if a file exists + exists = volume_client.object_exists(catalog, schema, volume_name, "sample-1.txt") + print(f"File 'sample-1.txt' exists: {exists}") + + # Check if a file in subdirectory exists + exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/sample-1.txt") + print(f"File 'dir-1/sample-1.txt' exists: {exists}") + + # Check if a directory exists + exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1") + print(f"Directory 'dir-1' exists: {exists}") + + # Check if a file exists + exists = volume_client.object_exists(catalog, schema, volume_name, "sample-2.txt") + print(f"File 'sample-2.txt' exists: {exists}") + + # Check if a file in subdirectory exists + exists = volume_client.object_exists(catalog, schema, volume_name, "dir-2/sample-2.txt") + print(f"File 'dir-2/sample-2.txt' exists: {exists}") + + # Check if a directory exists + exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/") + print(f"Directory 'dir-1/' exists: {exists}") + + + # Case-insensitive check + exists = volume_client.object_exists(catalog, schema, volume_name, "SAMPLE-1.txt", case_sensitive=False) + print(f"File 'SAMPLE-1.txt' exists (case-insensitive): {exists}") + + exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/SAMPLE-1.txt", case_sensitive=False) + print(f"File 'dir-1/SAMPLE-1.txt' exists (case-insensitive): {exists}") + + print("\nVolume operations example completed!") + diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index b4cd78cf8..0a5eaf3ef 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -47,6 +47,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.volume.uc_volume_client import UCVolumeClient from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -433,6 +434,15 @@ def cursor( self._cursors.append(cursor) return cursor + def get_uc_volume_client(self) -> UCVolumeClient: + if not self.open: + raise InterfaceError( + "Cannot create UC volume client from closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + return UCVolumeClient(connection=self) + def close(self) -> None: """Close the underlying session and mark all associated cursors as closed.""" self._close() diff --git a/src/databricks/sql/volume/__init__.py b/src/databricks/sql/volume/__init__.py new file mode 100644 index 000000000..b9cc1c1f9 --- /dev/null +++ b/src/databricks/sql/volume/__init__.py @@ -0,0 +1,3 @@ +from .uc_volume_client import UCVolumeClient + +__all__ = ['UCVolumeClient'] \ No newline at end of file diff --git a/src/databricks/sql/volume/uc_volume_client.py b/src/databricks/sql/volume/uc_volume_client.py new file mode 100644 index 000000000..76269fece --- /dev/null +++ b/src/databricks/sql/volume/uc_volume_client.py @@ -0,0 +1,72 @@ + +import logging +from typing import TYPE_CHECKING, List + +from databricks.sql.exc import OperationalError, ProgrammingError, ServerOperationError +from .volume_utils import ( + parse_path, + build_volume_path, + names_match, + validate_volume_inputs, + DIRECTORY_NOT_FOUND_ERROR +) + +# Avoid circular import +if TYPE_CHECKING: + from databricks.sql.client import Connection + +logger = logging.getLogger(__name__) + + +class UCVolumeClient: + + def __init__(self, connection: "Connection"): + self.connection = connection + self.session_id_hex = connection.get_session_id_hex() + + + def _execute_list_query(self, query: str) -> List: + """Execute LIST query and handle common errors.""" + try: + with self.connection.cursor() as cursor: + cursor.execute(query) + return cursor.fetchall() + except ServerOperationError as e: + if DIRECTORY_NOT_FOUND_ERROR in str(e): + return [] # Directory doesn't exist + raise OperationalError(f"Query failed: {str(e)}", session_id_hex=self.session_id_hex) from e + except Exception as e: + raise OperationalError(f"Query failed: {str(e)}", session_id_hex=self.session_id_hex) from e + + def object_exists(self, catalog: str, schema: str, volume: str, path: str, case_sensitive: bool = True) -> bool: + + validate_volume_inputs(catalog, schema, volume, path, self.session_id_hex) + + if not path.strip(): + return False + + folder, filename = parse_path(path) + volume_path = build_volume_path(catalog, schema, volume, folder) + query = f"LIST '{volume_path}'" + logger.debug(f"Executing query: {query}") + + results = self._execute_list_query(query) + if not results: + return False + + # Check if our file exists in results + # Row structure: [path, name, size, modification_time] + # Example: ['/Volumes/catalog/schema/volume/dir/file.txt', 'file.txt', 1024, 1752757716901] + # For directories: both path and name end with '/' (e.g., '/Volumes/.../dir/', 'dir/') + for row in results: + if len(row) > 1: + found_name = str(row[1]) # Second column is the filename + + # Remove trailing slash from directories + if found_name.endswith('/'): + found_name = found_name[:-1] + + if names_match(found_name, filename, case_sensitive): + return True + + return False \ No newline at end of file diff --git a/src/databricks/sql/volume/volume_utils.py b/src/databricks/sql/volume/volume_utils.py new file mode 100644 index 000000000..44a7d9e32 --- /dev/null +++ b/src/databricks/sql/volume/volume_utils.py @@ -0,0 +1,56 @@ +from typing import Tuple, Optional + +from databricks.sql.exc import ProgrammingError + +# Constants +VOLUME_PATH_TEMPLATE = "/Volumes/{catalog}/{schema}/{volume}/" +DIRECTORY_NOT_FOUND_ERROR = "No such file or directory" + + +def validate_volume_inputs(catalog: str, schema: str, volume: str, path: str, session_id_hex: Optional[str] = None) -> None: + if not all([catalog, schema, volume, path]): + raise ProgrammingError( + "All parameters (catalog, schema, volume, path) are required", + session_id_hex=session_id_hex + ) + + +def parse_path(path: str) -> Tuple[str, str]: + if not path or path == '/': + return '', '' + + # Handle trailing slash - treat "dir-1/" as looking for directory "dir-1" + path = path.rstrip('/') + + if '/' in path: + folder, filename = path.rsplit('/', 1) + else: + folder, filename = '', path + return folder, filename + + +def escape_path_component(component: str) -> str: + """Escape path component to prevent SQL injection. + """ + return component.replace("'", "''") + + +def build_volume_path(catalog: str, schema: str, volume: str, folder: str = "") -> str: + catalog_escaped = escape_path_component(catalog) + schema_escaped = escape_path_component(schema) + volume_escaped = escape_path_component(volume) + volume_path = VOLUME_PATH_TEMPLATE.format( + catalog=catalog_escaped, + schema=schema_escaped, + volume=volume_escaped + ) + if folder: + folder_escaped = escape_path_component(folder) + volume_path += folder_escaped + "/" + return volume_path + + +def names_match(found_name: str, target_name: str, case_sensitive: bool) -> bool: + if case_sensitive: + return found_name == target_name + return found_name.lower() == target_name.lower() \ No newline at end of file