From b2cf39b80c0cb53787c496d67d1b6ffe314a8bc0 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 13:18:48 +0530 Subject: [PATCH 1/3] fixed --- src/databricks/sql/result_set.py | 54 ++++++++++++-------------------- src/databricks/sql/utils.py | 22 +++++++++++++ tests/unit/test_util.py | 41 +++++++++++++++++++++++- 3 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 074877d32..9ed0188bf 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,7 +20,12 @@ from databricks.sql.types import Row from databricks.sql.exc import RequestError, CursorAlreadyClosedError -from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue +from databricks.sql.utils import ( + ExecuteResponse, + ColumnTable, + ColumnQueue, + concat_table_chunks, +) logger = logging.getLogger(__name__) @@ -251,23 +256,6 @@ def _convert_arrow_table(self, table): res = df.to_numpy(na_value=None, dtype="object") return [ResultRow(*v) for v in res] - def merge_columnar(self, result1, result2) -> "ColumnTable": - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -292,7 +280,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def fetchmany_columnar(self, size: int): """ @@ -305,7 +293,7 @@ def fetchmany_columnar(self, size: int): results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - + partial_result_chunks = [results] while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -313,11 +301,11 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" @@ -327,36 +315,34 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - partial_result_chunks.append(partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + result_table = concat_table_chunks(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: + if isinstance(result_table, ColumnTable) and pyarrow: data = { name: col - for name, col in zip(results.column_names, results.column_table) + for name, col in zip( + result_table.column_names, result_table.column_table + ) } return pyarrow.Table.from_pydict(data) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return result_table def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchone(self) -> Optional[Row]: """ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index a3e3e1dd0..d62f2394f 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -785,3 +785,25 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) + + +def concat_table_chunks( + table_chunks: List[Union["pyarrow.Table", ColumnTable]] +) -> Union["pyarrow.Table", ColumnTable]: + if len(table_chunks) == 0: + return table_chunks + + if isinstance(table_chunks[0], ColumnTable): + ## Check if all have the same column names + if not all( + table.column_names == table_chunks[0].column_names for table in table_chunks + ): + raise ValueError("The columns in the results don't match") + + result_table = table_chunks[0].column_table + for i in range(1, len(table_chunks)): + for j in range(table_chunks[i].num_columns): + result_table[j].extend(table_chunks[i].column_table[j]) + return ColumnTable(result_table, table_chunks[0].column_names) + else: + return pyarrow.concat_tables(table_chunks, use_threads=True) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index a47ab786f..713342b2e 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -1,8 +1,17 @@ import decimal import datetime from datetime import timezone, timedelta +import pytest +from databricks.sql.utils import ( + convert_to_assigned_datatypes_in_column_table, + ColumnTable, + concat_table_chunks, +) -from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table +try: + import pyarrow +except ImportError: + pyarrow = None class TestUtils: @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self): for index, entry in enumerate(converted_column_table): assert entry[0] == expected_convertion[index][0] assert isinstance(entry[0], expected_convertion[index][1]) + + def test_concat_table_chunks_column_table(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"]) + + result_table = concat_table_chunks([column_table1, column_table2]) + + assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert result_table.column_names == ["col1", "col2"] + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_concat_table_chunks_arrow_table(self): + arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]}) + arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]}) + + result_table = concat_table_chunks([arrow_table1, arrow_table2]) + assert result_table.column_names == ["col1", "col2"] + assert result_table.column("col1").to_pylist() == [1, 2, 3, 4] + assert result_table.column("col2").to_pylist() == [5, 6, 7, 8] + + def test_concat_table_chunks_empty(self): + result_table = concat_table_chunks([]) + assert result_table == [] + + def test_concat_table_chunks__incorrect_column_names_error(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"]) + + with pytest.raises(ValueError): + concat_table_chunks([column_table1, column_table2]) From 3232768fb68672d5474eb4fe7bb2ab1a21862a20 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 15:21:13 +0530 Subject: [PATCH 2/3] Minor fix --- src/databricks/sql/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d62f2394f..0c2dd54e5 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -800,8 +800,8 @@ def concat_table_chunks( ): raise ValueError("The columns in the results don't match") - result_table = table_chunks[0].column_table - for i in range(1, len(table_chunks)): + result_table = [[] for _ in range(table_chunks[0].num_columns)] + for i in range(0, len(table_chunks)): for j in range(table_chunks[i].num_columns): result_table[j].extend(table_chunks[i].column_table[j]) return ColumnTable(result_table, table_chunks[0].column_names) From 6eba353c5251198522e3db9d2eefc94ea0e6f2cd Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 22 Jul 2025 15:26:19 +0530 Subject: [PATCH 3/3] more types --- src/databricks/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 0c2dd54e5..9a70ed38d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -800,7 +800,7 @@ def concat_table_chunks( ): raise ValueError("The columns in the results don't match") - result_table = [[] for _ in range(table_chunks[0].num_columns)] + result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)] for i in range(0, len(table_chunks)): for j in range(table_chunks[i].num_columns): result_table[j].extend(table_chunks[i].column_table[j])