From 7521a617c121ead96a21ca47959a53b2db2da090 Mon Sep 17 00:00:00 2001 From: Mariatta Wijaya Date: Wed, 10 May 2023 15:42:01 -0700 Subject: [PATCH] Feat: Threaded MutationsBatcher (#722) - Batch mutations in a thread to allow concurrent batching - Flush the batch every second - Add flow control to control inflight requests Co-authored-by: Mattie Fu --- docs/batcher.rst | 6 + docs/usage.rst | 1 + google/cloud/bigtable/batcher.py | 366 +++++++++++++++++++++++++------ google/cloud/bigtable/table.py | 6 +- tests/unit/test_batcher.py | 218 ++++++++++++------ 5 files changed, 469 insertions(+), 128 deletions(-) create mode 100644 docs/batcher.rst diff --git a/docs/batcher.rst b/docs/batcher.rst new file mode 100644 index 000000000..9ac335be1 --- /dev/null +++ b/docs/batcher.rst @@ -0,0 +1,6 @@ +Mutations Batching +~~~~~~~~~~~~~~~~~~ + +.. automodule:: google.cloud.bigtable.batcher + :members: + :show-inheritance: diff --git a/docs/usage.rst b/docs/usage.rst index 33bf7bb7f..73a32b039 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -17,6 +17,7 @@ Using the API row-data row-filters row-set + batcher In the hierarchy of API concepts diff --git a/google/cloud/bigtable/batcher.py b/google/cloud/bigtable/batcher.py index 3c23f4436..6b06ec060 100644 --- a/google/cloud/bigtable/batcher.py +++ b/google/cloud/bigtable/batcher.py @@ -13,104 +13,251 @@ # limitations under the License. """User friendly container for Google Cloud Bigtable MutationBatcher.""" +import threading +import queue +import concurrent.futures +import atexit -FLUSH_COUNT = 1000 -MAX_MUTATIONS = 100000 -MAX_ROW_BYTES = 5242880 # 5MB +from google.api_core.exceptions import from_grpc_status +from dataclasses import dataclass -class MaxMutationsError(ValueError): - """The number of mutations for bulk request is too big.""" +FLUSH_COUNT = 100 # after this many elements, send out the batch + +MAX_MUTATION_SIZE = 20 * 1024 * 1024 # 20MB # after this many bytes, send out the batch + +MAX_OUTSTANDING_BYTES = 100 * 1024 * 1024 # 100MB # max inflight byte size. + +MAX_OUTSTANDING_ELEMENTS = 100000 # max inflight mutations. + + +class MutationsBatchError(Exception): + """Error in the batch request""" + + def __init__(self, message, exc): + self.exc = exc + self.message = message + super().__init__(self.message) + + +class _MutationsBatchQueue(object): + """Private Threadsafe Queue to hold rows for batching.""" + + def __init__(self, max_mutation_bytes=MAX_MUTATION_SIZE, flush_count=FLUSH_COUNT): + """Specify the queue constraints""" + self._queue = queue.Queue() + self.total_mutation_count = 0 + self.total_size = 0 + self.max_mutation_bytes = max_mutation_bytes + self.flush_count = flush_count + + def get(self): + """Retrieve an item from the queue. Recalculate queue size.""" + row = self._queue.get() + mutation_size = row.get_mutations_size() + self.total_mutation_count -= len(row._get_mutations()) + self.total_size -= mutation_size + return row + + def put(self, item): + """Insert an item to the queue. Recalculate queue size.""" + + mutation_count = len(item._get_mutations()) + + self._queue.put(item) + + self.total_size += item.get_mutations_size() + self.total_mutation_count += mutation_count + + def full(self): + """Check if the queue is full.""" + if ( + self.total_mutation_count >= self.flush_count + or self.total_size >= self.max_mutation_bytes + ): + return True + return False + + def empty(self): + return self._queue.empty() + + +@dataclass +class _BatchInfo: + """Keeping track of size of a batch""" + + mutations_count: int = 0 + rows_count: int = 0 + mutations_size: int = 0 + + +class _FlowControl(object): + def __init__( + self, + max_mutations=MAX_OUTSTANDING_ELEMENTS, + max_mutation_bytes=MAX_OUTSTANDING_BYTES, + ): + """Control the inflight requests. Keep track of the mutations, row bytes and row counts. + As requests to backend are being made, adjust the number of mutations being processed. + + If threshold is reached, block the flow. + Reopen the flow as requests are finished. + """ + self.max_mutations = max_mutations + self.max_mutation_bytes = max_mutation_bytes + self.inflight_mutations = 0 + self.inflight_size = 0 + self.event = threading.Event() + self.event.set() + + def is_blocked(self): + """Returns True if: + + - inflight mutations >= max_mutations, or + - inflight bytes size >= max_mutation_bytes, or + """ + + return ( + self.inflight_mutations >= self.max_mutations + or self.inflight_size >= self.max_mutation_bytes + ) + + def control_flow(self, batch_info): + """ + Calculate the resources used by this batch + """ + + self.inflight_mutations += batch_info.mutations_count + self.inflight_size += batch_info.mutations_size + self.set_flow_control_status() + + def wait(self): + """ + Wait until flow control pushback has been released. + It awakens as soon as `event` is set. + """ + self.event.wait() + + def set_flow_control_status(self): + """Check the inflight mutations and size. + + If values exceed the allowed threshold, block the event. + """ + if self.is_blocked(): + self.event.clear() # sleep + else: + self.event.set() # awaken the threads + + def release(self, batch_info): + """ + Release the resources. + Decrement the row size to allow enqueued mutations to be run. + """ + self.inflight_mutations -= batch_info.mutations_count + self.inflight_size -= batch_info.mutations_size + self.set_flow_control_status() class MutationsBatcher(object): """A MutationsBatcher is used in batch cases where the number of mutations - is large or unknown. It will store DirectRows in memory until one of the - size limits is reached, or an explicit call to flush() is performed. When - a flush event occurs, the DirectRows in memory will be sent to Cloud + is large or unknown. It will store :class:`DirectRow` in memory until one of the + size limits is reached, or an explicit call to :func:`flush()` is performed. When + a flush event occurs, the :class:`DirectRow` in memory will be sent to Cloud Bigtable. Batching mutations is more efficient than sending individual request. This class is not suited for usage in systems where each mutation must be guaranteed to be sent, since calling mutate may only result in an - in-memory change. In a case of a system crash, any DirectRows remaining in + in-memory change. In a case of a system crash, any :class:`DirectRow` remaining in memory will not necessarily be sent to the service, even after the - completion of the mutate() method. + completion of the :func:`mutate()` method. - TODO: Performance would dramatically improve if this class had the - capability of asynchronous, parallel RPCs. + Note on thread safety: The same :class:`MutationBatcher` cannot be shared by multiple end-user threads. :type table: class :param table: class:`~google.cloud.bigtable.table.Table`. :type flush_count: int :param flush_count: (Optional) Max number of rows to flush. If it - reaches the max number of rows it calls finish_batch() to mutate the - current row batch. Default is FLUSH_COUNT (1000 rows). + reaches the max number of rows it calls finish_batch() to mutate the + current row batch. Default is FLUSH_COUNT (1000 rows). :type max_row_bytes: int :param max_row_bytes: (Optional) Max number of row mutations size to - flush. If it reaches the max number of row mutations size it calls - finish_batch() to mutate the current row batch. Default is MAX_ROW_BYTES - (5 MB). + flush. If it reaches the max number of row mutations size it calls + finish_batch() to mutate the current row batch. Default is MAX_ROW_BYTES + (5 MB). + + :type flush_interval: float + :param flush_interval: (Optional) The interval (in seconds) between asynchronous flush. + Default is 1 second. """ - def __init__(self, table, flush_count=FLUSH_COUNT, max_row_bytes=MAX_ROW_BYTES): - self.rows = [] - self.total_mutation_count = 0 - self.total_size = 0 + def __init__( + self, + table, + flush_count=FLUSH_COUNT, + max_row_bytes=MAX_MUTATION_SIZE, + flush_interval=1, + ): + self._rows = _MutationsBatchQueue( + max_mutation_bytes=max_row_bytes, flush_count=flush_count + ) self.table = table - self.flush_count = flush_count - self.max_row_bytes = max_row_bytes + self._executor = concurrent.futures.ThreadPoolExecutor() + atexit.register(self.close) + self._timer = threading.Timer(flush_interval, self.flush) + self._timer.start() + self.flow_control = _FlowControl( + max_mutations=MAX_OUTSTANDING_ELEMENTS, + max_mutation_bytes=MAX_OUTSTANDING_BYTES, + ) + self.futures_mapping = {} + self.exceptions = queue.Queue() + + @property + def flush_count(self): + return self._rows.flush_count + + @property + def max_row_bytes(self): + return self._rows.max_mutation_bytes + + def __enter__(self): + """Starting the MutationsBatcher as a context manager""" + return self def mutate(self, row): """Add a row to the batch. If the current batch meets one of the size - limits, the batch is sent synchronously. + limits, the batch is sent asynchronously. For example: - .. literalinclude:: snippets.py + .. literalinclude:: snippets_table.py :start-after: [START bigtable_api_batcher_mutate] :end-before: [END bigtable_api_batcher_mutate] :dedent: 4 :type row: class - :param row: class:`~google.cloud.bigtable.row.DirectRow`. + :param row: :class:`~google.cloud.bigtable.row.DirectRow`. :raises: One of the following: - * :exc:`~.table._BigtableRetryableError` if any - row returned a transient error. - * :exc:`RuntimeError` if the number of responses doesn't - match the number of rows that were retried - * :exc:`.batcher.MaxMutationsError` if any row exceeds max - mutations count. - """ - mutation_count = len(row._get_mutations()) - if mutation_count > MAX_MUTATIONS: - raise MaxMutationsError( - "The row key {} exceeds the number of mutations {}.".format( - row.row_key, mutation_count - ) - ) - - if (self.total_mutation_count + mutation_count) >= MAX_MUTATIONS: - self.flush() - - self.rows.append(row) - self.total_mutation_count += mutation_count - self.total_size += row.get_mutations_size() + * :exc:`~.table._BigtableRetryableError` if any row returned a transient error. + * :exc:`RuntimeError` if the number of responses doesn't match the number of rows that were retried + """ + self._rows.put(row) - if self.total_size >= self.max_row_bytes or len(self.rows) >= self.flush_count: - self.flush() + if self._rows.full(): + self._flush_async() def mutate_rows(self, rows): """Add multiple rows to the batch. If the current batch meets one of the size - limits, the batch is sent synchronously. + limits, the batch is sent asynchronously. For example: - .. literalinclude:: snippets.py + .. literalinclude:: snippets_table.py :start-after: [START bigtable_api_batcher_mutate_rows] :end-before: [END bigtable_api_batcher_mutate_rows] :dedent: 4 @@ -119,28 +266,119 @@ def mutate_rows(self, rows): :param rows: list:[`~google.cloud.bigtable.row.DirectRow`]. :raises: One of the following: - * :exc:`~.table._BigtableRetryableError` if any - row returned a transient error. - * :exc:`RuntimeError` if the number of responses doesn't - match the number of rows that were retried - * :exc:`.batcher.MaxMutationsError` if any row exceeds max - mutations count. + * :exc:`~.table._BigtableRetryableError` if any row returned a transient error. + * :exc:`RuntimeError` if the number of responses doesn't match the number of rows that were retried """ for row in rows: self.mutate(row) def flush(self): - """Sends the current. batch to Cloud Bigtable. + """Sends the current batch to Cloud Bigtable synchronously. For example: - .. literalinclude:: snippets.py + .. literalinclude:: snippets_table.py :start-after: [START bigtable_api_batcher_flush] :end-before: [END bigtable_api_batcher_flush] :dedent: 4 + :raises: + * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. + """ + rows_to_flush = [] + while not self._rows.empty(): + rows_to_flush.append(self._rows.get()) + response = self._flush_rows(rows_to_flush) + return response + + def _flush_async(self): + """Sends the current batch to Cloud Bigtable asynchronously. + + :raises: + * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. + """ + + rows_to_flush = [] + mutations_count = 0 + mutations_size = 0 + rows_count = 0 + batch_info = _BatchInfo() + + while not self._rows.empty(): + row = self._rows.get() + mutations_count += len(row._get_mutations()) + mutations_size += row.get_mutations_size() + rows_count += 1 + rows_to_flush.append(row) + batch_info.mutations_count = mutations_count + batch_info.rows_count = rows_count + batch_info.mutations_size = mutations_size + + if ( + rows_count >= self.flush_count + or mutations_size >= self.max_row_bytes + or mutations_count >= self.flow_control.max_mutations + or mutations_size >= self.flow_control.max_mutation_bytes + or self._rows.empty() # submit when it reached the end of the queue + ): + # wait for resources to become available, before submitting any new batch + self.flow_control.wait() + # once unblocked, submit a batch + # event flag will be set by control_flow to block subsequent thread, but not blocking this one + self.flow_control.control_flow(batch_info) + future = self._executor.submit(self._flush_rows, rows_to_flush) + self.futures_mapping[future] = batch_info + future.add_done_callback(self._batch_completed_callback) + + # reset and start a new batch + rows_to_flush = [] + mutations_size = 0 + rows_count = 0 + mutations_count = 0 + batch_info = _BatchInfo() + + def _batch_completed_callback(self, future): + """Callback for when the mutation has finished. + + Raise exceptions if there's any. + Release the resources locked by the flow control and allow enqueued tasks to be run. + """ + + processed_rows = self.futures_mapping[future] + self.flow_control.release(processed_rows) + del self.futures_mapping[future] + + def _flush_rows(self, rows_to_flush): + """Mutate the specified rows. + + :raises: + * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. + """ + responses = [] + if len(rows_to_flush) > 0: + response = self.table.mutate_rows(rows_to_flush) + + for result in response: + if result.code != 0: + exc = from_grpc_status(result.code, result.message) + self.exceptions.put(exc) + responses.append(result) + + return responses + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Clean up resources. Flush and shutdown the ThreadPoolExecutor.""" + self.close() + + def close(self): + """Clean up resources. Flush and shutdown the ThreadPoolExecutor. + Any errors will be raised. + + :raises: + * :exc:`.batcherMutationsBatchError` if there's any error in the mutations. """ - if len(self.rows) != 0: - self.table.mutate_rows(self.rows) - self.total_mutation_count = 0 - self.total_size = 0 - self.rows = [] + self.flush() + self._executor.shutdown(wait=True) + atexit.unregister(self.close) + if self.exceptions.qsize() > 0: + exc = list(self.exceptions.queue) + raise MutationsBatchError("Errors in batch mutations.", exc=exc) diff --git a/google/cloud/bigtable/table.py b/google/cloud/bigtable/table.py index 8605992ba..e3191a729 100644 --- a/google/cloud/bigtable/table.py +++ b/google/cloud/bigtable/table.py @@ -32,7 +32,7 @@ from google.cloud.bigtable.column_family import _gc_rule_from_pb from google.cloud.bigtable.column_family import ColumnFamily from google.cloud.bigtable.batcher import MutationsBatcher -from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_ROW_BYTES +from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE from google.cloud.bigtable.encryption_info import EncryptionInfo from google.cloud.bigtable.policy import Policy from google.cloud.bigtable.row import AppendRow @@ -844,7 +844,9 @@ def drop_by_prefix(self, row_key_prefix, timeout=None): request={"name": self.name, "row_key_prefix": _to_bytes(row_key_prefix)} ) - def mutations_batcher(self, flush_count=FLUSH_COUNT, max_row_bytes=MAX_ROW_BYTES): + def mutations_batcher( + self, flush_count=FLUSH_COUNT, max_row_bytes=MAX_MUTATION_SIZE + ): """Factory to create a mutation batcher associated with this instance. For example: diff --git a/tests/unit/test_batcher.py b/tests/unit/test_batcher.py index 9ae6ed175..a238b2852 100644 --- a/tests/unit/test_batcher.py +++ b/tests/unit/test_batcher.py @@ -14,122 +14,118 @@ import mock +import time + import pytest from google.cloud.bigtable.row import DirectRow +from google.cloud.bigtable.batcher import ( + _FlowControl, + MutationsBatcher, + MutationsBatchError, +) TABLE_ID = "table-id" TABLE_NAME = "/tables/" + TABLE_ID -def _make_mutation_batcher(table, **kw): - from google.cloud.bigtable.batcher import MutationsBatcher - - return MutationsBatcher(table, **kw) - - def test_mutation_batcher_constructor(): table = _Table(TABLE_NAME) - - mutation_batcher = _make_mutation_batcher(table) - assert table is mutation_batcher.table + with MutationsBatcher(table) as mutation_batcher: + assert table is mutation_batcher.table def test_mutation_batcher_mutate_row(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table) + with MutationsBatcher(table=table) as mutation_batcher: - rows = [ - DirectRow(row_key=b"row_key"), - DirectRow(row_key=b"row_key_2"), - DirectRow(row_key=b"row_key_3"), - DirectRow(row_key=b"row_key_4"), - ] + rows = [ + DirectRow(row_key=b"row_key"), + DirectRow(row_key=b"row_key_2"), + DirectRow(row_key=b"row_key_3"), + DirectRow(row_key=b"row_key_4"), + ] - mutation_batcher.mutate_rows(rows) - mutation_batcher.flush() + mutation_batcher.mutate_rows(rows) assert table.mutation_calls == 1 def test_mutation_batcher_mutate(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table) + with MutationsBatcher(table=table) as mutation_batcher: - row = DirectRow(row_key=b"row_key") - row.set_cell("cf1", b"c1", 1) - row.set_cell("cf1", b"c2", 2) - row.set_cell("cf1", b"c3", 3) - row.set_cell("cf1", b"c4", 4) - - mutation_batcher.mutate(row) + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", 1) + row.set_cell("cf1", b"c2", 2) + row.set_cell("cf1", b"c3", 3) + row.set_cell("cf1", b"c4", 4) - mutation_batcher.flush() + mutation_batcher.mutate(row) assert table.mutation_calls == 1 def test_mutation_batcher_flush_w_no_rows(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table) - mutation_batcher.flush() + with MutationsBatcher(table=table) as mutation_batcher: + mutation_batcher.flush() assert table.mutation_calls == 0 def test_mutation_batcher_mutate_w_max_flush_count(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table, flush_count=3) + with MutationsBatcher(table=table, flush_count=3) as mutation_batcher: - row_1 = DirectRow(row_key=b"row_key_1") - row_2 = DirectRow(row_key=b"row_key_2") - row_3 = DirectRow(row_key=b"row_key_3") + row_1 = DirectRow(row_key=b"row_key_1") + row_2 = DirectRow(row_key=b"row_key_2") + row_3 = DirectRow(row_key=b"row_key_3") - mutation_batcher.mutate(row_1) - mutation_batcher.mutate(row_2) - mutation_batcher.mutate(row_3) + mutation_batcher.mutate(row_1) + mutation_batcher.mutate(row_2) + mutation_batcher.mutate(row_3) assert table.mutation_calls == 1 -@mock.patch("google.cloud.bigtable.batcher.MAX_MUTATIONS", new=3) -def test_mutation_batcher_mutate_with_max_mutations_failure(): - from google.cloud.bigtable.batcher import MaxMutationsError - +@mock.patch("google.cloud.bigtable.batcher.MAX_OUTSTANDING_ELEMENTS", new=3) +def test_mutation_batcher_mutate_w_max_mutations(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table) + with MutationsBatcher(table=table) as mutation_batcher: - row = DirectRow(row_key=b"row_key") - row.set_cell("cf1", b"c1", 1) - row.set_cell("cf1", b"c2", 2) - row.set_cell("cf1", b"c3", 3) - row.set_cell("cf1", b"c4", 4) + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", 1) + row.set_cell("cf1", b"c2", 2) + row.set_cell("cf1", b"c3", 3) - with pytest.raises(MaxMutationsError): mutation_batcher.mutate(row) + assert table.mutation_calls == 1 + -@mock.patch("google.cloud.bigtable.batcher.MAX_MUTATIONS", new=3) -def test_mutation_batcher_mutate_w_max_mutations(): +def test_mutation_batcher_mutate_w_max_row_bytes(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher(table=table) + with MutationsBatcher( + table=table, max_row_bytes=3 * 1024 * 1024 + ) as mutation_batcher: - row = DirectRow(row_key=b"row_key") - row.set_cell("cf1", b"c1", 1) - row.set_cell("cf1", b"c2", 2) - row.set_cell("cf1", b"c3", 3) + number_of_bytes = 1 * 1024 * 1024 + max_value = b"1" * number_of_bytes - mutation_batcher.mutate(row) - mutation_batcher.flush() + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", max_value) + row.set_cell("cf1", b"c2", max_value) + row.set_cell("cf1", b"c3", max_value) + + mutation_batcher.mutate(row) assert table.mutation_calls == 1 -def test_mutation_batcher_mutate_w_max_row_bytes(): +def test_mutations_batcher_flushed_when_closed(): table = _Table(TABLE_NAME) - mutation_batcher = _make_mutation_batcher( - table=table, max_row_bytes=3 * 1024 * 1024 - ) + mutation_batcher = MutationsBatcher(table=table, max_row_bytes=3 * 1024 * 1024) number_of_bytes = 1 * 1024 * 1024 max_value = b"1" * number_of_bytes @@ -137,13 +133,108 @@ def test_mutation_batcher_mutate_w_max_row_bytes(): row = DirectRow(row_key=b"row_key") row.set_cell("cf1", b"c1", max_value) row.set_cell("cf1", b"c2", max_value) - row.set_cell("cf1", b"c3", max_value) mutation_batcher.mutate(row) + assert table.mutation_calls == 0 + + mutation_batcher.close() + + assert table.mutation_calls == 1 + + +def test_mutations_batcher_context_manager_flushed_when_closed(): + table = _Table(TABLE_NAME) + with MutationsBatcher( + table=table, max_row_bytes=3 * 1024 * 1024 + ) as mutation_batcher: + + number_of_bytes = 1 * 1024 * 1024 + max_value = b"1" * number_of_bytes + + row = DirectRow(row_key=b"row_key") + row.set_cell("cf1", b"c1", max_value) + row.set_cell("cf1", b"c2", max_value) + + mutation_batcher.mutate(row) assert table.mutation_calls == 1 +@mock.patch("google.cloud.bigtable.batcher.MutationsBatcher.flush") +def test_mutations_batcher_flush_interval(mocked_flush): + table = _Table(TABLE_NAME) + flush_interval = 0.5 + mutation_batcher = MutationsBatcher(table=table, flush_interval=flush_interval) + + assert mutation_batcher._timer.interval == flush_interval + mocked_flush.assert_not_called() + + time.sleep(0.4) + mocked_flush.assert_not_called() + + time.sleep(0.1) + mocked_flush.assert_called_once_with() + + mutation_batcher.close() + + +def test_mutations_batcher_response_with_error_codes(): + from google.rpc.status_pb2 import Status + + mocked_response = [Status(code=1), Status(code=5)] + + with mock.patch("tests.unit.test_batcher._Table") as mocked_table: + table = mocked_table.return_value + mutation_batcher = MutationsBatcher(table=table) + + row1 = DirectRow(row_key=b"row_key") + row2 = DirectRow(row_key=b"row_key") + table.mutate_rows.return_value = mocked_response + + mutation_batcher.mutate_rows([row1, row2]) + with pytest.raises(MutationsBatchError) as exc: + mutation_batcher.close() + assert exc.value.message == "Errors in batch mutations." + assert len(exc.value.exc) == 2 + + assert exc.value.exc[0].message == mocked_response[0].message + assert exc.value.exc[1].message == mocked_response[1].message + + +def test_flow_control_event_is_set_when_not_blocked(): + flow_control = _FlowControl() + + flow_control.set_flow_control_status() + assert flow_control.event.is_set() + + +def test_flow_control_event_is_not_set_when_blocked(): + flow_control = _FlowControl() + + flow_control.inflight_mutations = flow_control.max_mutations + flow_control.inflight_size = flow_control.max_mutation_bytes + + flow_control.set_flow_control_status() + assert not flow_control.event.is_set() + + +@mock.patch("concurrent.futures.ThreadPoolExecutor.submit") +def test_flush_async_batch_count(mocked_executor_submit): + table = _Table(TABLE_NAME) + mutation_batcher = MutationsBatcher(table=table, flush_count=2) + + number_of_bytes = 1 * 1024 * 1024 + max_value = b"1" * number_of_bytes + for index in range(5): + row = DirectRow(row_key=f"row_key_{index}") + row.set_cell("cf1", b"c1", max_value) + mutation_batcher.mutate(row) + mutation_batcher._flush_async() + + # 3 batches submitted. 2 batches of 2 items, and the last one a single item batch. + assert mocked_executor_submit.call_count == 3 + + class _Instance(object): def __init__(self, client=None): self._client = client @@ -156,5 +247,8 @@ def __init__(self, name, client=None): self.mutation_calls = 0 def mutate_rows(self, rows): + from google.rpc.status_pb2 import Status + self.mutation_calls += 1 - return rows + + return [Status(code=0) for _ in rows]