diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 3d3587cae..9feb6e924 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -20,6 +20,7 @@ from databricks.sql.utils import ( ColumnTable, ColumnQueue, + concat_table_chunks, ) from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse from databricks.sql.telemetry.models.event import StatementType @@ -296,23 +297,6 @@ def _convert_columnar_table(self, table): return result - def merge_columnar(self, result1, result2) -> "ColumnTable": - """ - Function to merge / combining the columnar results into a single result - :param result1: - :param result2: - :return: - """ - - if result1.column_names != result2.column_names: - raise ValueError("The columns in the results don't match") - - merged_result = [ - result1.column_table[i] + result2.column_table[i] - for i in range(result1.num_columns) - ] - return ColumnTable(merged_result, result1.column_names) - def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -337,7 +321,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return concat_table_chunks(partial_result_chunks) def fetchmany_columnar(self, size: int): """ @@ -350,7 +334,7 @@ def fetchmany_columnar(self, size: int): results = self.results.next_n_rows(size) n_remaining_rows = size - results.num_rows self._next_row_index += results.num_rows - + partial_result_chunks = [results] while ( n_remaining_rows > 0 and not self.has_been_closed_server_side @@ -358,11 +342,11 @@ def fetchmany_columnar(self, size: int): ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" @@ -372,36 +356,34 @@ def fetchall_arrow(self) -> "pyarrow.Table": while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - if isinstance(results, ColumnTable) and isinstance( - partial_results, ColumnTable - ): - results = self.merge_columnar(results, partial_results) - else: - partial_result_chunks.append(partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + result_table = concat_table_chunks(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set - if isinstance(results, ColumnTable) and pyarrow: + if isinstance(result_table, ColumnTable) and pyarrow: data = { name: col - for name, col in zip(results.column_names, results.column_table) + for name, col in zip( + result_table.column_names, result_table.column_table + ) } return pyarrow.Table.from_pydict(data) - return pyarrow.concat_tables(partial_result_chunks, use_threads=True) + return result_table def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows - + partial_result_chunks = [results] while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = self.merge_columnar(results, partial_results) + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows - return results + return concat_table_chunks(partial_result_chunks) def fetchone(self) -> Optional[Row]: """ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 4617f7de6..c1d89ca5c 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -853,3 +853,25 @@ def _create_python_tuple(t_col_value_wrapper): result[i] = None return tuple(result) + + +def concat_table_chunks( + table_chunks: List[Union["pyarrow.Table", ColumnTable]] +) -> Union["pyarrow.Table", ColumnTable]: + if len(table_chunks) == 0: + return table_chunks + + if isinstance(table_chunks[0], ColumnTable): + ## Check if all have the same column names + if not all( + table.column_names == table_chunks[0].column_names for table in table_chunks + ): + raise ValueError("The columns in the results don't match") + + result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)] + for i in range(0, len(table_chunks)): + for j in range(table_chunks[i].num_columns): + result_table[j].extend(table_chunks[i].column_table[j]) + return ColumnTable(result_table, table_chunks[0].column_names) + else: + return pyarrow.concat_tables(table_chunks, use_threads=True) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index a47ab786f..713342b2e 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -1,8 +1,17 @@ import decimal import datetime from datetime import timezone, timedelta +import pytest +from databricks.sql.utils import ( + convert_to_assigned_datatypes_in_column_table, + ColumnTable, + concat_table_chunks, +) -from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table +try: + import pyarrow +except ImportError: + pyarrow = None class TestUtils: @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self): for index, entry in enumerate(converted_column_table): assert entry[0] == expected_convertion[index][0] assert isinstance(entry[0], expected_convertion[index][1]) + + def test_concat_table_chunks_column_table(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"]) + + result_table = concat_table_chunks([column_table1, column_table2]) + + assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]] + assert result_table.column_names == ["col1", "col2"] + + @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") + def test_concat_table_chunks_arrow_table(self): + arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]}) + arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]}) + + result_table = concat_table_chunks([arrow_table1, arrow_table2]) + assert result_table.column_names == ["col1", "col2"] + assert result_table.column("col1").to_pylist() == [1, 2, 3, 4] + assert result_table.column("col2").to_pylist() == [5, 6, 7, 8] + + def test_concat_table_chunks_empty(self): + result_table = concat_table_chunks([]) + assert result_table == [] + + def test_concat_table_chunks__incorrect_column_names_error(self): + column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"]) + column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"]) + + with pytest.raises(ValueError): + concat_table_chunks([column_table1, column_table2])