diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index 01070ad64..28859d3d7 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -833,7 +833,7 @@ def block_waiting_available_thread(seq: Sequence[Future], max_workers: int) -> N def check_schema_changes(columns_types: dict[str, str], table_input: dict[str, Any] | None, mode: str) -> None: """Check schema changes.""" - if (table_input is not None) and (mode in ("append", "overwrite_partitions")): + if (table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")): catalog_cols: dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]} for c, t in columns_types.items(): if c not in catalog_cols: diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index e81726429..09d55d7dc 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -31,7 +31,7 @@ def _update_if_necessary( if value is not None: if key not in dic or dic[key] != value: dic[key] = value - if mode in ("append", "overwrite_partitions"): + if mode in ("append", "overwrite_partitions", "overwrite_files"): return "update" return mode @@ -150,9 +150,10 @@ def _create_table( # noqa: PLR0912,PLR0915 client_glue = _utils.client(service_name="glue", session=boto3_session) skip_archive: bool = not catalog_versioning - if mode not in ("overwrite", "append", "overwrite_partitions", "update"): + if mode not in ("overwrite", "append", "overwrite_partitions", "overwrite_files", "update"): raise exceptions.InvalidArgument( - f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'." + f"{mode} is not a valid mode. It must be 'overwrite', " + f"'append', 'overwrite_partitions' or 'overwrite_files'." ) args: dict[str, Any] = _catalog_id( catalog_id=catalog_id, @@ -304,7 +305,7 @@ def _create_parquet_table( _logger.debug("catalog_table_input: %s", catalog_table_input) table_input: dict[str, Any] - if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")): + if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")): table_input = catalog_table_input is_table_updated = _update_table_input(table_input, columns_types) @@ -366,7 +367,7 @@ def _create_orc_table( _logger.debug("catalog_table_input: %s", catalog_table_input) table_input: dict[str, Any] - if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")): + if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")): table_input = catalog_table_input is_table_updated = _update_table_input(table_input, columns_types) @@ -436,7 +437,7 @@ def _create_csv_table( _utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) table_input: dict[str, Any] - if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")): + if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")): table_input = catalog_table_input is_table_updated = _update_table_input(table_input, columns_types, allow_reorder=False) @@ -508,7 +509,7 @@ def _create_json_table( table_input: dict[str, Any] if schema_evolution is False: _utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) - if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")): + if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions", "overwrite_files")): table_input = catalog_table_input is_table_updated = _update_table_input(table_input, columns_types) @@ -1098,7 +1099,7 @@ def create_csv_table( If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")) Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html sep @@ -1278,7 +1279,7 @@ def create_json_table( If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")) Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html serde_library diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index 80840ce05..4669d5831 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import uuid from abc import ABC, abstractmethod from enum import Enum from typing import TYPE_CHECKING, Any, Callable, NamedTuple @@ -48,7 +47,7 @@ def _extract_dtypes_from_table_input(table_input: dict[str, Any]) -> dict[str, s def _apply_dtype( df: pd.DataFrame, dtype: dict[str, str], catalog_table_input: dict[str, Any] | None, mode: str ) -> pd.DataFrame: - if mode in ("append", "overwrite_partitions"): + if mode in ("append", "overwrite_partitions", "overwrite_files"): if catalog_table_input is not None: catalog_types: dict[str, str] | None = _extract_dtypes_from_table_input(table_input=catalog_table_input) if catalog_types is not None: @@ -72,6 +71,7 @@ def _validate_args( columns_comments: dict[str, str] | None, columns_parameters: dict[str, dict[str, str]] | None, execution_engine: Enum, + max_rows_by_file: int | None = None, ) -> None: if df.empty is True: _logger.warning("Empty DataFrame will be written.") @@ -107,6 +107,11 @@ def _validate_args( raise exceptions.InvalidArgumentValue( "Please pass a value greater than 1 for the number of buckets for bucketing." ) + elif mode == "overwrite_files" and (max_rows_by_file or bucketing_info): + raise exceptions.InvalidArgumentValue( + "When mode is set to 'overwrite_files', the " + "`max_rows_by_file` and `bucketing_info` arguments cannot be set." + ) class _SanitizeResult(NamedTuple): @@ -279,7 +284,6 @@ def write( # noqa: PLR0913 partitions_values: dict[str, list[str]] = {} mode = "append" if mode is None else mode - filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) s3_client = _utils.client(service_name="s3", session=boto3_session) @@ -328,7 +332,6 @@ def write( # noqa: PLR0913 paths = self._write_to_s3( df, path=path, - filename_prefix=filename_prefix, schema=schema, index=index, cpus=cpus, diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index c6d78a90a..fc09922cc 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import uuid from typing import Any, Callable import boto3 @@ -18,6 +19,26 @@ _logger: logging.Logger = logging.getLogger(__name__) +def _load_mode_and_filename_prefix(*, mode: str | None, filename_prefix: str | None = None) -> tuple[str, str]: + if mode is None: + mode = "append" + + if mode == "overwrite_files": + # In `overwrite_files` mode, we need to create deterministic + # filenames to ensure that the same files are always overwritten: + if filename_prefix is None: + filename_prefix = "data" + random_filename_suffix = "" + mode = "append" + else: + random_filename_suffix = uuid.uuid4().hex + + if filename_prefix is None: + filename_prefix = "" + filename_prefix = filename_prefix + random_filename_suffix + return mode, filename_prefix + + def _get_bucketing_series(df: pd.DataFrame, bucketing_info: typing.BucketingInfoTuple) -> pd.Series: bucket_number_series = ( df[bucketing_info[0]] @@ -201,7 +222,7 @@ def _to_dataset( concurrent_partitioning: bool, df: pd.DataFrame, path_root: str, - filename_prefix: str, + filename_prefix: str | None, index: bool, use_threads: bool | int, mode: str, @@ -212,6 +233,7 @@ def _to_dataset( ) -> tuple[list[str], dict[str, list[str]]]: path_root = path_root if path_root.endswith("/") else f"{path_root}/" # Evaluate mode + mode, filename_prefix = _load_mode_and_filename_prefix(mode=mode, filename_prefix=filename_prefix) if mode not in ["append", "overwrite", "overwrite_partitions"]: raise exceptions.InvalidArgumentValue( f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." diff --git a/awswrangler/s3/_write_orc.py b/awswrangler/s3/_write_orc.py index f3038a2e3..c5f434a3a 100644 --- a/awswrangler/s3/_write_orc.py +++ b/awswrangler/s3/_write_orc.py @@ -326,7 +326,7 @@ def to_orc( partition_cols: list[str] | None = None, bucketing_info: BucketingInfoTuple | None = None, concurrent_partitioning: bool = False, - mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None, + mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None, catalog_versioning: bool = False, schema_evolution: bool = True, database: str | None = None, @@ -414,7 +414,7 @@ def to_orc( If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. True by default. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")) Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html database @@ -646,6 +646,7 @@ def to_orc( columns_comments=columns_comments, columns_parameters=columns_parameters, execution_engine=engine.get(), + max_rows_by_file=max_rows_by_file, ) # Evaluating compression diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index d30fc6d60..9bef70999 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -353,7 +353,7 @@ def to_parquet( partition_cols: list[str] | None = None, bucketing_info: BucketingInfoTuple | None = None, concurrent_partitioning: bool = False, - mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None, + mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None, catalog_versioning: bool = False, schema_evolution: bool = True, database: str | None = None, @@ -444,7 +444,7 @@ def to_parquet( If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. True by default. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")) Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html database @@ -704,6 +704,7 @@ def to_parquet( columns_comments=columns_comments, columns_parameters=columns_parameters, execution_engine=engine.get(), + max_rows_by_file=max_rows_by_file, ) # Evaluating compression diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 4911e7696..827a440ec 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -4,7 +4,6 @@ import csv import logging -import uuid from typing import TYPE_CHECKING, Any, Literal, cast import boto3 @@ -98,7 +97,7 @@ def to_csv( # noqa: PLR0912,PLR0915 partition_cols: list[str] | None = None, bucketing_info: BucketingInfoTuple | None = None, concurrent_partitioning: bool = False, - mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None, + mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None, catalog_versioning: bool = False, schema_evolution: bool = False, dtype: dict[str, str] | None = None, @@ -180,7 +179,7 @@ def to_csv( # noqa: PLR0912,PLR0915 If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")). False by default. + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")). False by default. Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html database @@ -469,7 +468,6 @@ def to_csv( # noqa: PLR0912,PLR0915 partitions_values: dict[str, list[str]] = {} mode = "append" if mode is None else mode - filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex s3_client = _utils.client(service_name="s3", session=boto3_session) # Sanitize table to respect Athena's standards @@ -661,7 +659,7 @@ def to_json( # noqa: PLR0912,PLR0915 partition_cols: list[str] | None = None, bucketing_info: BucketingInfoTuple | None = None, concurrent_partitioning: bool = False, - mode: Literal["append", "overwrite", "overwrite_partitions"] | None = None, + mode: Literal["append", "overwrite", "overwrite_partitions", "overwrite_files"] | None = None, catalog_versioning: bool = False, schema_evolution: bool = True, dtype: dict[str, str] | None = None, @@ -726,7 +724,7 @@ def to_json( # noqa: PLR0912,PLR0915 If True and `mode="overwrite"`, creates an archived version of the table catalog before updating it. schema_evolution If True allows schema evolution (new or missing columns), otherwise a exception will be raised. - (Only considered if dataset=True and mode in ("append", "overwrite_partitions")) + (Only considered if dataset=True and mode in ("append", "overwrite_partitions", "overwrite_files")) Related tutorial: https://aws-sdk-pandas.readthedocs.io/en/3.11.0/tutorials/014%20-%20Schema%20Evolution.html database @@ -919,7 +917,6 @@ def to_json( # noqa: PLR0912,PLR0915 partitions_values: dict[str, list[str]] = {} mode = "append" if mode is None else mode - filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex s3_client = _utils.client(service_name="s3", session=boto3_session) # Sanitize table to respect Athena's standards