Skip to content

uc volume client setup #640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions examples/volume_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
This example demonstrates how to use the UCVolumeClient with Unity Catalog Volumes.
"""

from databricks import sql
import os

host = os.getenv("DATABRICKS_SERVER_HOSTNAME")
http_path = os.getenv("DATABRICKS_HTTP_PATH")
access_token = os.getenv("DATABRICKS_TOKEN")
catalog = os.getenv("DATABRICKS_CATALOG")
schema = os.getenv("DATABRICKS_SCHEMA")

if not all([host, http_path, access_token, catalog, schema]):
print("Error: Please set all required environment variables")
print("Required: DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, DATABRICKS_TOKEN, DATABRICKS_CATALOG, DATABRICKS_SCHEMA")
exit(1)

# Type assertions for the linter
assert host and http_path and access_token and catalog and schema

# Connect to Databricks
with sql.connect(
server_hostname=host,
http_path=http_path,
access_token=access_token,
) as connection:

# Get the UC volume client
volume_client = connection.get_uc_volume_client()

# Example volume name (change this to match your setup)
volume_name = "sv-volume"

print(f"Using volume: /Volumes/{catalog}/{schema}/{volume_name}/")
print()

# Check if a file exists
exists = volume_client.object_exists(catalog, schema, volume_name, "sample-1.txt")
print(f"File 'sample-1.txt' exists: {exists}")

# Check if a file in subdirectory exists
exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/sample-1.txt")
print(f"File 'dir-1/sample-1.txt' exists: {exists}")

# Check if a directory exists
exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1")
print(f"Directory 'dir-1' exists: {exists}")

# Check if a file exists
exists = volume_client.object_exists(catalog, schema, volume_name, "sample-2.txt")
print(f"File 'sample-2.txt' exists: {exists}")

# Check if a file in subdirectory exists
exists = volume_client.object_exists(catalog, schema, volume_name, "dir-2/sample-2.txt")
print(f"File 'dir-2/sample-2.txt' exists: {exists}")

# Check if a directory exists
exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/")
print(f"Directory 'dir-1/' exists: {exists}")


# Case-insensitive check
exists = volume_client.object_exists(catalog, schema, volume_name, "SAMPLE-1.txt", case_sensitive=False)
print(f"File 'SAMPLE-1.txt' exists (case-insensitive): {exists}")

exists = volume_client.object_exists(catalog, schema, volume_name, "dir-1/SAMPLE-1.txt", case_sensitive=False)
print(f"File 'dir-1/SAMPLE-1.txt' exists (case-insensitive): {exists}")

print("\nVolume operations example completed!")

10 changes: 10 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
from databricks.sql.volume.uc_volume_client import UCVolumeClient

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
Expand Down Expand Up @@ -433,6 +434,15 @@ def cursor(
self._cursors.append(cursor)
return cursor

def get_uc_volume_client(self) -> UCVolumeClient:
if not self.open:
raise InterfaceError(
"Cannot create UC volume client from closed connection",
session_id_hex=self.get_session_id_hex(),
)

return UCVolumeClient(connection=self)

def close(self) -> None:
"""Close the underlying session and mark all associated cursors as closed."""
self._close()
Expand Down
3 changes: 3 additions & 0 deletions src/databricks/sql/volume/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .uc_volume_client import UCVolumeClient

__all__ = ['UCVolumeClient']
72 changes: 72 additions & 0 deletions src/databricks/sql/volume/uc_volume_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

import logging
from typing import TYPE_CHECKING, List

from databricks.sql.exc import OperationalError, ProgrammingError, ServerOperationError
from .volume_utils import (
parse_path,
build_volume_path,
names_match,
validate_volume_inputs,
DIRECTORY_NOT_FOUND_ERROR
)

# Avoid circular import
if TYPE_CHECKING:
from databricks.sql.client import Connection

logger = logging.getLogger(__name__)


class UCVolumeClient:

def __init__(self, connection: "Connection"):
self.connection = connection
self.session_id_hex = connection.get_session_id_hex()


def _execute_list_query(self, query: str) -> List:
"""Execute LIST query and handle common errors."""
try:
with self.connection.cursor() as cursor:
cursor.execute(query)
return cursor.fetchall()
except ServerOperationError as e:
if DIRECTORY_NOT_FOUND_ERROR in str(e):
return [] # Directory doesn't exist
raise OperationalError(f"Query failed: {str(e)}", session_id_hex=self.session_id_hex) from e
except Exception as e:
raise OperationalError(f"Query failed: {str(e)}", session_id_hex=self.session_id_hex) from e

def object_exists(self, catalog: str, schema: str, volume: str, path: str, case_sensitive: bool = True) -> bool:

validate_volume_inputs(catalog, schema, volume, path, self.session_id_hex)

if not path.strip():
return False

folder, filename = parse_path(path)
volume_path = build_volume_path(catalog, schema, volume, folder)
query = f"LIST '{volume_path}'"
logger.debug(f"Executing query: {query}")

results = self._execute_list_query(query)
if not results:
return False

# Check if our file exists in results
# Row structure: [path, name, size, modification_time]
# Example: ['/Volumes/catalog/schema/volume/dir/file.txt', 'file.txt', 1024, 1752757716901]
# For directories: both path and name end with '/' (e.g., '/Volumes/.../dir/', 'dir/')
for row in results:
if len(row) > 1:
found_name = str(row[1]) # Second column is the filename

# Remove trailing slash from directories
if found_name.endswith('/'):
found_name = found_name[:-1]

if names_match(found_name, filename, case_sensitive):
return True

return False
56 changes: 56 additions & 0 deletions src/databricks/sql/volume/volume_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Tuple, Optional

from databricks.sql.exc import ProgrammingError

# Constants
VOLUME_PATH_TEMPLATE = "/Volumes/{catalog}/{schema}/{volume}/"
DIRECTORY_NOT_FOUND_ERROR = "No such file or directory"


def validate_volume_inputs(catalog: str, schema: str, volume: str, path: str, session_id_hex: Optional[str] = None) -> None:
if not all([catalog, schema, volume, path]):
raise ProgrammingError(
"All parameters (catalog, schema, volume, path) are required",
session_id_hex=session_id_hex
)


def parse_path(path: str) -> Tuple[str, str]:
if not path or path == '/':
return '', ''

# Handle trailing slash - treat "dir-1/" as looking for directory "dir-1"
path = path.rstrip('/')

if '/' in path:
folder, filename = path.rsplit('/', 1)
else:
folder, filename = '', path
return folder, filename


def escape_path_component(component: str) -> str:
"""Escape path component to prevent SQL injection.
"""
return component.replace("'", "''")


def build_volume_path(catalog: str, schema: str, volume: str, folder: str = "") -> str:
catalog_escaped = escape_path_component(catalog)
schema_escaped = escape_path_component(schema)
volume_escaped = escape_path_component(volume)
volume_path = VOLUME_PATH_TEMPLATE.format(
catalog=catalog_escaped,
schema=schema_escaped,
volume=volume_escaped
)
if folder:
folder_escaped = escape_path_component(folder)
volume_path += folder_escaped + "/"
return volume_path


def names_match(found_name: str, target_name: str, case_sensitive: bool) -> bool:
if case_sensitive:
return found_name == target_name
return found_name.lower() == target_name.lower()