Skip to content

Commit

Permalink
Rename bucket to gcs_bucket in GCSToS3Operator (apache#33031)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Vincent <[email protected]>
  • Loading branch information
hankehly and vincbeck authored Sep 26, 2023
1 parent 20b7cfc commit b6499ac
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 19 deletions.
27 changes: 20 additions & 7 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class GCSToS3Operator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:GCSToS3Operator`
:param bucket: The Google Cloud Storage bucket to find the objects. (templated)
:param gcs_bucket: The Google Cloud Storage bucket to find the objects. (templated)
:param bucket: (Deprecated) Use ``gcs_bucket`` instead.
:param prefix: Prefix string which filters objects whose name begin with
this prefix. (templated)
:param delimiter: (Deprecated) The delimiter by which you want to filter the objects. (templated)
Expand Down Expand Up @@ -87,7 +88,7 @@ class GCSToS3Operator(BaseOperator):
"""

template_fields: Sequence[str] = (
"bucket",
"gcs_bucket",
"prefix",
"delimiter",
"dest_s3_key",
Expand All @@ -99,7 +100,8 @@ class GCSToS3Operator(BaseOperator):
def __init__(
self,
*,
bucket: str,
gcs_bucket: str | None = None,
bucket: str | None = None,
prefix: str | None = None,
delimiter: str | None = None,
gcp_conn_id: str = "google_cloud_default",
Expand All @@ -117,7 +119,18 @@ def __init__(
) -> None:
super().__init__(**kwargs)

self.bucket = bucket
if bucket:
warnings.warn(
"The ``bucket`` parameter is deprecated and will be removed in a future version. "
"Please use ``gcs_bucket`` instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
self.gcs_bucket = bucket
if gcs_bucket:
self.gcs_bucket = gcs_bucket
if not (bucket or gcs_bucket):
raise ValueError("You must pass either ``bucket`` or ``gcs_bucket``.")
self.prefix = prefix
self.gcp_conn_id = gcp_conn_id
self.dest_aws_conn_id = dest_aws_conn_id
Expand Down Expand Up @@ -161,13 +174,13 @@ def execute(self, context: Context) -> list[str]:

self.log.info(
"Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
self.bucket,
self.gcs_bucket,
self.delimiter,
self.prefix,
)

list_kwargs = {
"bucket_name": self.bucket,
"bucket_name": self.gcs_bucket,
"prefix": self.prefix,
"delimiter": self.delimiter,
"user_project": self.gcp_user_project,
Expand Down Expand Up @@ -206,7 +219,7 @@ def execute(self, context: Context) -> list[str]:
if gcs_files:
for file in gcs_files:
with gcs_hook.provide_file(
object_name=file, bucket_name=self.bucket, user_project=self.gcp_user_project
object_name=file, bucket_name=self.gcs_bucket, user_project=self.gcp_user_project
) as local_tmp_file:
dest_key = os.path.join(self.dest_s3_key, file)
self.log.info("Saving file to %s", dest_key)
Expand Down
56 changes: 45 additions & 11 deletions tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_execute__match_glob(self, mock_hook):

operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
dest_aws_conn_id="aws_default",
dest_s3_key=S3_BUCKET,
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_execute_incremental(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_execute_without_replace(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_execute_without_replace_with_folder_structure(self, mock_hook, dest_s3_
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_execute(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand All @@ -196,6 +196,40 @@ def test_execute(self, mock_hook):
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))

@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_gcs_bucket_rename_compatibility(self, mock_hook):
"""
Tests the same conditions as `test_execute` using the deprecated `bucket` parameter instead of
`gcs_bucket`. This test can be removed when the `bucket` parameter is removed.
"""
mock_hook.return_value.list.return_value = MOCK_FILES
with NamedTemporaryFile() as f:
gcs_provide_file = mock_hook.return_value.provide_file
gcs_provide_file.return_value.__enter__.return_value.name = f.name
bucket_param_deprecated_message = (
"The ``bucket`` parameter is deprecated and will be removed in a future version. "
"Please use ``gcs_bucket`` instead."
)
with pytest.deprecated_call(match=bucket_param_deprecated_message):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
prefix=PREFIX,
match_glob=DELIMITER,
dest_aws_conn_id="aws_default",
dest_s3_key=S3_BUCKET,
replace=False,
)
hook, _ = _create_test_bucket()
# we expect all MOCK_FILES to be uploaded
# and all MOCK_FILES to be present at the S3 bucket
uploaded_files = operator.execute(None)
assert sorted(MOCK_FILES) == sorted(uploaded_files)
assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/"))
with pytest.raises(ValueError) as excinfo:
GCSToS3Operator(task_id=TASK_ID, dest_s3_key=S3_BUCKET)
assert str(excinfo.value) == "You must pass either ``bucket`` or ``gcs_bucket``."

@mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook")
def test_execute_with_replace(self, mock_hook):
mock_hook.return_value.list.return_value = MOCK_FILES
Expand All @@ -206,7 +240,7 @@ def test_execute_with_replace(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand All @@ -233,7 +267,7 @@ def test_execute_incremental_with_replace(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand Down Expand Up @@ -261,7 +295,7 @@ def test_execute_should_handle_with_default_dest_s3_extra_args(self, s3_mock_hoo
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand All @@ -284,7 +318,7 @@ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, s3_mock_hook, m
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand All @@ -310,7 +344,7 @@ def test_execute_with_s3_acl_policy(self, mock_load_file, mock_gcs_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand All @@ -335,7 +369,7 @@ def test_execute_without_keep_director_structure(self, mock_hook):
with pytest.deprecated_call(match=deprecated_call_match):
operator = GCSToS3Operator(
task_id=TASK_ID,
bucket=GCS_BUCKET,
gcs_bucket=GCS_BUCKET,
prefix=PREFIX,
delimiter=DELIMITER,
dest_aws_conn_id="aws_default",
Expand Down
2 changes: 1 addition & 1 deletion tests/system/providers/amazon/aws/example_gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def upload_gcs_file(bucket_name: str, object_name: str, user_project: str):
# [START howto_transfer_gcs_to_s3]
gcs_to_s3 = GCSToS3Operator(
task_id="gcs_to_s3",
bucket=gcs_bucket,
gcs_bucket=gcs_bucket,
dest_s3_key=f"s3://{s3_bucket}/{s3_key}",
replace=True,
gcp_user_project=gcp_user_project,
Expand Down

0 comments on commit b6499ac

Please sign in to comment.