diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 24b5b3fd4..2c45e7202 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -56,17 +56,13 @@ jobs: - name: Install dependencies run: | uv sync --extra files --group test-spark-3.5 --group test-pydantic-2 --group dev - uv pip install -U flake8-commas # Set the `CODEQL-PYTHON` environment variable to the Python executable # that includes the dependencies echo "CODEQL_PYTHON=$(which python)" >> $GITHUB_ENV - - name: Run flake8 - run: python3 -m flake8 --config setup.cfg . - - name: Run mypy - run: python3 -m mypy --config-file setup.cfg onetl + run: python3 -m mypy onetl codeql: name: CodeQL diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 163d56801..a74588f69 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -116,24 +116,19 @@ repos: hooks: - id: ruff-format priority: 8 + - id: ruff + args: [--fix] + priority: 9 - repo: local hooks: - - id: flake8 - name: flake8 - entry: python3 -m flake8 - language: system - types: [python] - files: ^(onetl|tests)/.*$ - pass_filenames: true - priority: 9 - id: mypy name: mypy - entry: python3 -m mypy --config-file setup.cfg onetl + entry: python3 -m mypy onetl language: system types: [python] pass_filenames: false - priority: 9 + priority: 10 - id: towncrier name: towncrier entry: towncrier build --draft @@ -153,6 +148,5 @@ ci: skip: - chmod # failing in pre-commit.ci - docker-compose-check # cannot run on pre-commit.ci - - flake8 # checked with Github Actions - mypy # checked with Github Actions - towncrier # checked with Github Actions diff --git a/onetl/__init__.py b/onetl/__init__.py index 25a65ac02..87c7729c0 100644 --- a/onetl/__init__.py +++ b/onetl/__init__.py @@ -5,6 +5,8 @@ from onetl.plugins import import_plugins from onetl.version import __version__ +__all__ = ["__version__"] + def plugins_auto_import(): """ diff --git a/onetl/_metrics/__init__.py b/onetl/_metrics/__init__.py index 046fb6a39..516ada0a7 100644 --- a/onetl/_metrics/__init__.py +++ b/onetl/_metrics/__init__.py @@ -10,8 +10,8 @@ __all__ = [ "SparkCommandMetrics", "SparkDriverMetrics", - "SparkMetricsRecorder", "SparkExecutorMetrics", "SparkInputMetrics", + "SparkMetricsRecorder", "SparkOutputMetrics", ] diff --git a/onetl/_metrics/extract.py b/onetl/_metrics/extract.py index 836ce10b5..c390f6cec 100644 --- a/onetl/_metrics/extract.py +++ b/onetl/_metrics/extract.py @@ -29,7 +29,7 @@ def _get_int(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | None: try: return int(data[key][0]) - except Exception: + except (IndexError, KeyError, ValueError, TypeError): return None @@ -38,7 +38,7 @@ def _get_bytes(data: dict[SparkSQLMetricNames, list[str]], key: Any) -> int | No raw_value = data[key][0] normalized_value = NON_BYTE_SIZE.sub("", raw_value) return int(ByteSize.validate(normalized_value)) - except Exception: + except (IndexError, KeyError, ValueError, TypeError): return None diff --git a/onetl/_metrics/listener/__init__.py b/onetl/_metrics/listener/__init__.py index 476e8e086..8ae322fef 100644 --- a/onetl/_metrics/listener/__init__.py +++ b/onetl/_metrics/listener/__init__.py @@ -15,15 +15,15 @@ ) __all__ = [ - "SparkListenerTask", - "SparkListenerTaskStatus", - "SparkListenerTaskMetrics", - "SparkListenerStage", - "SparkListenerStageStatus", - "SparkListenerJob", - "SparkListenerJobStatus", "SparkListenerExecution", "SparkListenerExecutionStatus", - "SparkSQLMetricNames", + "SparkListenerJob", + "SparkListenerJobStatus", + "SparkListenerStage", + "SparkListenerStageStatus", + "SparkListenerTask", + "SparkListenerTaskMetrics", + "SparkListenerTaskStatus", "SparkMetricsListener", + "SparkSQLMetricNames", ] diff --git a/onetl/_metrics/listener/base.py b/onetl/_metrics/listener/base.py index f12b7bf10..be428112c 100644 --- a/onetl/_metrics/listener/base.py +++ b/onetl/_metrics/listener/base.py @@ -17,14 +17,15 @@ class BaseSparkListener: """Base no-op SparkListener implementation. See `SparkListener `_ interface. - """ + """ # noqa: E501 spark: SparkSession def activate(self): start_callback_server(self.spark) - # passing python listener object directly to addSparkListener or removeSparkListener leads to creating new java object each time. + # passing python listener object directly to addSparkListener or removeSparkListener + # leads to creating new java object each time. # But removeSparkListener call has effect only on the same Java object passed to removeSparkListener. # So we need to explicitly create Java object, and then pass it both calls. gateway = get_java_gateway(self.spark) @@ -32,12 +33,12 @@ def activate(self): java_list.append(self) self._java_listener = java_list[0] - spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context = self.spark.sparkContext._jsc.sc() # noqa: SLF001 spark_context.addSparkListener(self._java_listener) def deactivate(self): with suppress(Exception): - spark_context = self.spark.sparkContext._jsc.sc() # noqa: WPS437 + spark_context = self.spark.sparkContext._jsc.sc() # noqa: SLF001 spark_context.removeSparkListener(self._java_listener) with suppress(Exception): @@ -50,7 +51,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.deactivate() - def __del__(self): # noqa: WPS603 + def __del__(self): # If current object is collected by GC, deactivate listener # and free bind Java object self.deactivate() @@ -60,119 +61,119 @@ def equals(self, other): # so we compare string representation which should contain some form of id return other.toString() == self._java_listener.toString() - def toString(self): + def toString(self): # noqa: N802 return type(self).__qualname__ + "@" + hex(id(self)) - def hashCode(self): + def hashCode(self): # noqa: N802 return hash(self) # no cover: start # method names are important for Java interface compatibility! - def onApplicationEnd(self, application): + def onApplicationEnd(self, application): # noqa: N802 pass - def onApplicationStart(self, application): + def onApplicationStart(self, application): # noqa: N802 pass - def onBlockManagerAdded(self, block_manager): + def onBlockManagerAdded(self, block_manager): # noqa: N802 pass - def onBlockManagerRemoved(self, block_manager): + def onBlockManagerRemoved(self, block_manager): # noqa: N802 pass - def onBlockUpdated(self, block): + def onBlockUpdated(self, block): # noqa: N802 pass - def onEnvironmentUpdate(self, environment): + def onEnvironmentUpdate(self, environment): # noqa: N802 pass - def onExecutorAdded(self, executor): + def onExecutorAdded(self, executor): # noqa: N802 pass - def onExecutorMetricsUpdate(self, executor): + def onExecutorMetricsUpdate(self, executor): # noqa: N802 pass - def onExecutorRemoved(self, executor): + def onExecutorRemoved(self, executor): # noqa: N802 pass - def onExecutorBlacklisted(self, event): + def onExecutorBlacklisted(self, event): # noqa: N802 pass - def onExecutorBlacklistedForStage(self, event): + def onExecutorBlacklistedForStage(self, event): # noqa: N802 pass - def onExecutorExcluded(self, event): + def onExecutorExcluded(self, event): # noqa: N802 pass - def onExecutorExcludedForStage(self, event): + def onExecutorExcludedForStage(self, event): # noqa: N802 pass - def onExecutorUnblacklisted(self, event): + def onExecutorUnblacklisted(self, event): # noqa: N802 pass - def onExecutorUnexcluded(self, event): + def onExecutorUnexcluded(self, event): # noqa: N802 pass - def onJobStart(self, event): + def onJobStart(self, event): # noqa: N802 pass - def onJobEnd(self, event): + def onJobEnd(self, event): # noqa: N802 pass - def onNodeBlacklisted(self, node): + def onNodeBlacklisted(self, node): # noqa: N802 pass - def onNodeBlacklistedForStage(self, stage): + def onNodeBlacklistedForStage(self, stage): # noqa: N802 pass - def onNodeExcluded(self, node): + def onNodeExcluded(self, node): # noqa: N802 pass - def onNodeExcludedForStage(self, node): + def onNodeExcludedForStage(self, node): # noqa: N802 pass - def onNodeUnblacklisted(self, node): + def onNodeUnblacklisted(self, node): # noqa: N802 pass - def onNodeUnexcluded(self, node): + def onNodeUnexcluded(self, node): # noqa: N802 pass - def onOtherEvent(self, event): + def onOtherEvent(self, event): # noqa: N802 pass - def onResourceProfileAdded(self, resource_profile): + def onResourceProfileAdded(self, resource_profile): # noqa: N802 pass - def onSpeculativeTaskSubmitted(self, task): + def onSpeculativeTaskSubmitted(self, task): # noqa: N802 pass - def onStageCompleted(self, event): + def onStageCompleted(self, event): # noqa: N802 pass - def onStageExecutorMetrics(self, metrics): + def onStageExecutorMetrics(self, metrics): # noqa: N802 pass - def onStageSubmitted(self, event): + def onStageSubmitted(self, event): # noqa: N802 pass - def onTaskEnd(self, event): + def onTaskEnd(self, event): # noqa: N802 pass - def onTaskGettingResult(self, task): + def onTaskGettingResult(self, task): # noqa: N802 pass - def onTaskStart(self, event): + def onTaskStart(self, event): # noqa: N802 pass - def onUnpersistRDD(self, rdd): + def onUnpersistRDD(self, rdd): # noqa: N802 pass - def onUnschedulableTaskSetAdded(self, task_set): + def onUnschedulableTaskSetAdded(self, task_set): # noqa: N802 pass - def onUnschedulableTaskSetRemoved(self, task_set): + def onUnschedulableTaskSetRemoved(self, task_set): # noqa: N802 pass # no cover: stop class Java: - implements = ["org.apache.spark.scheduler.SparkListenerInterface"] + implements = ["org.apache.spark.scheduler.SparkListenerInterface"] # noqa: RUF012 diff --git a/onetl/_metrics/listener/execution.py b/onetl/_metrics/listener/execution.py index fbcd5bdf4..1ba0aa49c 100644 --- a/onetl/_metrics/listener/execution.py +++ b/onetl/_metrics/listener/execution.py @@ -18,7 +18,7 @@ def __str__(self): return self.value -class SparkSQLMetricNames(str, Enum): # noqa: WPS338 +class SparkSQLMetricNames(str, Enum): # Metric names passed to SQLMetrics.createMetric(...) # But only those we're interested in. @@ -60,10 +60,7 @@ class SparkListenerExecution: @property def jobs(self) -> list[SparkListenerJob]: - result = [] - for job_id in sorted(self._jobs.keys()): - result.append(self._jobs[job_id]) - return result + return [self._jobs[job_id] for job_id in sorted(self._jobs.keys())] def on_execution_start(self, event): # https://github.com/apache/spark/blob/v3.5.7/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L44-L58 diff --git a/onetl/_metrics/listener/job.py b/onetl/_metrics/listener/job.py index 68551889e..11ef58533 100644 --- a/onetl/_metrics/listener/job.py +++ b/onetl/_metrics/listener/job.py @@ -31,10 +31,7 @@ class SparkListenerJob: @property def stages(self) -> list[SparkListenerStage]: - result = [] - for stage_id in sorted(self._stages.keys()): - result.append(self._stages[stage_id]) - return result + return [self._stages[stage_id] for stage_id in sorted(self._stages.keys())] @classmethod def create(cls, event): @@ -50,7 +47,7 @@ def create(cls, event): stage_ids = scala_seq_to_python_list(event.stageIds()) stage_infos = scala_seq_to_python_list(event.stageInfos()) for stage_id, stage_info in zip(stage_ids, stage_infos): - result._stages[stage_id] = SparkListenerStage.create(stage_info) # noqa: WPS437 + result._stages[stage_id] = SparkListenerStage.create(stage_info) return result diff --git a/onetl/_metrics/listener/listener.py b/onetl/_metrics/listener/listener.py index 33139046d..5c60cdc96 100644 --- a/onetl/_metrics/listener/listener.py +++ b/onetl/_metrics/listener/listener.py @@ -50,14 +50,14 @@ def __enter__(self): self.reset() return super().__enter__() - def onOtherEvent(self, event): + def onOtherEvent(self, event): # noqa: N802 class_name = event.getClass().getName() if class_name == self.SQL_START_CLASS_NAME: self.onExecutionStart(event) elif class_name == self.SQL_STOP_CLASS_NAME: self.onExecutionEnd(event) - def onExecutionStart(self, event): + def onExecutionStart(self, event): # noqa: N802 execution_id = event.executionId() description = event.description() execution = SparkListenerExecution( @@ -67,7 +67,7 @@ def onExecutionStart(self, event): self._recorded_executions[execution_id] = execution execution.on_execution_start(event) - def onExecutionEnd(self, event): + def onExecutionEnd(self, event): # noqa: N802 execution_id = event.executionId() execution = self._recorded_executions.get(execution_id) if execution: @@ -76,7 +76,7 @@ def onExecutionEnd(self, event): # Get execution metrics from SQLAppStatusStore, # as SparkListenerSQLExecutionEnd event does not provide them: # https://github.com/apache/spark/blob/v3.5.7/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala - session_status_store = self.spark._jsparkSession.sharedState().statusStore() # noqa: WPS437 + session_status_store = self.spark._jsparkSession.sharedState().statusStore() # noqa: SLF001 raw_execution = session_status_store.execution(execution.id).get() metrics = raw_execution.metrics() metric_values = session_status_store.executionMetrics(execution.id) @@ -90,7 +90,7 @@ def onExecutionEnd(self, event): continue execution.metrics[SparkSQLMetricNames(metric_name)].append(metric_value.get()) - def onJobStart(self, event): + def onJobStart(self, event): # noqa: N802 execution_id = event.properties().get("spark.sql.execution.id") execution_thread_id = event.properties().get(self.THREAD_ID_KEY) if execution_id is None: @@ -114,22 +114,22 @@ def onJobStart(self, event): execution.on_job_start(event) - def onJobEnd(self, event): + def onJobEnd(self, event): # noqa: N802 for execution in self._recorded_executions.values(): execution.on_job_end(event) - def onStageSubmitted(self, event): + def onStageSubmitted(self, event): # noqa: N802 for execution in self._recorded_executions.values(): execution.on_stage_start(event) - def onStageCompleted(self, event): + def onStageCompleted(self, event): # noqa: N802 for execution in self._recorded_executions.values(): execution.on_stage_end(event) - def onTaskStart(self, event): + def onTaskStart(self, event): # noqa: N802 for execution in self._recorded_executions.values(): execution.on_task_start(event) - def onTaskEnd(self, event): + def onTaskEnd(self, event): # noqa: N802 for execution in self._recorded_executions.values(): execution.on_task_end(event) diff --git a/onetl/_metrics/listener/stage.py b/onetl/_metrics/listener/stage.py index c5ac9b077..67ccf2878 100644 --- a/onetl/_metrics/listener/stage.py +++ b/onetl/_metrics/listener/stage.py @@ -29,10 +29,7 @@ class SparkListenerStage: @property def tasks(self) -> list[SparkListenerTask]: - result = [] - for task_id in sorted(self._tasks.keys()): - result.append(self._tasks[task_id]) - return result + return [self._tasks[task_id] for task_id in sorted(self._tasks.keys())] @classmethod def create(cls, stage_info): diff --git a/onetl/_util/file.py b/onetl/_util/file.py index 517647587..7e59af4a6 100644 --- a/onetl/_util/file.py +++ b/onetl/_util/file.py @@ -22,7 +22,7 @@ def get_file_hash( ) -> hashlib._Hash: """Get file hash by path and algorithm""" digest = hashlib.new(algorithm) - with open(path, "rb") as file: + with Path(path).open("rb") as file: chunk = file.read(chunk_size) while chunk: digest.update(chunk) @@ -36,13 +36,16 @@ def is_file_readable(path: str | os.PathLike) -> Path: path = Path(os.path.expandvars(path)).expanduser().resolve() if not path.exists(): - raise FileNotFoundError(f"File '{path}' does not exist") + msg = f"File '{path}' does not exist" + raise FileNotFoundError(msg) if not path.is_file(): - raise NotAFileError(f"{path_repr(path)} is not a file") + msg = f"{path_repr(path)} is not a file" + raise NotAFileError(msg) if not os.access(path, os.R_OK): - raise OSError(f"No read access to file {path_repr(path)}") + msg = f"No read access to file {path_repr(path)}" + raise OSError(msg) return path @@ -71,5 +74,5 @@ def generate_temp_path(root: PurePath) -> PurePath: from etl_entities.process import ProcessStackManager current_process = ProcessStackManager.get_current() - current_dt = datetime.now().strftime(DATETIME_FORMAT) + current_dt = datetime.now().strftime(DATETIME_FORMAT) # noqa: DTZ005 return root / "onetl" / current_process.host / current_process.full_name / current_dt diff --git a/onetl/_util/hadoop.py b/onetl/_util/hadoop.py index d9b3684d2..20ad11ac9 100644 --- a/onetl/_util/hadoop.py +++ b/onetl/_util/hadoop.py @@ -12,4 +12,4 @@ def get_hadoop_config(spark_session: SparkSession): """ Get ``org.apache.hadoop.conf.Configuration`` object """ - return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined] + return spark_session.sparkContext._jsc.hadoopConfiguration() # type: ignore[attr-defined] # noqa: SLF001 diff --git a/onetl/_util/java.py b/onetl/_util/java.py index 1e315032d..2efd88433 100644 --- a/onetl/_util/java.py +++ b/onetl/_util/java.py @@ -13,7 +13,7 @@ def get_java_gateway(spark_session: SparkSession) -> JavaGateway: """ Get py4j Java gateway object """ - return spark_session._sc._gateway # noqa: WPS437 # type: ignore[attr-defined] + return spark_session._sc._gateway # noqa: SLF001 # type: ignore[attr-defined] def try_import_java_class(spark_session: SparkSession, name: str): diff --git a/onetl/_util/scala.py b/onetl/_util/scala.py index 15ff55bb1..5c35040fe 100644 --- a/onetl/_util/scala.py +++ b/onetl/_util/scala.py @@ -9,15 +9,12 @@ def get_default_scala_version(spark_version: Version) -> Version: """ Get default Scala version for specific Spark version """ - if spark_version.major == 2: + if spark_version.major == 2: # noqa: PLR2004 return Version("2.11") - if spark_version.major == 3: + if spark_version.major == 3: # noqa: PLR2004 return Version("2.12") return Version("2.13") def scala_seq_to_python_list(seq) -> list: - result = [] - for i in range(seq.size()): - result.append(seq.apply(i)) - return result + return [seq.apply(i) for i in range(seq.size())] diff --git a/onetl/_util/spark.py b/onetl/_util/spark.py index e57a46684..460a338c9 100644 --- a/onetl/_util/spark.py +++ b/onetl/_util/spark.py @@ -23,7 +23,7 @@ SPARK_JOB_GROUP_PROPERTY = "spark.jobGroup.id" -def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212 +def stringify(value: Any, *, quote: bool = False) -> Any: """ Convert values to strings. @@ -57,10 +57,10 @@ def stringify(value: Any, quote: bool = False) -> Any: # noqa: WPS212 """ if isinstance(value, dict): - return {stringify(k): stringify(v, quote) for k, v in value.items()} + return {stringify(k): stringify(v, quote=quote) for k, v in value.items()} if isinstance(value, list): - return [stringify(v, quote) for v in value] + return [stringify(v, quote=quote) for v in value] if value is None: return "null" @@ -89,7 +89,7 @@ def inject_spark_param(conf: RuntimeConfig, name: str, value: Any): """ original_value = conf.get(name, None) - try: # noqa: WPS243 + try: conf.unset(name) if value is not None: conf.set(name, value) @@ -137,16 +137,15 @@ def estimate_dataframe_size(df: DataFrame) -> int: """ try: - spark_context = df._sc - size_estimator = spark_context._jvm.org.apache.spark.util.SizeEstimator # type: ignore[union-attr] - return size_estimator.estimate(df._jdf) - except Exception: + jvm = df._sc._jvm # type: ignore[attr-defined] # noqa: SLF001 + return jvm.org.apache.spark.util.SizeEstimator.estimate(df._jdf) # type: ignore[union-attr] # noqa: SLF001 + except Exception: # noqa: BLE001 # SizeEstimator uses Java reflection which may behave differently in different Java versions, # and also may be prohibited. return 0 -def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = False) -> tuple[float, dict]: +def get_executor_total_cores(spark_session: SparkSession, *, include_driver: bool = False) -> tuple[float, dict]: """ Calculate maximum number of cores which can be used by Spark on all executors. @@ -165,8 +164,8 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = expected_cores: float if master.startswith("local"): # no executors, only driver - scheduler = spark_session._jsc.sc().schedulerBackend() # type: ignore - expected_cores = scheduler.totalCores() # type: ignore + scheduler = spark_session._jsc.sc().schedulerBackend() # type: ignore[attr-defined] # noqa: SLF001 + expected_cores = scheduler.totalCores() config["spark.master"] = f"local[{expected_cores}]" else: cores = int(conf.get("spark.executor.cores", "1")) @@ -180,8 +179,8 @@ def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = # If user haven't executed anything in current session, number of executors will be 0. # # Yes, scheduler can refuse to provide executors == maxExecutors: - # On Yarn - queue size limit is reached, or other application has higher priority, so executors were preempted. - # on K8S - namespace has not enough resources. + # On Yarn - queue size limit is reached, or other application has higher priority, + # so executors were preempted. on K8S - namespace has not enough resources. # So pessimistic approach is preferred. dynamic_executors = conf.get("spark.dynamicAllocation.maxExecutors", "infinity") diff --git a/onetl/_util/version.py b/onetl/_util/version.py index e83e23823..fded45896 100644 --- a/onetl/_util/version.py +++ b/onetl/_util/version.py @@ -28,7 +28,10 @@ class Version: def __init__(self, version: str): self._raw_str = version self._raw_parts = re.split("[.-]", version) - self._numeric_parts = [int(part) for part in self._raw_parts if part.isdigit()] + self._numeric_parts = tuple(int(part) for part in self._raw_parts if part.isdigit()) + + def __hash__(self): + return hash(self._raw_str) @property def major(self) -> int: @@ -71,7 +74,7 @@ def patch(self) -> int: >>> Version("5.6").patch 0 """ - return self._numeric_parts[2] if len(self._numeric_parts) > 2 else 0 + return self._numeric_parts[2] if len(self._numeric_parts) > 2 else 0 # noqa: PLR2004 @property def raw_parts(self) -> list[str]: @@ -101,7 +104,7 @@ def __getitem__(self, item): >>> Version("1.2.3-alpha")[3] Traceback (most recent call last): ... - IndexError: list index out of range + IndexError: tuple index out of range """ return self._numeric_parts[item] @@ -197,9 +200,11 @@ def min_digits(self, num_parts: int) -> Version: ValueError: Version '5.6' does not have enough numeric components for requested format (expected at least 3). """ if len(self._numeric_parts) < num_parts: - raise ValueError( - f"Version '{self}' does not have enough numeric components for requested format (expected at least {num_parts}).", + msg = ( + f"Version '{self}' does not have enough numeric components " + f"for requested format (expected at least {num_parts})." ) + raise ValueError(msg) return self def format(self, format_string: str) -> str: diff --git a/onetl/base/__init__.py b/onetl/base/__init__.py index ccb858a70..3285e16a5 100644 --- a/onetl/base/__init__.py +++ b/onetl/base/__init__.py @@ -18,3 +18,25 @@ from onetl.base.path_stat_protocol import PathStatProtocol from onetl.base.pure_path_protocol import PurePathProtocol from onetl.base.supports_rename_dir import SupportsRenameDir + +__all__ = [ + "BaseConnection", + "BaseDBConnection", + "BaseDBDialect", + "BaseFileConnection", + "BaseFileDFConnection", + "BaseFileFilter", + "BaseFileLimit", + "BaseReadableFileFormat", + "BaseWritableFileFormat", + "ContainsException", + "ContainsGetDFSchemaMethod", + "ContainsGetMinMaxValues", + "FileDFReadOptions", + "FileDFWriteOptions", + "PathProtocol", + "PathStatProtocol", + "PathWithStatsProtocol", + "PurePathProtocol", + "SupportsRenameDir", +] diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index cc9a4c952..77380e880 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -121,7 +121,7 @@ def instance_url(self) -> str: # Some implementations may have a different number of parameters. # For example, the 'options' parameter may be present. This is fine. @abstractmethod - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, diff --git a/onetl/base/base_file_connection.py b/onetl/base/base_file_connection.py index 268fe8fa8..01d7a8705 100644 --- a/onetl/base/base_file_connection.py +++ b/onetl/base/base_file_connection.py @@ -243,7 +243,8 @@ def remove_file(self, path: os.PathLike | str) -> bool: .. warning:: - Supports only one file removal per call. Directory removal is **NOT** supported, use :obj:`~remove_dir` instead. + Supports only one file removal per call. + Directory removal is **NOT** supported, use :obj:`~remove_dir` instead. .. versionadded:: 0.8.0 @@ -273,7 +274,7 @@ def remove_file(self, path: os.PathLike | str) -> bool: """ @abstractmethod - def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: + def remove_dir(self, path: os.PathLike | str, *, recursive: bool = False) -> bool: """ Remove directory or directory tree. |support_hooks| @@ -325,6 +326,7 @@ def rename_file( self, source_file_path: os.PathLike | str, target_file_path: os.PathLike | str, + *, replace: bool = False, ) -> PathWithStatsProtocol: """ @@ -427,6 +429,7 @@ def list_dir( def walk( self, root: os.PathLike | str, + *, topdown: bool = True, filters: Iterable[BaseFileFilter] | None = None, limits: Iterable[BaseFileLimit] | None = None, @@ -490,6 +493,7 @@ def download_file( self, remote_file_path: os.PathLike | str, local_file_path: os.PathLike | str, + *, replace: bool = True, ) -> PathWithStatsProtocol: """ @@ -497,7 +501,8 @@ def download_file( .. warning:: - Supports only one file download per call. Directory download is **NOT** supported, use :ref:`file-downloader` instead. + Supports only one file download per call. + Directory download is **NOT** supported, use :ref:`file-downloader` instead. .. versionadded:: 0.8.0 @@ -552,6 +557,7 @@ def upload_file( self, local_file_path: os.PathLike | str, remote_file_path: os.PathLike | str, + *, replace: bool = False, ) -> PathWithStatsProtocol: """ @@ -559,7 +565,8 @@ def upload_file( .. warning:: - Supports only one file upload per call. Directory upload is **NOT** supported, use :ref:`file-uploader` instead. + Supports only one file upload per call. + Directory upload is **NOT** supported, use :ref:`file-uploader` instead. .. versionadded:: 0.8.0 diff --git a/onetl/base/base_file_df_connection.py b/onetl/base/base_file_df_connection.py index 463c32752..0dedfaa05 100644 --- a/onetl/base/base_file_df_connection.py +++ b/onetl/base/base_file_df_connection.py @@ -77,7 +77,7 @@ class BaseFileDFConnection(BaseConnection): @abstractmethod def check_if_format_supported( self, - format: BaseReadableFileFormat | BaseWritableFileFormat, # noqa: WPS125 + format: BaseReadableFileFormat | BaseWritableFileFormat, ) -> None: """ Validate if specific file format is supported. |support_hooks| @@ -110,7 +110,7 @@ def instance_url(self) -> str: def read_files_as_df( self, paths: list[PurePathProtocol], - format: BaseReadableFileFormat, # noqa: WPS125 + format: BaseReadableFileFormat, root: PurePathProtocol | None = None, df_schema: StructType | None = None, options: FileDFReadOptions | None = None, @@ -126,7 +126,7 @@ def write_df_as_files( self, df: DataFrame, path: PurePathProtocol, - format: BaseWritableFileFormat, # noqa: WPS125 + format: BaseWritableFileFormat, options: FileDFWriteOptions | None = None, ) -> None: """ diff --git a/onetl/base/pure_path_protocol.py b/onetl/base/pure_path_protocol.py index d55516a73..7dc13c5a2 100644 --- a/onetl/base/pure_path_protocol.py +++ b/onetl/base/pure_path_protocol.py @@ -6,7 +6,7 @@ from typing_extensions import Protocol, runtime_checkable -T = TypeVar("T", bound="PurePathProtocol", covariant=True) +T = TypeVar("T", bound="PurePathProtocol", covariant=True) # noqa: PLC0105 @runtime_checkable diff --git a/onetl/base/supports_rename_dir.py b/onetl/base/supports_rename_dir.py index e36b59b26..3269d7475 100644 --- a/onetl/base/supports_rename_dir.py +++ b/onetl/base/supports_rename_dir.py @@ -21,5 +21,6 @@ def rename_dir( self, source_dir_path: str | os.PathLike, target_dir_path: str | os.PathLike, + *, replace: bool = False, ) -> PathWithStatsProtocol: ... diff --git a/onetl/connection/__init__.py b/onetl/connection/__init__.py index b5a429a05..47a6eb702 100644 --- a/onetl/connection/__init__.py +++ b/onetl/connection/__init__.py @@ -62,6 +62,32 @@ "SparkS3": "spark_s3", } +__all__ = [ + "FTP", + "FTPS", + "HDFS", + "MSSQL", + "S3", + "SFTP", + "Clickhouse", + "DBConnection", + "FileConnection", + "Greenplum", + "Hive", + "Iceberg", + "Kafka", + "MongoDB", + "MySQL", + "Oracle", + "Postgres", + "Samba", + "SparkFileDFConnection", + "SparkHDFS", + "SparkLocalFS", + "SparkS3", + "WebDAV", +] + def __getattr__(name: str): if name in db_connection_modules: diff --git a/onetl/connection/db_connection/clickhouse/__init__.py b/onetl/connection/db_connection/clickhouse/__init__.py index 95ebc973d..58f03cf6a 100644 --- a/onetl/connection/db_connection/clickhouse/__init__.py +++ b/onetl/connection/db_connection/clickhouse/__init__.py @@ -2,6 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.connection.db_connection.clickhouse.connection import ( Clickhouse, - ClickhouseExtra, ) from onetl.connection.db_connection.clickhouse.dialect import ClickhouseDialect +from onetl.connection.db_connection.clickhouse.options import ( + ClickhouseExecuteOptions, + ClickhouseFetchOptions, + ClickhouseReadOptions, + ClickhouseSQLOptions, + ClickhouseWriteOptions, +) + +__all__ = [ + "Clickhouse", + "ClickhouseDialect", + "ClickhouseExecuteOptions", + "ClickhouseFetchOptions", + "ClickhouseReadOptions", + "ClickhouseSQLOptions", + "ClickhouseWriteOptions", +] diff --git a/onetl/connection/db_connection/clickhouse/connection.py b/onetl/connection/db_connection/clickhouse/connection.py index 6dcd7a584..9d63b2e5b 100644 --- a/onetl/connection/db_connection/clickhouse/connection.py +++ b/onetl/connection/db_connection/clickhouse/connection.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +import warnings from typing import ClassVar, Optional from etl_entities.instance import Host @@ -133,7 +134,9 @@ def get_packages( apache_http_client_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. Allows specifying custom JDBC and Apache HTTP Client versions. |support_hooks| + Get package names to be downloaded by Spark. |support_hooks| + + Allows specifying custom JDBC and Apache HTTP Client versions. .. versionadded:: 0.9.0 @@ -193,7 +196,13 @@ def get_packages( @classproperty def package(self) -> str: """Get a single string of package names to be downloaded by Spark for establishing a Clickhouse connection.""" - return "com.clickhouse:clickhouse-jdbc:0.7.2,com.clickhouse:clickhouse-http-client:0.7.2,org.apache.httpcomponents.client5:httpclient5:5.4.2" + msg = "`Clickhouse.package` will be removed in 1.0.0, use `Clickhouse.get_packages()` instead" + warnings.warn(msg, UserWarning, stacklevel=3) + return ( + "com.clickhouse:clickhouse-jdbc:0.7.2," + "com.clickhouse:clickhouse-http-client:0.7.2," + "org.apache.httpcomponents.client5:httpclient5:5.4.2" + ) @property def jdbc_url(self) -> str: diff --git a/onetl/connection/db_connection/db_connection/__init__.py b/onetl/connection/db_connection/db_connection/__init__.py index df4257c9d..d319c4556 100644 --- a/onetl/connection/db_connection/db_connection/__init__.py +++ b/onetl/connection/db_connection/db_connection/__init__.py @@ -2,3 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.connection.db_connection.db_connection.connection import DBConnection from onetl.connection.db_connection.db_connection.dialect import DBDialect + +__all__ = [ + "DBConnection", + "DBDialect", +] diff --git a/onetl/connection/db_connection/db_connection/connection.py b/onetl/connection/db_connection/db_connection/connection.py index f5e1baf33..015da823d 100644 --- a/onetl/connection/db_connection/db_connection/connection.py +++ b/onetl/connection/db_connection/db_connection/connection.py @@ -31,7 +31,7 @@ class DBConnection(BaseDBConnection, FrozenModel): def _forward_refs(cls) -> dict[str, type]: try_import_pyspark() - from pyspark.sql import SparkSession # noqa: WPS442 + from pyspark.sql import SparkSession # avoid importing pyspark unless user called the constructor, # as we allow user to use `Connection.get_packages()` for creating Spark session @@ -44,7 +44,7 @@ def _check_spark_session_alive(cls, spark): # https://stackoverflow.com/a/36044685 msg = "Spark session is stopped. Please recreate Spark session." try: - if not spark._jsc.sc().isStopped(): + if not spark._jsc.sc().isStopped(): # noqa: SLF001 return spark except Exception as e: # None has no attribute "something" diff --git a/onetl/connection/db_connection/db_connection/dialect.py b/onetl/connection/db_connection/db_connection/dialect.py index 1fa2cd026..0f88be3fc 100644 --- a/onetl/connection/db_connection/db_connection/dialect.py +++ b/onetl/connection/db_connection/db_connection/dialect.py @@ -17,11 +17,12 @@ class DBDialect(BaseDBDialect): def detect_hwm_class(self, field: StructField) -> type[HWM] | None: - return SparkTypeToHWM.get(field.dataType) # type: ignore + return SparkTypeToHWM.get(field.dataType) - def get_sql_query( + def get_sql_query( # noqa: PLR0913 self, table: str, + *, columns: list[str] | None = None, where: str | list[str] | None = None, hint: str | None = None, @@ -32,11 +33,7 @@ def get_sql_query( Generates a SQL query using input arguments """ - if compact: - indent = " " - else: - indent = os.linesep + " " * 7 - + indent = " " if compact else (os.linesep + " " * 7) hint = f" /*+ {hint} */" if hint else "" columns_str = indent + "*" diff --git a/onetl/connection/db_connection/dialect_mixins/__init__.py b/onetl/connection/db_connection/dialect_mixins/__init__.py index b64060bc5..bd6403b9b 100644 --- a/onetl/connection/db_connection/dialect_mixins/__init__.py +++ b/onetl/connection/db_connection/dialect_mixins/__init__.py @@ -33,3 +33,17 @@ from onetl.connection.db_connection.dialect_mixins.support_where_str import ( SupportWhereStr, ) + +__all__ = [ + "NotSupportColumns", + "NotSupportDFSchema", + "NotSupportHint", + "NotSupportWhere", + "RequiresDFSchema", + "SupportColumns", + "SupportHWMExpressionStr", + "SupportHintStr", + "SupportNameAny", + "SupportNameWithSchemaOnly", + "SupportWhereStr", +] diff --git a/onetl/connection/db_connection/dialect_mixins/not_support_columns.py b/onetl/connection/db_connection/dialect_mixins/not_support_columns.py index eb73099fe..4577887e6 100644 --- a/onetl/connection/db_connection/dialect_mixins/not_support_columns.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_columns.py @@ -15,4 +15,5 @@ def validate_columns( columns: Any, ) -> None: if columns is not None: - raise ValueError(f"'columns' parameter is not supported by {self.connection.__class__.__name__}") + msg = f"'columns' parameter is not supported by {self.connection.__class__.__name__}" + raise ValueError(msg) diff --git a/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py b/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py index 82da7034a..dc72f1dfb 100644 --- a/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_df_schema.py @@ -15,4 +15,5 @@ def validate_df_schema( df_schema: Any, ) -> None: if df_schema: - raise ValueError(f"'df_schema' parameter is not supported by {self.connection.__class__.__name__}") + msg = f"'df_schema' parameter is not supported by {self.connection.__class__.__name__}" + raise ValueError(msg) diff --git a/onetl/connection/db_connection/dialect_mixins/not_support_hint.py b/onetl/connection/db_connection/dialect_mixins/not_support_hint.py index 04ab2fd3a..da50f7ed6 100644 --- a/onetl/connection/db_connection/dialect_mixins/not_support_hint.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_hint.py @@ -15,4 +15,5 @@ def validate_hint( hint: Any, ) -> None: if hint is not None: - raise TypeError(f"'hint' parameter is not supported by {self.connection.__class__.__name__}") + msg = f"'hint' parameter is not supported by {self.connection.__class__.__name__}" + raise TypeError(msg) diff --git a/onetl/connection/db_connection/dialect_mixins/not_support_where.py b/onetl/connection/db_connection/dialect_mixins/not_support_where.py index 54303c6e4..9a836dbc6 100644 --- a/onetl/connection/db_connection/dialect_mixins/not_support_where.py +++ b/onetl/connection/db_connection/dialect_mixins/not_support_where.py @@ -15,4 +15,5 @@ def validate_where( where: Any, ) -> None: if where is not None: - raise TypeError(f"'where' parameter is not supported by {self.connection.__class__.__name__}") + msg = f"'where' parameter is not supported by {self.connection.__class__.__name__}" + raise TypeError(msg) diff --git a/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py b/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py index 0922c5429..d766a8cc0 100644 --- a/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py +++ b/onetl/connection/db_connection/dialect_mixins/requires_df_schema.py @@ -19,4 +19,5 @@ def validate_df_schema( ) -> StructType: if df_schema: return df_schema - raise ValueError(f"'df_schema' parameter is mandatory for {self.connection.__class__.__name__}") + msg = f"'df_schema' parameter is mandatory for {self.connection.__class__.__name__}" + raise ValueError(msg) diff --git a/onetl/connection/db_connection/dialect_mixins/support_hint_str.py b/onetl/connection/db_connection/dialect_mixins/support_hint_str.py index 79ccbf5bb..9eb00573e 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hint_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hint_str.py @@ -18,9 +18,10 @@ def validate_hint( return None if not isinstance(hint, str): - raise TypeError( + msg = ( f"{self.connection.__class__.__name__} requires 'hint' parameter type to be 'str', " - f"got {hint.__class__.__name__!r}", + f"got {hint.__class__.__name__!r}" ) + raise TypeError(msg) return hint diff --git a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py index 89c860cb1..af0259912 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_hwm_expression_str.py @@ -15,9 +15,10 @@ def validate_hwm(self, hwm: HWM | None) -> HWM | None: return hwm if not isinstance(hwm.expression, str): - raise TypeError( + msg = ( f"{self.connection.__class__.__name__} requires 'hwm.expression' parameter type to be 'str', " - f"got {hwm.expression.__class__.__name__!r}", + f"got {hwm.expression.__class__.__name__!r}" ) + raise TypeError(msg) return hwm diff --git a/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py index 56985c3b3..7c18aa6bb 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py +++ b/onetl/connection/db_connection/dialect_mixins/support_name_with_schema_only.py @@ -6,8 +6,7 @@ class SupportNameWithSchemaOnly: def validate_name(self, value: str) -> str: if "." not in value: - raise ValueError( - f"Name should be passed in `schema.name` format, got '{value}'", - ) + msg = f"Name should be passed in `schema.name` format, got '{value}'" + raise ValueError(msg) return value diff --git a/onetl/connection/db_connection/dialect_mixins/support_where_str.py b/onetl/connection/db_connection/dialect_mixins/support_where_str.py index 44657d396..fccfe5e61 100644 --- a/onetl/connection/db_connection/dialect_mixins/support_where_str.py +++ b/onetl/connection/db_connection/dialect_mixins/support_where_str.py @@ -18,9 +18,10 @@ def validate_where( return None if not isinstance(where, str): - raise TypeError( + msg = ( f"{self.connection.__class__.__name__} requires 'where' parameter type to be 'str', " - f"got {where.__class__.__name__!r}", + f"got {where.__class__.__name__!r}" ) + raise TypeError(msg) return where diff --git a/onetl/connection/db_connection/greenplum/__init__.py b/onetl/connection/db_connection/greenplum/__init__.py index ce7690b3d..2f9628dcb 100644 --- a/onetl/connection/db_connection/greenplum/__init__.py +++ b/onetl/connection/db_connection/greenplum/__init__.py @@ -3,7 +3,21 @@ from onetl.connection.db_connection.greenplum.connection import Greenplum from onetl.connection.db_connection.greenplum.dialect import GreenplumDialect from onetl.connection.db_connection.greenplum.options import ( + GreenplumExecuteOptions, + GreenplumFetchOptions, GreenplumReadOptions, + GreenplumSQLOptions, GreenplumTableExistBehavior, GreenplumWriteOptions, ) + +__all__ = [ + "Greenplum", + "GreenplumDialect", + "GreenplumExecuteOptions", + "GreenplumFetchOptions", + "GreenplumReadOptions", + "GreenplumSQLOptions", + "GreenplumTableExistBehavior", + "GreenplumWriteOptions", +] diff --git a/onetl/connection/db_connection/greenplum/connection.py b/onetl/connection/db_connection/greenplum/connection.py index cdc743ca7..aa5115a2b 100644 --- a/onetl/connection/db_connection/greenplum/connection.py +++ b/onetl/connection/db_connection/greenplum/connection.py @@ -17,7 +17,7 @@ try: from pydantic.v1 import SecretStr, validator except (ImportError, AttributeError): - from pydantic import validator, SecretStr # type: ignore[no-redef, assignment] + from pydantic import SecretStr, validator # type: ignore[no-redef, assignment] from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version @@ -41,13 +41,11 @@ GreenplumTableExistBehavior, GreenplumWriteOptions, ) -from onetl.connection.db_connection.jdbc_mixin import JDBCMixin -from onetl.connection.db_connection.jdbc_mixin.options import ( +from onetl.connection.db_connection.jdbc_mixin import ( JDBCExecuteOptions, JDBCFetchOptions, -) -from onetl.connection.db_connection.jdbc_mixin.options import ( - JDBCOptions as JDBCMixinOptions, + JDBCMixin, + JDBCMixinOptions, ) from onetl.exception import MISSING_JVM_CLASS_MSG, TooManyParallelJobsError from onetl.hooks import slot, support_hooks @@ -73,7 +71,7 @@ class GreenplumExtra(GenericOptions): # avoid closing connections from server side # while connector is moving data to executors before insert - tcpKeepAlive: str = "true" # noqa: N815 + tcpKeepAlive: str = "true" class Config: extra = "allow" @@ -81,7 +79,7 @@ class Config: @support_hooks -class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338 +class Greenplum(JDBCMixin, DBConnection): """Greenplum connection. |support_hooks| Based on package ``io.pivotal:greenplum-spark:2.2.0`` @@ -122,7 +120,8 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338 Supported options are: * All `Postgres JDBC driver properties `_ - * Properties from `Greenplum connector for Spark documentation `_ page, but only starting with ``server.`` or ``pool.`` + * Properties from `Greenplum connector for Spark documentation `_ page, + but only starting with ``server.`` or ``pool.`` Examples -------- @@ -162,7 +161,7 @@ class Greenplum(JDBCMixin, DBConnection): # noqa: WPS338 extra=extra, spark=spark, ).check() - """ + """ # noqa: E501 host: Host user: str @@ -239,27 +238,29 @@ def get_packages( """ # Connector version is fixed, so we can perform checks for Scala/Spark version - if package_version: - package_ver = Version(package_version) - else: - package_ver = Version("2.2.0") + package_ver = Version(package_version or "2.2.0") if scala_version: scala_ver = Version(scala_version).min_digits(2) elif spark_version: spark_ver = Version(spark_version).min_digits(2) if spark_ver >= Version("3.3"): - raise ValueError(f"Spark version must be 3.2.x or less, got {spark_ver}") + msg = f"Spark version must be 3.2.x or less, got {spark_ver}" + raise ValueError(msg) scala_ver = get_default_scala_version(spark_ver) else: - raise ValueError("You should pass either `scala_version` or `spark_version`") + msg = "You should pass either `scala_version` or `spark_version`" + raise ValueError(msg) return [f"io.pivotal:greenplum-spark_{scala_ver.format('{0}.{1}')}:{package_ver}"] @classproperty def package_spark_3_2(cls) -> str: """Get package name to be downloaded by Spark 3.2.""" - msg = "`Greenplum.package_3_2` will be removed in 1.0.0, use `Greenplum.get_packages(spark_version='3.2')` instead" + msg = ( + "`Greenplum.package_3_2` will be removed in 1.0.0, " + "use `Greenplum.get_packages(spark_version='3.2')` instead" + ) warnings.warn(msg, UserWarning, stacklevel=3) return "io.pivotal:greenplum-spark_2.12:2.2.0" @@ -279,7 +280,7 @@ def jdbc_custom_params(self) -> dict: result = { key: value for key, value in self.extra.dict(by_alias=True).items() - if not (key.startswith("server.") or key.startswith("pool.")) + if not key.startswith(("server.", "pool.")) } # https://www.postgresql.org/docs/current/runtime-config-logging.html#GUC-APPLICATION-NAME result["ApplicationName"] = result.get("ApplicationName", get_client_info(self.spark, limit=64)) @@ -294,7 +295,7 @@ def jdbc_params(self) -> dict: @slot def check(self): log.info("|%s| Checking connection availability...", self.__class__.__name__) - self._log_parameters() # type: ignore + self._log_parameters() log.debug("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) @@ -313,12 +314,13 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @slot - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, @@ -458,10 +460,10 @@ def _get_connector_params( self, table: str, ) -> dict: - schema, table_name = table.split(".") # noqa: WPS414 + schema, table_name = table.split(".") extra = self.extra.dict(by_alias=True, exclude_none=True) greenplum_connector_options = { - key: value for key, value in extra.items() if key.startswith("server.") or key.startswith("pool.") + key: value for key, value in extra.items() if key.startswith(("server.", "pool.")) } # Greenplum connector requires all JDBC params to be passed via JDBC URL: @@ -481,23 +483,27 @@ def _get_connector_params( **greenplum_connector_options, } - def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, read_only: bool): + def _get_jdbc_connection( + self, + options: JDBCFetchOptions | JDBCExecuteOptions, + *, + read_only: bool, + ): if read_only: # To properly support pgbouncer, we have to create connection with readOnly option set. # See https://github.com/pgjdbc/pgjdbc/issues/848 options = options.copy(update={"readOnly": True}) connection_properties = self._options_to_connection_properties(options) - driver_manager = self.spark._jvm.java.sql.DriverManager # type: ignore - # avoid calling .setReadOnly(True) here - return driver_manager.getConnection(self.jdbc_url, connection_properties) + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + return jvm.java.sql.DriverManager.getConnection(self.jdbc_url, connection_properties) # type: ignore[union-attr] def _get_server_setting(self, name: str) -> Any: query = f""" SELECT setting FROM pg_settings WHERE name = '{name}' - """ + """ # noqa: S608 log.debug("|%s| Executing SQL query (on driver):") log_lines(log, query, level=logging.DEBUG) @@ -506,6 +512,7 @@ def _get_server_setting(self, name: str) -> Any: log.debug( "|%s| Query succeeded, resulting in-memory dataframe contains %d rows", + self.__class__.__name__, len(result), ) if result: @@ -527,6 +534,7 @@ def _get_occupied_connections_count(self) -> int: log.debug( "|%s| Query succeeded, resulting in-memory dataframe contains %d rows", + self.__class__.__name__, len(result), ) return int(result[0][0]) diff --git a/onetl/connection/db_connection/greenplum/dialect.py b/onetl/connection/db_connection/greenplum/dialect.py index aca3ba424..a1d66c22e 100644 --- a/onetl/connection/db_connection/greenplum/dialect.py +++ b/onetl/connection/db_connection/greenplum/dialect.py @@ -15,7 +15,7 @@ ) -class GreenplumDialect( # noqa: WPS215 +class GreenplumDialect( SupportNameWithSchemaOnly, SupportColumns, NotSupportDFSchema, diff --git a/onetl/connection/db_connection/greenplum/options.py b/onetl/connection/db_connection/greenplum/options.py index a94dc6b9b..3807f2ea0 100644 --- a/onetl/connection/db_connection/greenplum/options.py +++ b/onetl/connection/db_connection/greenplum/options.py @@ -12,11 +12,11 @@ from pydantic import Field, root_validator # type: ignore[no-redef, assignment] from onetl._util.alias import avoid_alias -from onetl.connection.db_connection.jdbc_connection.options import JDBCSQLOptions -from onetl.connection.db_connection.jdbc_mixin import JDBCOptions -from onetl.connection.db_connection.jdbc_mixin.options import ( +from onetl.connection.db_connection.jdbc_connection import JDBCSQLOptions +from onetl.connection.db_connection.jdbc_mixin import ( JDBCExecuteOptions, JDBCFetchOptions, + JDBCMixinOptions, ) # options from which are populated by Greenplum class methods @@ -59,8 +59,8 @@ class GreenplumTableExistBehavior(str, Enum): def __str__(self) -> str: return str(self.value) - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 + @classmethod + def _missing_(cls, value: object): if str(value) == "overwrite": warnings.warn( "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " @@ -69,15 +69,17 @@ def _missing_(cls, value: object): # noqa: WPS120 stacklevel=4, ) return cls.REPLACE_ENTIRE_TABLE + return None -class GreenplumReadOptions(JDBCOptions): +class GreenplumReadOptions(JDBCMixinOptions): """VMware's Greenplum Spark connector reading options. .. warning:: Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, - etc are populated from connection attributes, and cannot be overridden by the user in ``ReadOptions`` to avoid issues. + etc are populated from connection attributes, + and cannot be overridden by the user in ``ReadOptions`` to avoid issues. Examples -------- @@ -102,7 +104,7 @@ class GreenplumReadOptions(JDBCOptions): class Config: known_options = READ_OPTIONS | READ_WRITE_OPTIONS - prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | WRITE_OPTIONS + prohibited_options = JDBCMixinOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | WRITE_OPTIONS partition_column: Optional[str] = Field(alias="partitionColumn") """Column used to parallelize reading from a table. @@ -205,13 +207,14 @@ class Config: """ -class GreenplumWriteOptions(JDBCOptions): +class GreenplumWriteOptions(JDBCMixinOptions): """VMware's Greenplum Spark connector writing options. .. warning:: Some options, like ``url``, ``dbtable``, ``server.*``, ``pool.*``, etc - are populated from connection attributes, and cannot be overridden by the user in ``WriteOptions`` to avoid issues. + are populated from connection attributes, and cannot be overridden + by the user in ``WriteOptions`` to avoid issues. Examples -------- @@ -237,7 +240,7 @@ class GreenplumWriteOptions(JDBCOptions): class Config: known_options = WRITE_OPTIONS | READ_WRITE_OPTIONS - prohibited_options = JDBCOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | READ_OPTIONS + prohibited_options = JDBCMixinOptions.Config.prohibited_options | GENERIC_PROHIBITED_OPTIONS | READ_OPTIONS if_exists: GreenplumTableExistBehavior = Field( # type: ignore[literal-required] default=GreenplumTableExistBehavior.APPEND, diff --git a/onetl/connection/db_connection/hive/__init__.py b/onetl/connection/db_connection/hive/__init__.py index d4f5f8dff..35efb13a7 100644 --- a/onetl/connection/db_connection/hive/__init__.py +++ b/onetl/connection/db_connection/hive/__init__.py @@ -2,9 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.connection.db_connection.hive.connection import Hive from onetl.connection.db_connection.hive.dialect import HiveDialect -from onetl.connection.db_connection.hive.options import ( - HiveLegacyOptions, - HiveTableExistBehavior, - HiveWriteOptions, -) -from onetl.connection.db_connection.hive.slots import HiveSlots +from onetl.connection.db_connection.hive.options import HiveLegacyOptions, HiveTableExistBehavior, HiveWriteOptions + +__all__ = [ + "Hive", + "HiveDialect", + "HiveLegacyOptions", + "HiveTableExistBehavior", + "HiveWriteOptions", +] diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 555bfa754..f26ceb500 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -145,15 +145,16 @@ def get_current(cls, spark: SparkSession): # injecting current cluster name via hooks mechanism hive = Hive.get_current(spark=spark) - """ + """ # noqa: E501 log.info("|%s| Detecting current cluster...", cls.__name__) current_cluster = cls.Slots.get_current_cluster() if not current_cluster: - raise RuntimeError( + msg = ( f"{cls.__name__}.get_current() can be used only if there are " - f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", + f"some hooks bound to {cls.__name__}.Slots.get_current_cluster" ) + raise RuntimeError(msg) log.info("|%s| Got %r", cls.__name__, current_cluster) return cls(cluster=current_cluster, spark=spark) # type: ignore[arg-type] @@ -170,7 +171,8 @@ def check(self): log.debug("|%s| Detecting current cluster...", self.__class__.__name__) current_cluster = self.Slots.get_current_cluster() if current_cluster and self.cluster != current_cluster: - raise ValueError("You can connect to a Hive cluster only from the same cluster") + msg = "You can connect to a Hive cluster only from the same cluster" + raise ValueError(msg) log.info("|%s| Checking connection availability...", self.__class__.__name__) self._log_parameters() @@ -184,7 +186,8 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @@ -223,7 +226,7 @@ def sql( with override_job_description(self.spark, f"{self}.sql()"): df = self._execute_sql(query) except Exception: - log.error("|%s| Query failed", self.__class__.__name__) + log.exception("|%s| Query failed", self.__class__.__name__) metrics = recorder.metrics() if log.isEnabledFor(logging.DEBUG) and not metrics.is_empty: @@ -271,7 +274,8 @@ def execute( with override_job_description(self.spark, f"{self}.execute()"): self._execute_sql(statement).collect() except Exception: - log.error("|%s| Execution failed", self.__class__.__name__) + log.exception("|%s| Execution failed", self.__class__.__name__) + metrics = recorder.metrics() if log.isEnabledFor(logging.DEBUG) and not metrics.is_empty: # as SparkListener results are not guaranteed to be received in time, @@ -306,7 +310,8 @@ def write_df_to_target( return if write_options.if_exists == HiveTableExistBehavior.ERROR: - raise ValueError("Operation stopped due to Hive.WriteOptions(if_exists='error')") + msg = "Operation stopped due to Hive.WriteOptions(if_exists='error')" + raise ValueError(msg) if write_options.if_exists == HiveTableExistBehavior.IGNORE: log.info( @@ -320,7 +325,7 @@ def write_df_to_target( self._insert_into(df, target, options) @slot - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, @@ -401,14 +406,13 @@ def _validate_cluster_name(cls, cluster): log.debug("|%s| Normalizing cluster %r name...", cls.__name__, cluster) validated_cluster = cls.Slots.normalize_cluster_name(cluster) or cluster if validated_cluster != cluster: - log.debug("|%s| Got %r", cls.__name__) + log.debug("|%s| Got %r", cls.__name__, validated_cluster) log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: - raise ValueError( - f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", - ) + msg = f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}" + raise ValueError(msg) return validated_cluster @@ -539,7 +543,7 @@ def _save_as_table( write_options = self.WriteOptions.parse(options) writer = df.write - for method, value in write_options.dict( # noqa: WPS352 + for method, value in write_options.dict( by_alias=True, exclude_none=True, exclude={"if_exists", "format", "table_properties"}, diff --git a/onetl/connection/db_connection/hive/dialect.py b/onetl/connection/db_connection/hive/dialect.py index 4b79675dd..e310c2e79 100644 --- a/onetl/connection/db_connection/hive/dialect.py +++ b/onetl/connection/db_connection/hive/dialect.py @@ -13,7 +13,7 @@ ) -class HiveDialect( # noqa: WPS215 +class HiveDialect( SupportNameWithSchemaOnly, SupportColumns, NotSupportDFSchema, diff --git a/onetl/connection/db_connection/hive/options.py b/onetl/connection/db_connection/hive/options.py index aef2da921..17f70cc24 100644 --- a/onetl/connection/db_connection/hive/options.py +++ b/onetl/connection/db_connection/hive/options.py @@ -29,8 +29,8 @@ class HiveTableExistBehavior(str, Enum): def __str__(self): return str(self.value) - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 + @classmethod + def _missing_(cls, value: object): if str(value) == "overwrite": warnings.warn( "Mode `overwrite` is deprecated since v0.4.0 and will be removed in v1.0.0. " @@ -57,6 +57,7 @@ def _missing_(cls, value: object): # noqa: WPS120 stacklevel=4, ) return cls.REPLACE_ENTIRE_TABLE + return None class HiveWriteOptions(GenericOptions): @@ -114,14 +115,17 @@ class Config: * Table exists, but not partitioned, :obj:`~partition_by` is set Data is appended to a table. Table is still not partitioned (DDL is unchanged). - * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` + * Table exists and partitioned, + but has different partitioning schema than :obj:`~partition_by` Partition is created based on table's ``PARTITIONED BY (...)`` options. Explicit :obj:`~partition_by` value is ignored. - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + * Table exists and partitioned according :obj:`~partition_by`, + but partition is present only in dataframe Partition is created. - * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table + * Table exists and partitioned according :obj:`~partition_by`, + partition is present in both dataframe and table Data is appended to existing partition. .. warning:: @@ -132,7 +136,8 @@ class Config: To implement deduplication, write data to staging table first, and then perform some deduplication logic using :obj:`~sql`. - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe + * Table exists and partitioned according :obj:`~partition_by`, + but partition is present only in table, not dataframe Existing partition is left intact. * ``replace_overlapping_partitions`` @@ -152,19 +157,24 @@ class Config: Table is created using options provided by user (``format``, ``compression``, etc). * Table exists, but not partitioned, :obj:`~partition_by` is set - Data is **overwritten in all the table**. Table is still not partitioned (DDL is unchanged). + Data is **overwritten in all the table**. + Table is still not partitioned (DDL is unchanged). - * Table exists and partitioned, but has different partitioning schema than :obj:`~partition_by` + * Table exists and partitioned, + but has different partitioning schema than :obj:`~partition_by` Partition is created based on table's ``PARTITIONED BY (...)`` options. Explicit :obj:`~partition_by` value is ignored. - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + * Table exists and partitioned according :obj:`~partition_by`, + but partition is present only in dataframe Partition is created. - * Table exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and table + * Table exists and partitioned according :obj:`~partition_by`, + partition is present in both dataframe and table Existing partition **replaced** with data from dataframe. - * Table exists and partitioned according :obj:`~partition_by`, but partition is present only in table, not dataframe + * Table exists and partitioned according :obj:`~partition_by`, + but partition is present only in table, not dataframe Existing partition is left intact. * ``replace_entire_table`` @@ -257,10 +267,11 @@ class Config: Used **only** while **creating new table**, or in case of ``if_exists=replace_entire_table`` """ - bucket_by: Optional[Tuple[int, Union[List[str], str]]] = Field(default=None, alias="bucketBy") # noqa: WPS234 + bucket_by: Optional[Tuple[int, Union[List[str], str]]] = Field(default=None, alias="bucketBy") """Number of buckets plus bucketing columns. ``None`` means bucketing is disabled. - Each bucket is created as a set of files with name containing result of calculation ``hash(columns) mod num_buckets``. + Each bucket is created as a set of files with name containing result of + calculation ``hash(columns) mod num_buckets``. This allows to remove shuffle from queries containing ``GROUP BY`` or ``JOIN`` or using ``=`` / ``IN`` predicates on specific columns. @@ -335,7 +346,8 @@ def _sort_by_cannot_be_used_without_bucket_by(cls, sort_by, values): options = values.copy() bucket_by = options.pop("bucket_by", None) if sort_by and not bucket_by: - raise ValueError("`sort_by` option can only be used with non-empty `bucket_by`") + msg = "`sort_by` option can only be used with non-empty `bucket_by`" + raise ValueError(msg) return sort_by @@ -347,15 +359,15 @@ def _partition_overwrite_mode_is_not_allowed(cls, values): recommend_mode = "replace_entire_table" else: recommend_mode = "replace_overlapping_partitions" - raise ValueError( - f"`partitionOverwriteMode` option should be replaced with if_exists='{recommend_mode}'", - ) + msg = f"`partitionOverwriteMode` option should be replaced with if_exists='{recommend_mode}'" + raise ValueError(msg) if values.get("insert_into") is not None or values.get("insertInto") is not None: - raise ValueError( + msg = ( "`insertInto` option was removed in onETL 0.4.0, " - "now df.write.insertInto or df.write.saveAsTable is selected based on table existence", + "now df.write.insertInto or df.write.saveAsTable is selected based on table existence" ) + raise ValueError(msg) return values diff --git a/onetl/connection/db_connection/iceberg/__init__.py b/onetl/connection/db_connection/iceberg/__init__.py index 9f3c1ee57..409c6ee0b 100644 --- a/onetl/connection/db_connection/iceberg/__init__.py +++ b/onetl/connection/db_connection/iceberg/__init__.py @@ -4,6 +4,7 @@ from onetl.connection.db_connection.iceberg.dialect import IcebergDialect from onetl.connection.db_connection.iceberg.extra import IcebergExtra from onetl.connection.db_connection.iceberg.options import ( + IcebergTableExistBehavior, IcebergWriteOptions, ) @@ -11,5 +12,6 @@ "Iceberg", "IcebergDialect", "IcebergExtra", + "IcebergTableExistBehavior", "IcebergWriteOptions", ] diff --git a/onetl/connection/db_connection/iceberg/catalog/__init__.py b/onetl/connection/db_connection/iceberg/catalog/__init__.py index b7a2fdef3..83a4cacfc 100644 --- a/onetl/connection/db_connection/iceberg/catalog/__init__.py +++ b/onetl/connection/db_connection/iceberg/catalog/__init__.py @@ -5,3 +5,9 @@ IcebergFilesystemCatalog, ) from onetl.connection.db_connection.iceberg.catalog.rest import IcebergRESTCatalog + +__all__ = [ + "IcebergCatalog", + "IcebergFilesystemCatalog", + "IcebergRESTCatalog", +] diff --git a/onetl/connection/db_connection/iceberg/catalog/auth/base.py b/onetl/connection/db_connection/iceberg/catalog/auth/base.py index 28a6762b5..bcbf6ffac 100644 --- a/onetl/connection/db_connection/iceberg/catalog/auth/base.py +++ b/onetl/connection/db_connection/iceberg/catalog/auth/base.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict class IcebergRESTCatalogAuth(ABC): @@ -14,5 +13,5 @@ class IcebergRESTCatalogAuth(ABC): """ @abstractmethod - def get_config(self) -> Dict[str, str]: + def get_config(self) -> dict[str, str]: """Return REST catalog auth configuration.""" diff --git a/onetl/connection/db_connection/iceberg/catalog/auth/basic.py b/onetl/connection/db_connection/iceberg/catalog/auth/basic.py index 1785f25ec..1074dd0e2 100644 --- a/onetl/connection/db_connection/iceberg/catalog/auth/basic.py +++ b/onetl/connection/db_connection/iceberg/catalog/auth/basic.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2025-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from __future__ import annotations try: from pydantic.v1 import SecretStr @@ -44,7 +44,7 @@ class IcebergRESTCatalogBasicAuth(IcebergRESTCatalogAuth, FrozenModel): user: str password: SecretStr - def get_config(self) -> Dict[str, str]: + def get_config(self) -> dict[str, str]: return { "rest.auth.type": "basic", "rest.auth.basic.username": self.user, diff --git a/onetl/connection/db_connection/iceberg/catalog/auth/bearer.py b/onetl/connection/db_connection/iceberg/catalog/auth/bearer.py index e17164039..8861ef961 100644 --- a/onetl/connection/db_connection/iceberg/catalog/auth/bearer.py +++ b/onetl/connection/db_connection/iceberg/catalog/auth/bearer.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2025-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from __future__ import annotations try: from pydantic.v1 import SecretStr @@ -39,7 +39,7 @@ class IcebergRESTCatalogBearerAuth(IcebergRESTCatalogAuth, FrozenModel): # https://github.com/apache/iceberg/blob/720ef99720a1c59e4670db983c951243dffc4f3e/core/src/main/java/org/apache/iceberg/rest/auth/OAuth2Properties.java#L24-L25 access_token: SecretStr - def get_config(self) -> Dict[str, str]: + def get_config(self) -> dict[str, str]: return { "rest.auth.type": "oauth2", "token": self.access_token.get_secret_value(), diff --git a/onetl/connection/db_connection/iceberg/catalog/auth/oauth2_client_credentials.py b/onetl/connection/db_connection/iceberg/catalog/auth/oauth2_client_credentials.py index a44265341..df3165a60 100644 --- a/onetl/connection/db_connection/iceberg/catalog/auth/oauth2_client_credentials.py +++ b/onetl/connection/db_connection/iceberg/catalog/auth/oauth2_client_credentials.py @@ -1,12 +1,14 @@ # SPDX-FileCopyrightText: 2025-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from datetime import timedelta -from typing import Dict, List, Optional +from typing import List, Optional try: from pydantic.v1 import AnyUrl, Field, SecretStr except (ImportError, AttributeError): - from pydantic import Field, SecretStr, AnyUrl # type: ignore[no-redef, assignment] + from pydantic import AnyUrl, Field, SecretStr # type: ignore[no-redef, assignment] from onetl._util.spark import stringify from onetl.connection.db_connection.iceberg.catalog.auth import IcebergRESTCatalogAuth @@ -101,7 +103,7 @@ class IcebergRESTCatalogOAuth2ClientCredentials(IcebergRESTCatalogAuth, FrozenMo audience: Optional[str] = None resource: Optional[str] = None - def get_config(self) -> Dict[str, str]: + def get_config(self) -> dict[str, str]: config = { "rest.auth.type": "oauth2", "token-exchange-enabled": "false", diff --git a/onetl/connection/db_connection/iceberg/catalog/filesystem.py b/onetl/connection/db_connection/iceberg/catalog/filesystem.py index 0fa0c148e..4da292936 100644 --- a/onetl/connection/db_connection/iceberg/catalog/filesystem.py +++ b/onetl/connection/db_connection/iceberg/catalog/filesystem.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2025-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from typing import Dict +from __future__ import annotations from onetl.connection.db_connection.iceberg.catalog import IcebergCatalog from onetl.impl.frozen_model import FrozenModel @@ -32,7 +32,7 @@ class IcebergFilesystemCatalog(IcebergCatalog, FrozenModel): catalog = Iceberg.FilesystemCatalog() """ - def get_config(self) -> Dict[str, str]: + def get_config(self) -> dict[str, str]: return { "type": "hadoop", } diff --git a/onetl/connection/db_connection/iceberg/catalog/rest.py b/onetl/connection/db_connection/iceberg/catalog/rest.py index 748a7cb37..0abab8156 100644 --- a/onetl/connection/db_connection/iceberg/catalog/rest.py +++ b/onetl/connection/db_connection/iceberg/catalog/rest.py @@ -7,7 +7,7 @@ try: from pydantic.v1 import AnyUrl, Field except (ImportError, AttributeError): - from pydantic import Field, AnyUrl # type: ignore[no-redef, assignment] + from pydantic import AnyUrl, Field # type: ignore[no-redef, assignment] from onetl._util.spark import stringify from onetl.connection.db_connection.iceberg.catalog import IcebergCatalog diff --git a/onetl/connection/db_connection/iceberg/connection.py b/onetl/connection/db_connection/iceberg/connection.py index 3568779d7..fd0481ceb 100644 --- a/onetl/connection/db_connection/iceberg/connection.py +++ b/onetl/connection/db_connection/iceberg/connection.py @@ -228,14 +228,14 @@ def __init__( catalog_name: str, catalog: IcebergCatalog, warehouse: Optional[IcebergWarehouse] = None, - extra: Union[IcebergExtra, Dict[str, Any]] = IcebergExtra(), # noqa: B008, WPS404 + extra: Union[IcebergExtra, Dict[str, Any], None] = None, ): super().__init__( spark=spark, catalog_name=catalog_name, # type: ignore[call-arg] catalog=catalog, # type: ignore[call-arg] warehouse=warehouse, # type: ignore[call-arg] - extra=extra, # type: ignore[call-arg] + extra=extra or IcebergExtra(), # type: ignore[call-arg] ) for k, v in self._get_spark_config().items(): self.spark.conf.set(k, v) @@ -339,7 +339,8 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @@ -376,7 +377,7 @@ def sql( with override_job_description(self.spark, f"{self}.sql()"): df = self._execute_sql(query) except Exception: - log.error("|%s| Query failed", self.__class__.__name__) + log.exception("|%s| Query failed", self.__class__.__name__) metrics = recorder.metrics() if log.isEnabledFor(logging.DEBUG) and not metrics.is_empty: @@ -422,7 +423,7 @@ def execute( with override_job_description(self.spark, f"{self}.execute()"): self._execute_sql(statement).collect() except Exception: - log.error("|%s| Execution failed", self.__class__.__name__) + log.exception("|%s| Execution failed", self.__class__.__name__) metrics = recorder.metrics() if log.isEnabledFor(logging.DEBUG) and not metrics.is_empty: # as SparkListener results are not guaranteed to be received in time, @@ -457,7 +458,8 @@ def write_df_to_target( return if write_options.if_exists == IcebergTableExistBehavior.ERROR: - raise ValueError("Operation stopped due to Iceberg.WriteOptions(if_exists='error')") + msg = "Operation stopped due to Iceberg.WriteOptions(if_exists='error')" + raise ValueError(msg) if write_options.if_exists == IcebergTableExistBehavior.IGNORE: log.info( @@ -469,7 +471,7 @@ def write_df_to_target( self._insert_into(df, target, options) @slot - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, diff --git a/onetl/connection/db_connection/iceberg/dialect.py b/onetl/connection/db_connection/iceberg/dialect.py index b9fce96dd..58156a684 100644 --- a/onetl/connection/db_connection/iceberg/dialect.py +++ b/onetl/connection/db_connection/iceberg/dialect.py @@ -13,7 +13,7 @@ ) -class IcebergDialect( # noqa: WPS215 +class IcebergDialect( SupportNameWithSchemaOnly, SupportColumns, NotSupportDFSchema, diff --git a/onetl/connection/db_connection/iceberg/warehouse/delegated.py b/onetl/connection/db_connection/iceberg/warehouse/delegated.py index 81018db70..b02c36b64 100644 --- a/onetl/connection/db_connection/iceberg/warehouse/delegated.py +++ b/onetl/connection/db_connection/iceberg/warehouse/delegated.py @@ -67,7 +67,7 @@ class IcebergDelegatedWarehouse(IcebergWarehouse, FrozenModel): # other params passed to S3 client (optional) extra={"client.region": "us-east-1"}, ) - """ + """ # noqa: E501 name: Optional[str] = None access_delegation: Literal["vended-credentials", "remote-signing"] diff --git a/onetl/connection/db_connection/iceberg/warehouse/filesystem.py b/onetl/connection/db_connection/iceberg/warehouse/filesystem.py index f4e0d986a..183c8341b 100644 --- a/onetl/connection/db_connection/iceberg/warehouse/filesystem.py +++ b/onetl/connection/db_connection/iceberg/warehouse/filesystem.py @@ -94,14 +94,14 @@ class IcebergFilesystemWarehouse(IcebergWarehouse, FrozenModel): @slot def get_config(self) -> dict[str, str]: config = { - "warehouse": self.connection._convert_to_url(self.path), # noqa: WPS437 + "warehouse": self.connection._convert_to_url(self.path), # noqa: SLF001 "io-impl": "org.apache.iceberg.hadoop.HadoopFileIO", } if isinstance(self.connection, SparkS3): - prefix = self.connection._get_hadoop_config_prefix() # noqa: WPS437 + prefix = self.connection._get_hadoop_config_prefix() # noqa: SLF001 hadoop_config = { "hadoop." + k: v - for k, v in self.connection._get_expected_hadoop_config(prefix).items() # noqa: WPS437 + for k, v in self.connection._get_expected_hadoop_config(prefix).items() # noqa: SLF001 } config.update(hadoop_config) diff --git a/onetl/connection/db_connection/iceberg/warehouse/s3.py b/onetl/connection/db_connection/iceberg/warehouse/s3.py index e5fd0c224..c36483a81 100644 --- a/onetl/connection/db_connection/iceberg/warehouse/s3.py +++ b/onetl/connection/db_connection/iceberg/warehouse/s3.py @@ -14,7 +14,7 @@ try: from pydantic.v1 import Field, SecretStr, validator except (ImportError, AttributeError): - from pydantic import validator, Field, SecretStr # type: ignore[no-redef, assignment] + from pydantic import Field, SecretStr, validator # type: ignore[no-redef, assignment] from onetl._util.spark import stringify from onetl.base import PurePathProtocol diff --git a/onetl/connection/db_connection/jdbc_connection/__init__.py b/onetl/connection/db_connection/jdbc_connection/__init__.py index 163189f4d..44b9f2085 100644 --- a/onetl/connection/db_connection/jdbc_connection/__init__.py +++ b/onetl/connection/db_connection/jdbc_connection/__init__.py @@ -3,8 +3,23 @@ from onetl.connection.db_connection.jdbc_connection.connection import JDBCConnection from onetl.connection.db_connection.jdbc_connection.dialect import JDBCDialect from onetl.connection.db_connection.jdbc_connection.options import ( + JDBCFetchOptions, + JDBCLegacyOptions, JDBCPartitioningMode, JDBCReadOptions, + JDBCSQLOptions, JDBCTableExistBehavior, JDBCWriteOptions, ) + +__all__ = [ + "JDBCConnection", + "JDBCDialect", + "JDBCFetchOptions", + "JDBCLegacyOptions", + "JDBCPartitioningMode", + "JDBCReadOptions", + "JDBCSQLOptions", + "JDBCTableExistBehavior", + "JDBCWriteOptions", +] diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 6b7ba6af4..eb05709c4 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -58,7 +58,7 @@ @support_hooks -class JDBCConnection(JDBCMixin, DBConnection): # noqa: WPS338 +class JDBCConnection(JDBCMixin, DBConnection): user: str password: SecretStr @@ -90,7 +90,7 @@ def _check_java_class_imported(cls, spark): @slot def check(self): log.info("|%s| Checking connection availability...", self.__class__.__name__) - self._log_parameters() # type: ignore + self._log_parameters() log.debug("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, self._CHECK_QUERY, level=logging.DEBUG) @@ -102,7 +102,8 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @@ -152,14 +153,14 @@ def sql( with override_job_description(self.spark, f"{self}.sql()"): df = self._query_on_executor(query, self.SQLOptions.parse(options)) except Exception: - log.error("|%s| Query failed!", self.__class__.__name__) + log.exception("|%s| Query failed!", self.__class__.__name__) raise log.info("|Spark| DataFrame successfully created from SQL statement") return df @slot - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, @@ -199,7 +200,7 @@ def read_source_as_df( else: partition_column = read_options.partition_column - # hack to avoid column name verification + # avoid column name verification # in the spark, the expression in the partitioning of the column must # have the same name as the field in the table ( 2.4 version ) # https://github.com/apache/spark/pull/21379 diff --git a/onetl/connection/db_connection/jdbc_connection/dialect.py b/onetl/connection/db_connection/jdbc_connection/dialect.py index de5cb4827..5b2f0ebe6 100644 --- a/onetl/connection/db_connection/jdbc_connection/dialect.py +++ b/onetl/connection/db_connection/jdbc_connection/dialect.py @@ -15,7 +15,7 @@ ) -class JDBCDialect( # noqa: WPS215 +class JDBCDialect( SupportNameWithSchemaOnly, SupportColumns, NotSupportDFSchema, diff --git a/onetl/connection/db_connection/jdbc_connection/options.py b/onetl/connection/db_connection/jdbc_connection/options.py index 7f43dd246..7b7b8c7ae 100644 --- a/onetl/connection/db_connection/jdbc_connection/options.py +++ b/onetl/connection/db_connection/jdbc_connection/options.py @@ -89,8 +89,8 @@ class JDBCTableExistBehavior(str, Enum): def __str__(self) -> str: return str(self.value) - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 + @classmethod + def _missing_(cls, value: object): if str(value) == "overwrite": warnings.warn( "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " @@ -99,6 +99,7 @@ def _missing_(cls, value: object): # noqa: WPS120 stacklevel=4, ) return cls.REPLACE_ENTIRE_TABLE + return None class JDBCPartitioningMode(str, Enum): @@ -155,7 +156,8 @@ class Config: .. note:: Column type depends on :obj:`~partitioning_mode`. - * ``partitioning_mode="range"`` requires column to be an integer, date or timestamp (can be NULL, but not recommended). + * ``partitioning_mode="range"`` requires column to be an integer, + date or timestamp (can be NULL, but not recommended). * ``partitioning_mode="hash"`` accepts any column type (NOT NULL). * ``partitioning_mode="mod"`` requires column to be an integer (NOT NULL). @@ -166,10 +168,10 @@ class Config: See documentation for :obj:`~partitioning_mode` for more details""" lower_bound: Optional[int] = Field(default=None, alias="lowerBound") - """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 + """See documentation for :obj:`~partitioning_mode` for more details""" upper_bound: Optional[int] = Field(default=None, alias="upperBound") - """See documentation for :obj:`~partitioning_mode` for more details""" # noqa: WPS322 + """See documentation for :obj:`~partitioning_mode` for more details""" session_init_statement: Optional[str] = Field(default=None, alias="sessionInitStatement") '''After each database session is opened to the remote DB and before starting to read data, @@ -386,10 +388,12 @@ def _partitioning_mode_actions(cls, values): if num_partitions == 1: return values - raise ValueError("You should set partition_column to enable partitioning") + msg = "You should set partition_column to enable partitioning" + raise ValueError(msg) - elif num_partitions == 1: - raise ValueError("You should set num_partitions > 1 to enable partitioning") + if num_partitions == 1: + msg = "You should set num_partitions > 1 to enable partitioning" + raise ValueError(msg) if mode == JDBCPartitioningMode.RANGE: return values @@ -626,13 +630,13 @@ class JDBCSQLOptions(GenericOptions): """ num_partitions: Optional[int] = Field(default=None, alias="numPartitions") - """Number of jobs created by Spark to read the table content in parallel.""" # noqa: WPS322 + """Number of jobs created by Spark to read the table content in parallel.""" lower_bound: Optional[int] = Field(default=None, alias="lowerBound") - """Defines the starting boundary for partitioning the query's data. Mandatory if :obj:`~partition_column` is set""" # noqa: WPS322 + """Defines the starting boundary for partitioning the query's data. Mandatory if :obj:`~partition_column` is set""" upper_bound: Optional[int] = Field(default=None, alias="upperBound") - """Sets the ending boundary for data partitioning. Mandatory if :obj:`~partition_column` is set""" # noqa: WPS322 + """Sets the ending boundary for data partitioning. Mandatory if :obj:`~partition_column` is set""" session_init_statement: Optional[str] = Field(default=None, alias="sessionInitStatement") '''After each database session is opened to the remote DB and before starting to read data, @@ -690,9 +694,9 @@ def _check_partition_fields(cls, values): lower_bound = values.get("lower_bound") upper_bound = values.get("upper_bound") - if num_partitions is not None and num_partitions > 1: - if lower_bound is None or upper_bound is None: - raise ValueError("lowerBound and upperBound must be set if numPartitions > 1") + if num_partitions is not None and num_partitions > 1 and (lower_bound is None or upper_bound is None): + msg = "lowerBound and upperBound must be set if numPartitions > 1" + raise ValueError(msg) return values diff --git a/onetl/connection/db_connection/jdbc_mixin/__init__.py b/onetl/connection/db_connection/jdbc_mixin/__init__.py index 3fb814f93..010ff7ee1 100644 --- a/onetl/connection/db_connection/jdbc_mixin/__init__.py +++ b/onetl/connection/db_connection/jdbc_mixin/__init__.py @@ -1,7 +1,18 @@ # SPDX-FileCopyrightText: 2022-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.jdbc_mixin.connection import ( - JDBCMixin, - JDBCStatementType, +from onetl.connection.db_connection.jdbc_mixin.connection import JDBCMixin, JDBCStatementType +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCExecuteOptions, + JDBCFetchOptions, ) -from onetl.connection.db_connection.jdbc_mixin.options import JDBCOptions +from onetl.connection.db_connection.jdbc_mixin.options import ( + JDBCOptions as JDBCMixinOptions, +) + +__all__ = [ + "JDBCExecuteOptions", + "JDBCFetchOptions", + "JDBCMixin", + "JDBCMixinOptions", + "JDBCStatementType", +] diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index b81c41666..83a52462f 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -137,7 +137,7 @@ def close(self): def __enter__(self): return self - def __exit__(self, _exc_type, _exc_value, _traceback): # noqa: U101 + def __exit__(self, _exc_type, _exc_value, _traceback): self.close() @slot @@ -186,7 +186,7 @@ def fetch( log_lines(log, query) call_options = ( - self.FetchOptions.parse(options.dict()) # type: ignore + self.FetchOptions.parse(options.dict()) if isinstance(options, JDBCMixinOptions) else self.FetchOptions.parse(options) ) @@ -195,7 +195,7 @@ def fetch( try: df = self._query_on_driver(query, call_options) except Exception: - log.error("|%s| Query failed!", self.__class__.__name__) + log.exception("|%s| Query failed!", self.__class__.__name__) raise log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__) @@ -254,7 +254,7 @@ def execute( log_lines(log, statement) call_options = ( - self.ExecuteOptions.parse(options.dict()) # type: ignore + self.ExecuteOptions.parse(options.dict()) if isinstance(options, JDBCMixinOptions) else self.ExecuteOptions.parse(options) ) @@ -263,7 +263,7 @@ def execute( try: df = self._call_on_driver(statement, call_options) except Exception: - log.error("|%s| Execution failed!", self.__class__.__name__) + log.exception("|%s| Execution failed!", self.__class__.__name__) raise if not df: @@ -344,20 +344,26 @@ def _options_to_connection_properties(self, options: JDBCFetchOptions | JDBCExec """ jdbc_properties = self._get_jdbc_properties(options, exclude_none=True) - jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore - jdbc_options = jdbc_utils_package.JDBCOptions( + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + JdbcUtils = jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore[union-attr] # noqa: N806 + jdbc_options = JdbcUtils.JDBCOptions( self.jdbc_url, # JDBCOptions class requires `table` argument to be passed, but it is not used in asConnectionProperties "table", - self.spark._jvm.PythonUtils.toScalaMap(jdbc_properties), # type: ignore + jvm.PythonUtils.toScalaMap(jdbc_properties), # type: ignore[union-attr] ) return jdbc_options.asConnectionProperties() - def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, read_only: bool): + def _get_jdbc_connection( + self, + options: JDBCFetchOptions | JDBCExecuteOptions, + *, + read_only: bool, + ): connection_properties = self._options_to_connection_properties(options) - driver_manager = self.spark._jvm.java.sql.DriverManager # type: ignore - connection = driver_manager.getConnection(self.jdbc_url, connection_properties) - connection.setReadOnly(read_only) # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + connection = jvm.java.sql.DriverManager.getConnection(self.jdbc_url, connection_properties) # type: ignore[union-attr] + connection.setReadOnly(read_only) return connection def _get_spark_dialect_class_name(self) -> str: @@ -372,11 +378,11 @@ def _get_spark_dialect_class_name(self) -> str: return dialect.getCanonicalName().split("$")[0] def _get_spark_dialect(self): - jdbc_dialects_package = self.spark._jvm.org.apache.spark.sql.jdbc # type: ignore - return jdbc_dialects_package.JdbcDialects.get(self.jdbc_url) + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + return jvm.org.apache.spark.sql.jdbc.JdbcDialects.get(self.jdbc_url) def _get_statement_args(self) -> tuple[int, ...]: - resultset = self.spark._jvm.java.sql.ResultSet # type: ignore + resultset = self.spark._jvm.java.sql.ResultSet # type: ignore[attr-defined, union-attr] # noqa: SLF001 return resultset.TYPE_FORWARD_ONLY, resultset.CONCUR_READ_ONLY def _execute_on_driver( @@ -385,7 +391,7 @@ def _execute_on_driver( statement_type: JDBCStatementType, callback: Callable[..., T], options: JDBCFetchOptions | JDBCExecuteOptions, - read_only: bool, + read_only: bool, # noqa: FBT001 ) -> T: """ Actually execute statement on driver. @@ -397,19 +403,19 @@ def _execute_on_driver( """ statement_args = self._get_statement_args() - jdbc_connection = self._get_jdbc_connection(options, read_only) + jdbc_connection = self._get_jdbc_connection(options, read_only=read_only) with closing(jdbc_connection): jdbc_statement = self._build_statement(statement, statement_type, jdbc_connection, statement_args) return self._execute_statement(jdbc_connection, jdbc_statement, statement, options, callback, read_only) - def _execute_statement( + def _execute_statement( # noqa: PLR0913 self, jdbc_connection, jdbc_statement, statement: str, options: JDBCFetchOptions | JDBCExecuteOptions, callback: Callable[..., T], - read_only: bool, + read_only: bool, # noqa: FBT001 ) -> T: """ Executes ``java.sql.Statement`` or child class and passes it into the callback function. @@ -422,8 +428,10 @@ def _execute_statement( from py4j.java_gateway import is_instance_of gateway = get_java_gateway(self.spark) - prepared_statement = self.spark._jvm.java.sql.PreparedStatement # type: ignore - callable_statement = self.spark._jvm.java.sql.CallableStatement # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + + is_prepared = is_instance_of(gateway, jdbc_statement, jvm.java.sql.PreparedStatement) # type: ignore[union-attr] + is_callable = is_instance_of(gateway, jdbc_statement, jvm.java.sql.CallableStatement) # type: ignore[union-attr] with closing(jdbc_statement): if options.fetchsize is not None: @@ -433,9 +441,7 @@ def _execute_statement( jdbc_statement.setQueryTimeout(options.query_timeout) # Java SQL classes are not consistent.. - if is_instance_of(gateway, jdbc_statement, prepared_statement): - jdbc_statement.execute() - elif is_instance_of(gateway, jdbc_statement, callable_statement): + if is_prepared or is_callable: jdbc_statement.execute() elif read_only: jdbc_statement.executeQuery(statement) @@ -505,37 +511,36 @@ def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame: * https://github.com/apache/spark/blob/v3.2.0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L337-L343 """ - from pyspark.sql import DataFrame # noqa: WPS442 + from pyspark.sql import DataFrame jdbc_dialect = self._get_spark_dialect() - jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore - jdbc_utils = jdbc_utils_package.JdbcUtils - - java_converters = self.spark._jvm.scala.collection.JavaConverters # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + JdbcUtils = jvm.org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils # type: ignore[union-attr] # noqa: N806 + JavaConverters = jvm.scala.collection.JavaConverters # type: ignore[union-attr] # noqa: N806 spark_version = get_spark_version(self.spark) if spark_version >= Version("4.0"): - result_schema = jdbc_utils.getSchema( + result_schema = JdbcUtils.getSchema( jdbc_connection, result_set, jdbc_dialect, - False, # noqa: WPS425 - False, # noqa: WPS425 + False, # noqa: FBT003 + False, # noqa: FBT003 ) elif spark_version >= Version("3.4"): # https://github.com/apache/spark/commit/2349175e1b81b0a61e1ed90c2d051c01cf78de9b - result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False, False) # noqa: WPS425 + result_schema = JdbcUtils.getSchema(result_set, jdbc_dialect, False, False) # noqa: FBT003 else: - result_schema = jdbc_utils.getSchema(result_set, jdbc_dialect, False) # noqa: WPS425 + result_schema = JdbcUtils.getSchema(result_set, jdbc_dialect, False) # noqa: FBT003 - if spark_version.major >= 4: - result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema, jdbc_dialect) + if spark_version.major >= 4: # noqa: PLR2004 + result_iterator = JdbcUtils.resultSetToRows(result_set, result_schema, jdbc_dialect) else: - result_iterator = jdbc_utils.resultSetToRows(result_set, result_schema) + result_iterator = JdbcUtils.resultSetToRows(result_set, result_schema) - result_list = java_converters.seqAsJavaListConverter(result_iterator.toSeq()).asJava() - jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore + result_list = JavaConverters.seqAsJavaListConverter(result_iterator.toSeq()).asJava() + jdf = self.spark._jsparkSession.createDataFrame(result_list, result_schema) # type: ignore[attr-defined] # noqa: SLF001 # DataFrame constructor in Spark 2.3 and 2.4 required second argument to be a SQLContext class # E.g. spark._wrapped = SQLContext(spark). @@ -543,4 +548,4 @@ def _resultset_to_dataframe(self, jdbc_connection, result_set) -> DataFrame: # attribute was removed from SparkSession spark_context = getattr(self.spark, "_wrapped", self.spark) - return DataFrame(jdf, spark_context) # type: ignore + return DataFrame(jdf, spark_context) diff --git a/onetl/connection/db_connection/kafka/__init__.py b/onetl/connection/db_connection/kafka/__init__.py index b97e83948..f38c753f7 100644 --- a/onetl/connection/db_connection/kafka/__init__.py +++ b/onetl/connection/db_connection/kafka/__init__.py @@ -1,3 +1,35 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.connection.db_connection.kafka.connection import Kafka +from onetl.connection.db_connection.kafka.dialect import KafkaDialect +from onetl.connection.db_connection.kafka.extra import KafkaExtra +from onetl.connection.db_connection.kafka.kafka_auth import KafkaAuth +from onetl.connection.db_connection.kafka.kafka_basic_auth import KafkaBasicAuth +from onetl.connection.db_connection.kafka.kafka_kerberos_auth import KafkaKerberosAuth +from onetl.connection.db_connection.kafka.kafka_plaintext_protocol import KafkaPlaintextProtocol +from onetl.connection.db_connection.kafka.kafka_protocol import KafkaProtocol +from onetl.connection.db_connection.kafka.kafka_scram_auth import KafkaScramAuth +from onetl.connection.db_connection.kafka.kafka_ssl_protocol import KafkaSSLProtocol +from onetl.connection.db_connection.kafka.options import ( + KafkaReadOptions, + KafkaTopicExistBehaviorKafka, + KafkaWriteOptions, +) +from onetl.connection.db_connection.kafka.slots import KafkaSlots + +__all__ = [ + "Kafka", + "KafkaAuth", + "KafkaBasicAuth", + "KafkaDialect", + "KafkaExtra", + "KafkaKerberosAuth", + "KafkaPlaintextProtocol", + "KafkaProtocol", + "KafkaReadOptions", + "KafkaSSLProtocol", + "KafkaScramAuth", + "KafkaSlots", + "KafkaTopicExistBehaviorKafka", + "KafkaWriteOptions", +] diff --git a/onetl/connection/db_connection/kafka/connection.py b/onetl/connection/db_connection/kafka/connection.py index 6d1cac121..701a5f6ba 100644 --- a/onetl/connection/db_connection/kafka/connection.py +++ b/onetl/connection/db_connection/kafka/connection.py @@ -202,8 +202,7 @@ class Kafka(DBConnection): extra={"max.request.size": 1024 * 1024}, # <-- spark=spark, ).check() - - """ + """ # noqa: E501 BasicAuth = KafkaBasicAuth KerberosAuth = KafkaKerberosAuth @@ -239,11 +238,12 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @slot - def read_source_as_df( # noqa: WPS231 + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, @@ -252,14 +252,16 @@ def read_source_as_df( # noqa: WPS231 df_schema: StructType | None = None, window: Window | None = None, limit: int | None = None, - options: KafkaReadOptions = KafkaReadOptions(), # noqa: B008, WPS404 + options: KafkaReadOptions | None = None, ) -> DataFrame: log.info("|%s| Reading data from topic %r", self.__class__.__name__, source) if source not in self._get_topics(): - raise ValueError(f"Topic {source!r} doesn't exist") + msg = f"Topic {source!r} doesn't exist" + raise ValueError(msg) result_options = {f"kafka.{key}": value for key, value in self._get_connection_properties().items()} - result_options.update(options.dict(by_alias=True, exclude_none=True)) + if options: + result_options.update(options.dict(by_alias=True, exclude_none=True)) result_options["subscribe"] = source if window and window.expression == "offset": @@ -293,7 +295,7 @@ def write_df_to_target( self, df: DataFrame, target: str, - options: KafkaWriteOptions = KafkaWriteOptions(), # noqa: B008, WPS404 + options: KafkaWriteOptions | None = None, ) -> None: # Check that the DataFrame doesn't contain any columns not in the schema required_columns = {"value"} @@ -302,15 +304,19 @@ def write_df_to_target( df_columns = set(df.columns) if not df_columns.issubset(allowed_columns): invalid_columns = df_columns - allowed_columns - raise ValueError( + msg = ( f"Invalid column names: {sorted(invalid_columns)}. " f"Expected columns: {sorted(required_columns)} (required)," - f" {sorted(optional_columns)} (optional)", + f" {sorted(optional_columns)} (optional)" ) + raise ValueError(msg) + + options = options or KafkaWriteOptions() # Check that the DataFrame doesn't contain a 'headers' column with includeHeaders=False if not options.include_headers and "headers" in df.columns: - raise ValueError("Cannot write 'headers' column with kafka.WriteOptions(include_headers=False)") + msg = "Cannot write 'headers' column with kafka.WriteOptions(include_headers=False)" + raise ValueError(msg) if "topic" in df.columns: log.warning("The 'topic' column in the DataFrame will be overridden with value %r", target) @@ -324,7 +330,8 @@ def write_df_to_target( # https://issues.apache.org/jira/browse/SPARK-44774 mode = options.if_exists if mode == KafkaTopicExistBehaviorKafka.ERROR and target in self._get_topics(): - raise TargetAlreadyExistsError(f"Topic {target} already exists") + msg = f"Topic {target} already exists" + raise TargetAlreadyExistsError(msg) log.info("|%s| Saving data to a topic %r", self.__class__.__name__, target) df.write.format("kafka").mode(mode).options(**write_options).save() @@ -335,9 +342,9 @@ def get_df_schema( self, source: str, columns: list[str] | None = None, - options: KafkaReadOptions = KafkaReadOptions(), # noqa: WPS404 + options: KafkaReadOptions | None = None, ) -> StructType: - from pyspark.sql.types import ( # noqa: WPS442 + from pyspark.sql.types import ( ArrayType, BinaryType, IntegerType, @@ -510,7 +517,7 @@ def get_min_max_values( # https://kafka.apache.org/22/javadoc/org/apache/kafka/clients/consumer/KafkaConsumer.html#partitionsFor-java.lang.String- partition_infos = consumer.partitionsFor(source) - jvm = self.spark._jvm # type: ignore[attr-defined] + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 topic_partitions = [ jvm.org.apache.kafka.common.TopicPartition(source, p.partition()) # type: ignore[union-attr] for p in partition_infos @@ -569,7 +576,8 @@ def _get_addresses_by_cluster(cls, values): log.debug("|%s| Set cluster %r addresses: %r", cls.__name__, cluster, cluster_addresses) values["addresses"] = cluster_addresses else: - raise ValueError("Passed empty parameter 'addresses'") + msg = "Passed empty parameter 'addresses'" + raise ValueError(msg) return values @validator("cluster") @@ -582,9 +590,8 @@ def _validate_cluster_name(cls, cluster): log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: - raise ValueError( - f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", - ) + msg = f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}" + raise ValueError(msg) return validated_cluster @@ -601,7 +608,8 @@ def _validate_addresses(cls, value, values): cluster_addresses = set(cls.Slots.get_cluster_addresses(cluster) or []) unknown_addresses = set(validated_addresses) - cluster_addresses if cluster_addresses and unknown_addresses: - raise ValueError(f"Cluster {cluster!r} does not contain addresses {unknown_addresses!r}") + msg = f"Cluster {cluster!r} does not contain addresses {unknown_addresses!r}" + raise ValueError(msg) return validated_addresses @@ -641,12 +649,12 @@ def _get_java_consumer(self): "value.deserializer": "org.apache.kafka.common.serialization.ByteArrayDeserializer", }, ) - jvm = self.spark._jvm + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 consumer_class = jvm.org.apache.kafka.clients.consumer.KafkaConsumer return consumer_class(connection_properties) def _get_topics(self, timeout: int = 10) -> set[str]: - jvm = self.spark._jvm # type: ignore[attr-defined] + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 # Maybe we should not pass explicit timeout at all, # and instead use default.api.timeout.ms which is configurable via self.extra. # Think about this next time if someone see issues in real use diff --git a/onetl/connection/db_connection/kafka/dialect.py b/onetl/connection/db_connection/kafka/dialect.py index 9e9719e22..bbde1bdc0 100644 --- a/onetl/connection/db_connection/kafka/dialect.py +++ b/onetl/connection/db_connection/kafka/dialect.py @@ -21,21 +21,22 @@ log = logging.getLogger(__name__) -class KafkaDialect( # noqa: WPS215 +class KafkaDialect( NotSupportColumns, NotSupportDFSchema, NotSupportHint, NotSupportWhere, DBDialect, ): - SUPPORTED_HWM_COLUMNS = {"offset"} + SUPPORTED_HWM_COLUMNS = frozenset(("offset",)) def validate_name(self, value: str) -> str: if "*" in value or "," in value: - raise ValueError( + msg = ( f"source/target={value} is not supported by {self.connection.__class__.__name__}. " - f"Provide a singular topic.", + f"Provide a singular topic." ) + raise ValueError(msg) return value def validate_hwm( @@ -46,10 +47,11 @@ def validate_hwm( return None if hwm.expression not in self.SUPPORTED_HWM_COLUMNS: - raise ValueError( + msg = ( f"hwm.expression={hwm.expression!r} is not supported by {self.connection.__class__.__name__}. " - f"Valid values are: {self.SUPPORTED_HWM_COLUMNS}", + f"Valid values are: {self.SUPPORTED_HWM_COLUMNS}" ) + raise ValueError(msg) return hwm def detect_hwm_class(self, field: StructField) -> type[KeyValueIntHWM] | None: diff --git a/onetl/connection/db_connection/kafka/extra.py b/onetl/connection/db_connection/kafka/extra.py index ef2b6b68e..d3589cec3 100644 --- a/onetl/connection/db_connection/kafka/extra.py +++ b/onetl/connection/db_connection/kafka/extra.py @@ -33,6 +33,6 @@ class KafkaExtra(GenericOptions): """ class Config: - strip_prefixes = ["kafka."] + strip_prefixes = ("kafka.",) prohibited_options = PROHIBITED_OPTIONS extra = "allow" diff --git a/onetl/connection/db_connection/kafka/kafka_basic_auth.py b/onetl/connection/db_connection/kafka/kafka_basic_auth.py index abac06759..f001ce74b 100644 --- a/onetl/connection/db_connection/kafka/kafka_basic_auth.py +++ b/onetl/connection/db_connection/kafka/kafka_basic_auth.py @@ -43,7 +43,7 @@ class KafkaBasicAuth(KafkaAuth, GenericOptions): user: str = Field(alias="username") password: SecretStr - def get_jaas_conf(self) -> str: # noqa: WPS473 + def get_jaas_conf(self) -> str: return ( "org.apache.kafka.common.security.plain.PlainLoginModule required " f'username="{self.user}" ' diff --git a/onetl/connection/db_connection/kafka/kafka_kerberos_auth.py b/onetl/connection/db_connection/kafka/kafka_kerberos_auth.py index 58791a221..45e29a044 100644 --- a/onetl/connection/db_connection/kafka/kafka_kerberos_auth.py +++ b/onetl/connection/db_connection/kafka/kafka_kerberos_auth.py @@ -133,7 +133,7 @@ class KafkaKerberosAuth(KafkaAuth, GenericOptions): class Config: prohibited_options = PROHIBITED_OPTIONS known_options = KNOWN_OPTIONS - strip_prefixes = ["kafka."] + strip_prefixes = ("kafka.",) extra = "allow" def get_jaas_conf(self, kafka: Kafka) -> str: @@ -180,7 +180,8 @@ def _use_keytab(cls, values): keytab = values.get("keytab") use_keytab = values.get("use_keytab") if use_keytab and not keytab: - raise ValueError("keytab is required if useKeytab is True") + msg = "keytab is required if useKeytab is True" + raise ValueError(msg) return values def _prepare_keytab(self, kafka: Kafka) -> str: diff --git a/onetl/connection/db_connection/kafka/kafka_scram_auth.py b/onetl/connection/db_connection/kafka/kafka_scram_auth.py index dea4715e8..af9d8483e 100644 --- a/onetl/connection/db_connection/kafka/kafka_scram_auth.py +++ b/onetl/connection/db_connection/kafka/kafka_scram_auth.py @@ -64,13 +64,13 @@ class KafkaScramAuth(KafkaAuth, GenericOptions): digest: Literal["SHA-256", "SHA-512"] class Config: - strip_prefixes = ["kafka."] + strip_prefixes = ("kafka.",) # https://kafka.apache.org/documentation/#producerconfigs_sasl.login.class - known_options = {"sasl.login.*"} - prohibited_options = {"sasl.mechanism", "sasl.jaas.config"} + known_options = frozenset(("sasl.login.*",)) + prohibited_options = frozenset(("sasl.mechanism", "sasl.jaas.config")) extra = "allow" - def get_jaas_conf(self) -> str: # noqa: WPS473 + def get_jaas_conf(self) -> str: return ( "org.apache.kafka.common.security.scram.ScramLoginModule required " f'username="{self.user}" ' diff --git a/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py b/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py index f2670aef6..64639d9d7 100644 --- a/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py +++ b/onetl/connection/db_connection/kafka/kafka_ssl_protocol.py @@ -56,10 +56,10 @@ class KafkaSSLProtocol(KafkaProtocol, GenericOptions): protocol = Kafka.SSLProtocol( keystore_type="PEM", - keystore_certificate_chain="-----BEGIN CERTIFICATE-----\\nMIIDZjC...\\n-----END CERTIFICATE-----", - keystore_key="-----BEGIN PRIVATE KEY-----\\nMIIEvg..\\n-----END PRIVATE KEY-----", + keystore_certificate_chain="-----BEGIN CERTIFICATE-----\\n...\\n-----END CERTIFICATE-----", + keystore_key="-----BEGIN PRIVATE KEY-----\\...\\n-----END PRIVATE KEY-----", truststore_type="PEM", - truststore_certificates="-----BEGIN CERTIFICATE-----\\nMICC...\\n-----END CERTIFICATE-----", + truststore_certificates="-----BEGIN CERTIFICATE-----\\n...\\n-----END CERTIFICATE-----", ) Pass custom options: @@ -70,10 +70,10 @@ class KafkaSSLProtocol(KafkaProtocol, GenericOptions): { # Just the same options as above, but using Kafka config naming with dots "ssl.keystore.type": "PEM", - "ssl.keystore.certificate_chain": "-----BEGIN CERTIFICATE-----\\nMIIDZjC...\\n-----END CERTIFICATE-----", - "ssl.keystore.key": "-----BEGIN PRIVATE KEY-----\\nMIIEvg..\\n-----END PRIVATE KEY-----", + "ssl.keystore.certificate_chain": "-----BEGIN CERTIFICATE-----\\n...\\n-----END CERTIFICATE-----", + "ssl.keystore.key": "-----BEGIN PRIVATE KEY-----\\n...\\n-----END PRIVATE KEY-----", "ssl.truststore.type": "PEM", - "ssl.truststore.certificates": "-----BEGIN CERTIFICATE-----\\nMICC...\\n-----END CERTIFICATE-----", + "ssl.truststore.certificates": "-----BEGIN CERTIFICATE-----\\n...\\n-----END CERTIFICATE-----", # Any option starting from "ssl." is passed to Kafka client as-is "ssl.protocol": "TLSv1.3", } @@ -133,8 +133,8 @@ class KafkaSSLProtocol(KafkaProtocol, GenericOptions): truststore_certificates: Optional[str] = Field(default=None, alias="ssl.truststore.certificates", repr=False) class Config: - known_options = {"ssl.*"} - strip_prefixes = ["kafka."] + known_options = frozenset(("ssl.*",)) + strip_prefixes = ("kafka.",) extra = "allow" def get_options(self, kafka: Kafka) -> dict: diff --git a/onetl/connection/db_connection/kafka/options.py b/onetl/connection/db_connection/kafka/options.py index f103d1b57..22381519a 100644 --- a/onetl/connection/db_connection/kafka/options.py +++ b/onetl/connection/db_connection/kafka/options.py @@ -66,7 +66,8 @@ class KafkaReadOptions(GenericOptions): * ``subscribe`` * ``subscribePattern`` - are populated from connection attributes, and cannot be overridden by the user in ``ReadOptions`` to avoid issues. + are populated from connection attributes, + and cannot be overridden by the user in ``ReadOptions`` to avoid issues. .. versionadded:: 0.9.0 @@ -113,7 +114,8 @@ class KafkaWriteOptions(GenericOptions): * ``kafka.*`` * ``topic`` - are populated from connection attributes, and cannot be overridden by the user in ``WriteOptions`` to avoid issues. + are populated from connection attributes, + and cannot be overridden by the user in ``WriteOptions`` to avoid issues. .. versionadded:: 0.9.0 @@ -163,5 +165,6 @@ class Config: @root_validator(pre=True) def _mode_is_restricted(cls, values): if "mode" in values: - raise ValueError("Parameter `mode` is not allowed. Please use `if_exists` parameter instead.") + msg = "Parameter `mode` is not allowed. Please use `if_exists` parameter instead." + raise ValueError(msg) return values diff --git a/onetl/connection/db_connection/kafka/slots.py b/onetl/connection/db_connection/kafka/slots.py index 12143ee83..66bf09904 100644 --- a/onetl/connection/db_connection/kafka/slots.py +++ b/onetl/connection/db_connection/kafka/slots.py @@ -135,7 +135,8 @@ def get_cluster_addresses(cluster: str) -> list[str] | None: Returns ------- list[str] | None - A collection of broker addresses for the specified Kafka cluster. If the hook cannot be applied, return ``None``. + A collection of broker addresses for the specified Kafka cluster. + If the hook cannot be applied, return ``None``. Examples -------- diff --git a/onetl/connection/db_connection/mongodb/__init__.py b/onetl/connection/db_connection/mongodb/__init__.py index 26c4f7572..46dda5012 100644 --- a/onetl/connection/db_connection/mongodb/__init__.py +++ b/onetl/connection/db_connection/mongodb/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.mongodb.connection import MongoDB, MongoDBExtra +from onetl.connection.db_connection.mongodb.connection import MongoDB from onetl.connection.db_connection.mongodb.dialect import MongoDBDialect from onetl.connection.db_connection.mongodb.options import ( MongoDBCollectionExistBehavior, @@ -8,3 +8,12 @@ MongoDBReadOptions, MongoDBWriteOptions, ) + +__all__ = [ + "MongoDB", + "MongoDBCollectionExistBehavior", + "MongoDBDialect", + "MongoDBPipelineOptions", + "MongoDBReadOptions", + "MongoDBWriteOptions", +] diff --git a/onetl/connection/db_connection/mongodb/connection.py b/onetl/connection/db_connection/mongodb/connection.py index a2f3c7a03..420d59a60 100644 --- a/onetl/connection/db_connection/mongodb/connection.py +++ b/onetl/connection/db_connection/mongodb/connection.py @@ -146,7 +146,9 @@ def get_packages( package_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. Allows specifying custom MongoDB Spark connector versions. |support_hooks| + Get package names to be downloaded by Spark. |support_hooks| + + Allows specifying custom MongoDB Spark connector versions. .. versionadded:: 0.9.0 @@ -185,7 +187,8 @@ def get_packages( spark_ver = Version(spark_version) scala_ver = get_default_scala_version(spark_ver) else: - raise ValueError("You should pass either `scala_version` or `spark_version`") + msg = "You should pass either `scala_version` or `spark_version`" + raise ValueError(msg) connector_ver = Version(package_version or default_package_version).min_digits(2) return [f"org.mongodb.spark:mongo-spark-connector_{scala_ver.format('{0}.{1}')}:{connector_ver}"] @@ -261,7 +264,8 @@ def pipeline( Schema describing the resulting DataFrame. options : PipelineOptions | dict, optional - Additional pipeline options, see :obj:`MongoDB.PipelineOptions `. + Additional pipeline options, + see :obj:`MongoDB.PipelineOptions `. Examples -------- @@ -371,7 +375,7 @@ def check(self): self._log_parameters() try: - jvm = self.spark._jvm # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) list(client.listDatabaseNames().iterator()) @@ -382,7 +386,8 @@ def check(self): log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @@ -440,7 +445,7 @@ def get_min_max_values( return min_value, max_value @slot - def read_source_as_df( + def read_source_as_df( # noqa: PLR0913 self, source: str, columns: list[str] | None = None, @@ -501,8 +506,9 @@ def write_df_to_target( if self._collection_exists(target): # MongoDB connector does not support mode=ignore and mode=error if write_options.if_exists == MongoDBCollectionExistBehavior.ERROR: - raise ValueError("Operation stopped due to MongoDB.WriteOptions(if_exists='error')") - elif write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE: + msg = "Operation stopped due to MongoDB.WriteOptions(if_exists='error')" + raise ValueError(msg) + if write_options.if_exists == MongoDBCollectionExistBehavior.IGNORE: log.info( "|%s| Skip writing to existing collection because of MongoDB.WriteOptions(if_exists='ignore')", self.__class__.__name__, @@ -559,16 +565,16 @@ def _get_server_version(self) -> Version: if self._server_version: return self._server_version - jvm = self.spark._jvm # type: ignore[attr-defined] - client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore[union-attr] db = client.getDatabase(self.database) - command = jvm.org.bson.BsonDocument("buildinfo", jvm.org.bson.BsonString("")) # type: ignore + command = jvm.org.bson.BsonDocument("buildinfo", jvm.org.bson.BsonString("")) # type: ignore[union-attr] self._server_version = Version(db.runCommand(command).get("version")) return self._server_version def _collection_exists(self, source: str) -> bool: - jvm = self.spark._jvm # type: ignore[attr-defined] - client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore + jvm = self.spark._jvm # type: ignore[attr-defined] # noqa: SLF001 + client = jvm.com.mongodb.client.MongoClients.create(self.connection_url) # type: ignore[union-attr] collections = set(client.getDatabase(self.database).listCollectionNames().iterator()) if source in collections: log.info("|%s| Collection %r exists", self.__class__.__name__, source) diff --git a/onetl/connection/db_connection/mongodb/dialect.py b/onetl/connection/db_connection/mongodb/dialect.py index b9906723d..38210239d 100644 --- a/onetl/connection/db_connection/mongodb/dialect.py +++ b/onetl/connection/db_connection/mongodb/dialect.py @@ -14,7 +14,7 @@ ) from onetl.hwm import Edge, Window -_upper_level_operators = frozenset( # noqa: WPS527 +_upper_level_operators = frozenset( [ "$addFields", "$bucket", @@ -58,7 +58,7 @@ ) -class MongoDBDialect( # noqa: WPS215 +class MongoDBDialect( SupportNameAny, NotSupportColumns, RequiresDFSchema, @@ -73,10 +73,11 @@ def validate_where( return None if not isinstance(where, dict): - raise ValueError( + msg = ( f"{self.connection.__class__.__name__} requires 'where' parameter type to be 'dict', " - f"got {where.__class__.__name__!r}", + f"got {where.__class__.__name__!r}" ) + raise TypeError(msg) for key in where: self._validate_top_level_keys_in_where_parameter(key) @@ -90,10 +91,11 @@ def validate_hint( return None if not isinstance(hint, dict): - raise ValueError( + msg = ( f"{self.connection.__class__.__name__} requires 'hint' parameter type to be 'dict', " - f"got {hint.__class__.__name__!r}", + f"got {hint.__class__.__name__!r}" ) + raise TypeError(msg) return hint def prepare_pipeline( @@ -166,14 +168,20 @@ def _validate_top_level_keys_in_where_parameter(self, key: str): """ if key.startswith("$"): if key == "$match": - raise ValueError( + msg = ( "'$match' operator not allowed at the top level of the 'where' parameter dictionary. " "This error most likely occurred due to the fact that you used the MongoDB format for the " "pipeline {'$match': {'column': ...}}. In the onETL paradigm, you do not need to specify the " - "'$match' keyword, but write the filtering condition right away, like {'column': ...}", + "'$match' keyword, but write the filtering condition right away, like {'column': ...}" + ) + raise ValueError( + msg, ) - if key in _upper_level_operators: # noqa: WPS220 - raise ValueError( # noqa: WPS220 + if key in _upper_level_operators: + msg = ( f"An invalid parameter {key!r} was specified in the 'where' " - "field. You cannot use aggregations or 'groupBy' clauses in 'where'", + "field. You cannot use aggregations or 'groupBy' clauses in 'where'" + ) + raise ValueError( + msg, ) diff --git a/onetl/connection/db_connection/mongodb/options.py b/onetl/connection/db_connection/mongodb/options.py index 2f47251c2..9b59f8bcc 100644 --- a/onetl/connection/db_connection/mongodb/options.py +++ b/onetl/connection/db_connection/mongodb/options.py @@ -80,8 +80,8 @@ class MongoDBCollectionExistBehavior(str, Enum): def __str__(self) -> str: return str(self.value) - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 + @classmethod + def _missing_(cls, value: object): if str(value) == "overwrite": warnings.warn( "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. " @@ -90,12 +90,14 @@ def _missing_(cls, value: object): # noqa: WPS120 stacklevel=4, ) return cls.REPLACE_ENTIRE_COLLECTION + return None class MongoDBPipelineOptions(GenericOptions): """Aggregation pipeline options for MongoDB connector. - The only difference from :obj:`MongoDB.ReadOptions ` that latter does not allow to pass the ``hint`` parameter. + The only difference from :obj:`MongoDB.ReadOptions ` + that latter does not allow to pass the ``hint`` parameter. .. warning:: diff --git a/onetl/connection/db_connection/mssql/__init__.py b/onetl/connection/db_connection/mssql/__init__.py index c6d9dac03..db03e386f 100644 --- a/onetl/connection/db_connection/mssql/__init__.py +++ b/onetl/connection/db_connection/mssql/__init__.py @@ -1,4 +1,21 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.mssql.connection import MSSQL, MSSQLExtra +from onetl.connection.db_connection.mssql.connection import MSSQL from onetl.connection.db_connection.mssql.dialect import MSSQLDialect +from onetl.connection.db_connection.mssql.options import ( + MSSQLExecuteOptions, + MSSQLFetchOptions, + MSSQLReadOptions, + MSSQLSQLOptions, + MSSQLWriteOptions, +) + +__all__ = [ + "MSSQL", + "MSSQLDialect", + "MSSQLExecuteOptions", + "MSSQLFetchOptions", + "MSSQLReadOptions", + "MSSQLSQLOptions", + "MSSQLWriteOptions", +] diff --git a/onetl/connection/db_connection/mssql/connection.py b/onetl/connection/db_connection/mssql/connection.py index 69aba817e..8d2f0761c 100644 --- a/onetl/connection/db_connection/mssql/connection.py +++ b/onetl/connection/db_connection/mssql/connection.py @@ -162,8 +162,10 @@ class MSSQL(JDBCConnection): user="user", password="*****", extra={ - "applicationIntent": "ReadOnly", # driver will open read-only connection, to avoid writing to the database - "trustServerCertificate": "true", # add this to avoid SSL certificate issues + # driver will open read-only connection, to avoid writing to the database + "applicationIntent": "ReadOnly", + # add this to avoid SSL certificate issues + "trustServerCertificate": "true", }, spark=spark, ).check() @@ -194,7 +196,9 @@ def get_packages( package_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. Allows specifying custom JDBC driver versions for MSSQL. |support_hooks| + Get package names to be downloaded by Spark. |support_hooks| + + Allows specifying custom JDBC driver versions for MSSQL. .. versionadded:: 0.9.0 @@ -220,10 +224,11 @@ def get_packages( default_package_version = "13.2.1" java_ver = Version(java_version or default_java_version) - if java_ver.major < 8: - raise ValueError(f"Java version must be at least 8, got {java_ver}") + if java_ver.major < 8: # noqa: PLR2004 + msg = f"Java version must be at least 8, got {java_ver}" + raise ValueError(msg) - jre_ver = "8" if java_ver.major < 11 else "11" + jre_ver = "8" if java_ver.major < 11 else "11" # noqa: PLR2004 full_package_version = Version(package_version or default_package_version).min_digits(3) # check if a JRE suffix is already included @@ -278,7 +283,12 @@ def __str__(self): port = self.port or 1433 return f"{self.__class__.__name__}[{self.host}:{port}/{self.database}]" - def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, read_only: bool): + def _get_jdbc_connection( + self, + options: JDBCFetchOptions | JDBCExecuteOptions, + *, + read_only: bool, + ): if read_only: # connection.setReadOnly() is no-op in MSSQL: # https://learn.microsoft.com/en-us/sql/connect/jdbc/reference/setreadonly-method-sqlserverconnection?view=sql-server-ver16 @@ -286,4 +296,4 @@ def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, r # https://github.com/microsoft/mssql-jdbc/issues/484 options = options.copy(update={"ApplicationIntent": "ReadOnly"}) - return super()._get_jdbc_connection(options, read_only) + return super()._get_jdbc_connection(options, read_only=read_only) diff --git a/onetl/connection/db_connection/mssql/dialect.py b/onetl/connection/db_connection/mssql/dialect.py index 1713785a2..cff3d2673 100644 --- a/onetl/connection/db_connection/mssql/dialect.py +++ b/onetl/connection/db_connection/mssql/dialect.py @@ -18,9 +18,10 @@ def get_partition_column_mod(self, partition_column: str, num_partitions: int) - # Return positive value even for negative input return f"ABS({partition_column} % {num_partitions})" - def get_sql_query( + def get_sql_query( # noqa: PLR0913 self, table: str, + *, columns: list[str] | None = None, where: str | list[str] | None = None, hint: str | None = None, diff --git a/onetl/connection/db_connection/mysql/__init__.py b/onetl/connection/db_connection/mysql/__init__.py index 0f66e5abb..827a5ff54 100644 --- a/onetl/connection/db_connection/mysql/__init__.py +++ b/onetl/connection/db_connection/mysql/__init__.py @@ -1,4 +1,21 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.mysql.connection import MySQL, MySQLExtra +from onetl.connection.db_connection.mysql.connection import MySQL from onetl.connection.db_connection.mysql.dialect import MySQLDialect +from onetl.connection.db_connection.mysql.options import ( + MySQLExecuteOptions, + MySQLFetchOptions, + MySQLReadOptions, + MySQLSQLOptions, + MySQLWriteOptions, +) + +__all__ = [ + "MySQL", + "MySQLDialect", + "MySQLExecuteOptions", + "MySQLFetchOptions", + "MySQLReadOptions", + "MySQLSQLOptions", + "MySQLWriteOptions", +] diff --git a/onetl/connection/db_connection/mysql/connection.py b/onetl/connection/db_connection/mysql/connection.py index 25638d55b..edad851c1 100644 --- a/onetl/connection/db_connection/mysql/connection.py +++ b/onetl/connection/db_connection/mysql/connection.py @@ -31,8 +31,8 @@ class MySQLExtra(GenericOptions): - useUnicode: str = "yes" # noqa: N815 - characterEncoding: str = "UTF-8" # noqa: N815 + useUnicode: str = "yes" + characterEncoding: str = "UTF-8" class Config: extra = "allow" @@ -131,7 +131,8 @@ class MySQL(JDBCConnection): @classmethod def get_packages(cls, package_version: str | None = None) -> list[str]: """ - Get package names to be downloaded by Spark. Allows specifying a custom JDBC driver version for MySQL. |support_hooks| + Get package names to be downloaded by Spark. + Allows specifying a custom JDBC driver version for MySQL. |support_hooks| .. versionadded:: 0.9.0 @@ -193,8 +194,13 @@ def instance_url(self) -> str: def __str__(self): return f"{self.__class__.__name__}[{self.host}:{self.port}]" - def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, read_only: bool): - connection = super()._get_jdbc_connection(options, read_only) + def _get_jdbc_connection( + self, + options: JDBCFetchOptions | JDBCExecuteOptions, + *, + read_only: bool, + ): + connection = super()._get_jdbc_connection(options, read_only=read_only) # connection.setReadOnly() is no-op in MySQL JDBC driver. Session type can be changed by statement: # https://stackoverflow.com/questions/10240890/sql-open-connection-in-read-only-mode#comment123789248_48959180 diff --git a/onetl/connection/db_connection/oracle/__init__.py b/onetl/connection/db_connection/oracle/__init__.py index 3e3831282..f0b2c4f7a 100644 --- a/onetl/connection/db_connection/oracle/__init__.py +++ b/onetl/connection/db_connection/oracle/__init__.py @@ -1,4 +1,21 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.oracle.connection import Oracle, OracleExtra +from onetl.connection.db_connection.oracle.connection import Oracle from onetl.connection.db_connection.oracle.dialect import OracleDialect +from onetl.connection.db_connection.oracle.options import ( + OracleExecuteOptions, + OracleFetchOptions, + OracleReadOptions, + OracleSQLOptions, + OracleWriteOptions, +) + +__all__ = [ + "Oracle", + "OracleDialect", + "OracleExecuteOptions", + "OracleFetchOptions", + "OracleReadOptions", + "OracleSQLOptions", + "OracleWriteOptions", +] diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index c772f9e9e..98dbfd952 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -203,7 +203,9 @@ def get_packages( package_version: str | None = None, ) -> list[str]: """ - Get package names to be downloaded by Spark. Allows specifying custom JDBC driver versions for Oracle. |support_hooks| + Get package names to be downloaded by Spark. |support_hooks| + + Allows specifying custom JDBC driver versions for Oracle. Parameters ---------- @@ -229,10 +231,11 @@ def get_packages( default_package_version = "23.26.0.0.0" java_ver = Version(java_version or default_java_version) - if java_ver.major < 8: - raise ValueError(f"Java version must be at least 8, got {java_ver.major}") + if java_ver.major < 8: # noqa: PLR2004 + msg = f"Java version must be at least 8, got {java_ver.major}" + raise ValueError(msg) - jre_ver = "8" if java_ver.major < 11 else "11" + jre_ver = "8" if java_ver.major < 11 else "11" # noqa: PLR2004 jdbc_version = Version(package_version or default_package_version).min_digits(4) return [f"com.oracle.database.jdbc:ojdbc{jre_ver}:{jdbc_version}"] @@ -302,10 +305,12 @@ def _only_one_of_sid_or_service_name(cls, values): service_name = values.get("service_name") if sid and service_name: - raise ValueError("Only one of parameters ``sid``, ``service_name`` can be set, got both") + msg = "Only one of parameters ``sid``, ``service_name`` can be set, got both" + raise ValueError(msg) if not sid and not service_name: - raise ValueError("One of parameters ``sid``, ``service_name`` should be set, got none") + msg = "One of parameters ``sid``, ``service_name`` should be set, got none" + raise ValueError(msg) return values @@ -367,7 +372,7 @@ def _get_compile_errors( SEQUENCE, LINE, POSITION - """ + """ # noqa: S608 errors = self._query_on_driver(show_errors, options).collect() if not errors: return [] diff --git a/onetl/connection/db_connection/oracle/dialect.py b/onetl/connection/db_connection/oracle/dialect.py index 30754e060..37eacf902 100644 --- a/onetl/connection/db_connection/oracle/dialect.py +++ b/onetl/connection/db_connection/oracle/dialect.py @@ -8,9 +8,10 @@ class OracleDialect(JDBCDialect): - def get_sql_query( + def get_sql_query( # noqa: PLR0913 self, table: str, + *, columns: list[str] | None = None, where: str | list[str] | None = None, hint: str | None = None, diff --git a/onetl/connection/db_connection/postgres/__init__.py b/onetl/connection/db_connection/postgres/__init__.py index e26684783..016327290 100644 --- a/onetl/connection/db_connection/postgres/__init__.py +++ b/onetl/connection/db_connection/postgres/__init__.py @@ -1,4 +1,21 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from onetl.connection.db_connection.postgres.connection import Postgres, PostgresExtra +from onetl.connection.db_connection.postgres.connection import Postgres from onetl.connection.db_connection.postgres.dialect import PostgresDialect +from onetl.connection.db_connection.postgres.options import ( + PostgresExecuteOptions, + PostgresFetchOptions, + PostgresReadOptions, + PostgresSQLOptions, + PostgresWriteOptions, +) + +__all__ = [ + "Postgres", + "PostgresDialect", + "PostgresExecuteOptions", + "PostgresFetchOptions", + "PostgresReadOptions", + "PostgresSQLOptions", + "PostgresWriteOptions", +] diff --git a/onetl/connection/db_connection/postgres/connection.py b/onetl/connection/db_connection/postgres/connection.py index b6c22260a..18e148f0a 100644 --- a/onetl/connection/db_connection/postgres/connection.py +++ b/onetl/connection/db_connection/postgres/connection.py @@ -35,7 +35,7 @@ class PostgresExtra(GenericOptions): # avoid closing connections from server side # while connector is moving data to executors before insert - tcpKeepAlive: str = "true" # noqa: N815 + tcpKeepAlive: str = "true" class Config: extra = "allow" @@ -202,13 +202,18 @@ def instance_url(self) -> str: def __str__(self): return f"{self.__class__.__name__}[{self.host}:{self.port}/{self.database}]" - def _get_jdbc_connection(self, options: JDBCFetchOptions | JDBCExecuteOptions, read_only: bool): + def _get_jdbc_connection( + self, + options: JDBCFetchOptions | JDBCExecuteOptions, + *, + read_only: bool, + ): if read_only: # To properly support pgbouncer, we have to create connection with readOnly option set. # See https://github.com/pgjdbc/pgjdbc/issues/848 options = options.copy(update={"readOnly": True}) connection_properties = self._options_to_connection_properties(options) - driver_manager = self.spark._jvm.java.sql.DriverManager # type: ignore + driver_manager = self.spark._jvm.java.sql.DriverManager # type: ignore[attr-defined, union-attr] # noqa: SLF001 # avoid calling .setReadOnly(True) here return driver_manager.getConnection(self.jdbc_url, connection_properties) diff --git a/onetl/connection/file_connection/file_connection.py b/onetl/connection/file_connection/file_connection.py index 1619ac0ad..1dae3ffff 100644 --- a/onetl/connection/file_connection/file_connection.py +++ b/onetl/connection/file_connection/file_connection.py @@ -123,7 +123,7 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_value, _traceback): self.close() - def __del__(self): # noqa: WPS603 + def __del__(self): # If current object is collected by GC, close opened connection self.close() @@ -141,7 +141,8 @@ def check(self): raise except Exception as e: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self @@ -150,7 +151,8 @@ def is_file(self, path: os.PathLike | str) -> bool: remote_path = RemotePath(path) if not self.path_exists(remote_path): - raise FileNotFoundError(f"File '{remote_path}' does not exist") + msg = f"File '{remote_path}' does not exist" + raise FileNotFoundError(msg) return self._is_file(remote_path) @@ -159,7 +161,8 @@ def is_dir(self, path: os.PathLike | str) -> bool: remote_path = RemotePath(path) if not self.path_exists(remote_path): - raise DirectoryNotFoundError(f"Directory '{remote_path}' does not exist") + msg = f"Directory '{remote_path}' does not exist" + raise DirectoryNotFoundError(msg) return self._is_dir(remote_path) @@ -176,7 +179,8 @@ def resolve_dir(self, path: os.PathLike | str) -> RemoteDirectory: if not is_dir: remote_file = RemoteFile(path=remote_path, stats=stat) - raise NotADirectoryError(f"{path_repr(remote_file)} is not a directory") + msg = f"{path_repr(remote_file)} is not a directory" + raise NotADirectoryError(msg) return RemoteDirectory(path=remote_path, stats=stat) @@ -188,7 +192,8 @@ def resolve_file(self, path: os.PathLike | str) -> RemoteFile: if not is_file: remote_directory = RemoteDirectory(path=remote_path, stats=stat) - raise NotAFileError(f"{path_repr(remote_directory)} is not a file") + msg = f"{path_repr(remote_directory)} is not a file" + raise NotAFileError(msg) return RemoteFile(path=remote_path, stats=stat) @@ -215,7 +220,8 @@ def read_bytes(self, path: os.PathLike | str, **kwargs) -> bytes: @slot def write_text(self, path: os.PathLike | str, content: str, encoding: str = "utf-8", **kwargs) -> RemoteFile: if not isinstance(content, str): - raise TypeError(f"content must be str, not '{content.__class__.__name__}'") + msg = f"content must be str, not '{content.__class__.__name__}'" + raise TypeError(msg) log.debug( "|%s| Writing string size %d with encoding %r and options %r to '%s'", @@ -244,7 +250,8 @@ def write_text(self, path: os.PathLike | str, content: str, encoding: str = "utf @slot def write_bytes(self, path: os.PathLike | str, content: bytes, **kwargs) -> RemoteFile: if not isinstance(content, bytes): - raise TypeError(f"content must be bytes, not '{content.__class__.__name__}'") + msg = f"content must be bytes, not '{content.__class__.__name__}'" + raise TypeError(msg) log.debug( "|%s| Writing %s with options %e to '%s'", @@ -274,6 +281,7 @@ def download_file( self, remote_file_path: os.PathLike | str, local_file_path: os.PathLike | str, + *, replace: bool = True, ) -> LocalPath: log.debug( @@ -288,10 +296,12 @@ def download_file( if local_file.exists(): if not local_file.is_file(): - raise NotAFileError(f"{path_repr(local_file)} is not a file") + msg = f"{path_repr(local_file)} is not a file" + raise NotAFileError(msg) if not replace: - raise FileExistsError(f"File {path_repr(local_file)} already exists") + msg = f"File {path_repr(local_file)} already exists" + raise FileExistsError(msg) log.warning("|Local FS| File %s already exists, overwriting", path_repr(local_file)) local_file.unlink() @@ -301,10 +311,11 @@ def download_file( self._download_file(remote_file, local_file) if local_file.stat().st_size != remote_file.stat().st_size: - raise FileSizeMismatchError( + msg = ( f"The size of the downloaded file ({naturalsize(local_file.stat().st_size)}) does not match " - f"the size of the file on the source ({naturalsize(remote_file.stat().st_size)})", + f"the size of the file on the source ({naturalsize(remote_file.stat().st_size)})" ) + raise FileSizeMismatchError(msg) log.info("|Local FS| Successfully downloaded file '%s'", local_file) return local_file @@ -341,22 +352,26 @@ def upload_file( self, local_file_path: os.PathLike | str, remote_file_path: os.PathLike | str, + *, replace: bool = False, ) -> RemoteFile: log.debug("|%s| Uploading local file '%s' to '%s'", self.__class__.__name__, local_file_path, remote_file_path) local_file = LocalPath(local_file_path) if not local_file.exists(): - raise FileNotFoundError(f"File '{local_file}' does not exist") + msg = f"File '{local_file}' does not exist" + raise FileNotFoundError(msg) if not local_file.is_file(): - raise NotAFileError(f"{path_repr(local_file)} is not a file") + msg = f"{path_repr(local_file)} is not a file" + raise NotAFileError(msg) remote_file = RemotePath(remote_file_path) if self.path_exists(remote_file): file = self.resolve_file(remote_file_path) if not replace: - raise FileExistsError(f"File {path_repr(file)} already exists") + msg = f"File {path_repr(file)} already exists" + raise FileExistsError(msg) log.warning("|%s| File %s already exists, overwriting", self.__class__.__name__, path_repr(file)) self._remove_file(remote_file) @@ -367,10 +382,11 @@ def upload_file( result = self.resolve_file(remote_file) if result.stat().st_size != local_file.stat().st_size: - raise FileSizeMismatchError( + msg = ( f"The size of the uploaded file ({naturalsize(result.stat().st_size)}) does not match " - f"the size of the file on the source ({naturalsize(local_file.stat().st_size)})", + f"the size of the file on the source ({naturalsize(local_file.stat().st_size)})" ) + raise FileSizeMismatchError(msg) log.info("|%s| Successfully uploaded file '%s'", self.__class__.__name__, remote_file) return result @@ -380,6 +396,7 @@ def rename_file( self, source_file_path: os.PathLike | str, target_file_path: os.PathLike | str, + *, replace: bool = False, ) -> RemoteFile: log.debug("|%s| Renaming file '%s' to '%s'", self.__class__.__name__, source_file_path, target_file_path) @@ -390,7 +407,8 @@ def rename_file( if self.path_exists(target_file): file = self.resolve_file(target_file) if not replace: - raise FileExistsError(f"File {path_repr(file)} already exists") + msg = f"File {path_repr(file)} already exists" + raise FileExistsError(msg) log.warning("|%s| File %s already exists, overwriting", self.__class__.__name__, path_repr(file)) self._remove_file(target_file) @@ -438,6 +456,7 @@ def list_dir( def walk( self, root: os.PathLike | str, + *, topdown: bool = True, filters: Iterable[BaseFileFilter] | None = None, limits: Iterable[BaseFileLimit] | None = None, @@ -449,7 +468,7 @@ def walk( yield from self._walk(root_dir, topdown=topdown, filters=filters, limits=limits) @slot - def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: + def remove_dir(self, path: os.PathLike | str, *, recursive: bool = False) -> bool: description = "RECURSIVELY" if recursive else "NON-recursively" log.debug("|%s| %s removing directory '%s'", self.__class__.__name__, description, path) remote_dir = RemotePath(path) @@ -466,9 +485,10 @@ def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: if not recursive: # self.list_dir may return large list # self._scan_entries return an iterator, which have to be iterated at least once - for _entry in self._scan_entries(remote_dir): # noqa: WPS122, WPS328 + for _entry in self._scan_entries(remote_dir): + msg = "|%s| Cannot delete non-empty directory %s" raise DirectoryNotEmptyError( - "|%s| Cannot delete non-empty directory %s", + msg, self.__class__.__name__, directory_info, ) @@ -482,9 +502,10 @@ def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: log.info("|%s| Successfully removed directory '%s'", self.__class__.__name__, remote_dir) return True - def _walk( # noqa: WPS231 + def _walk( # noqa: C901 self, root: RemoteDirectory, + *, topdown: bool, filters: Iterable[BaseFileFilter], limits: Iterable[BaseFileLimit], diff --git a/onetl/connection/file_connection/ftp.py b/onetl/connection/file_connection/ftp.py index af21abf01..1d46141be 100644 --- a/onetl/connection/file_connection/ftp.py +++ b/onetl/connection/file_connection/ftp.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -import ftplib # noqa: S402 # nosec +import ftplib # nosec import os import textwrap from logging import getLogger diff --git a/onetl/connection/file_connection/ftps.py b/onetl/connection/file_connection/ftps.py index 64ddb3545..a20ff2aff 100644 --- a/onetl/connection/file_connection/ftps.py +++ b/onetl/connection/file_connection/ftps.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 -import ftplib # noqa: S402 # nosec +import ftplib # nosec import textwrap from ftputil import FTPHost @@ -24,14 +24,14 @@ ) from e -class TLSfix(ftplib.FTP_TLS): # noqa: N801 +class TLSfix(ftplib.FTP_TLS): """ Fix for python 3.6+ https://stackoverflow.com/questions/14659154/ftpes-session-reuse-required """ def ntransfercmd(self, cmd, rest=None): - conn, size = ftplib.FTP.ntransfercmd(self, cmd, rest) # noqa: S321 # nosec + conn, size = ftplib.FTP.ntransfercmd(self, cmd, rest) # noqa: S321 if self._prot_p: conn = self.context.wrap_socket( conn, diff --git a/onetl/connection/file_connection/hdfs/__init__.py b/onetl/connection/file_connection/hdfs/__init__.py index d0e82955a..7990bbe8c 100644 --- a/onetl/connection/file_connection/hdfs/__init__.py +++ b/onetl/connection/file_connection/hdfs/__init__.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.connection.file_connection.hdfs.connection import HDFS -from onetl.connection.file_connection.hdfs.slots import HDFSSlots + +__all__ = ["HDFS"] diff --git a/onetl/connection/file_connection/hdfs/connection.py b/onetl/connection/file_connection/hdfs/connection.py index 00afd5f72..db63c5199 100644 --- a/onetl/connection/file_connection/hdfs/connection.py +++ b/onetl/connection/file_connection/hdfs/connection.py @@ -20,7 +20,14 @@ validator, ) except (ImportError, AttributeError): - from pydantic import Field, FilePath, SecretStr, PrivateAttr, root_validator, validator # type: ignore[no-redef, assignment] + from pydantic import ( # type: ignore[no-redef, assignment] + Field, + FilePath, + PrivateAttr, + SecretStr, + root_validator, + validator, + ) from onetl._util.alias import avoid_alias from onetl.base import PathStatProtocol @@ -32,10 +39,10 @@ from onetl.impl import LocalPath, RemotePath, RemotePathStat try: - from hdfs import Client, InsecureClient + from hdfs import Client, InsecureClient # noqa: F401 if TYPE_CHECKING: - from hdfs.ext.kerberos import KerberosClient + from hdfs.ext.kerberos import KerberosClient # noqa: F401 except (ImportError, NameError) as err: raise ImportError( textwrap.dedent( @@ -191,7 +198,7 @@ class HDFS(FileConnection, RenameDirMixin): user="someuser", password="*****", ).check() - """ + """ # noqa: E501 cluster: Optional[Cluster] = None host: Optional[Host] = None @@ -245,10 +252,11 @@ def get_current(cls, **kwargs): log.info("|%s| Detecting current cluster...", cls.__name__) current_cluster = cls.Slots.get_current_cluster() if not current_cluster: - raise RuntimeError( + msg = ( f"{cls.__name__}.get_current() can be used only if there are " - f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", + f"some hooks bound to {cls.__name__}.Slots.get_current_cluster" ) + raise RuntimeError(msg) log.info("|%s| Got %r", cls.__name__, current_cluster) return cls(cluster=current_cluster, **kwargs) @@ -280,7 +288,7 @@ def close(self): def _validate_packages(cls, user): if user: try: - from hdfs.ext.kerberos import KerberosClient as CheckForKerberosSupport + from hdfs.ext.kerberos import KerberosClient as CheckForKerberosSupport # noqa: F401 except (ImportError, NameError) as e: raise ImportError( textwrap.dedent( @@ -304,7 +312,8 @@ def _validate_cluster_or_hostname_set(cls, values): cluster = values.get("cluster") if not cluster and not host: - raise ValueError("You should pass either host or cluster name") + msg = "You should pass either host or cluster name" + raise ValueError(msg) return values @@ -318,9 +327,8 @@ def _validate_cluster_name(cls, cluster): log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: - raise ValueError( - f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", - ) + msg = f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}" + raise ValueError(msg) return validated_cluster @@ -337,9 +345,12 @@ def _validate_host_name(cls, host, values): log.debug("|%s| Checking if %r is a known namenode of cluster %r ...", cls.__name__, namenode, cluster) known_namenodes = cls.Slots.get_cluster_namenodes(cluster) if known_namenodes and namenode not in known_namenodes: - raise ValueError( + msg = ( f"Namenode {namenode!r} is not in the known nodes list of cluster {cluster!r}: " - f"{sorted(known_namenodes)!r}", + f"{sorted(known_namenodes)!r}" + ) + raise ValueError( + msg, ) return namenode @@ -362,10 +373,12 @@ def _validate_credentials(cls, values): password = values.get("password") keytab = values.get("keytab") if password and keytab: - raise ValueError("Please provide either `keytab` or `password` for kinit, not both") + msg = "Please provide either `keytab` or `password` for kinit, not both" + raise ValueError(msg) if (password or keytab) and not user: - raise ValueError("`keytab` or `password` should be used only with `user`") + msg = "`keytab` or `password` should be used only with `user`" + raise ValueError(msg) return values @@ -375,7 +388,8 @@ def _get_active_namenode(self) -> str: namenodes = self.Slots.get_cluster_namenodes(self.cluster) if not namenodes: - raise RuntimeError(f"Cannot get list of namenodes for a cluster {self.cluster!r}") + msg = f"Cannot get list of namenodes for a cluster {self.cluster!r}" + raise RuntimeError(msg) nodes_len = len(namenodes) for i, namenode in enumerate(namenodes, start=1): @@ -385,7 +399,8 @@ def _get_active_namenode(self) -> str: return namenode log.debug("|%s| Node %r is not active, skipping", class_name, namenode) - raise RuntimeError(f"Cannot detect active namenode for cluster {self.cluster!r}") + msg = f"Cannot detect active namenode for cluster {self.cluster!r}" + raise RuntimeError(msg) def _get_host(self) -> str: if not self.host and self.cluster: @@ -408,9 +423,11 @@ def _get_host(self) -> str: return self.host if self.cluster: - raise RuntimeError(f"Host {self.host!r} is not an active namenode of cluster {self.cluster!r}") + msg = f"Host {self.host!r} is not an active namenode of cluster {self.cluster!r}" + raise RuntimeError(msg) - raise RuntimeError(f"Host {self.host!r} is not an active namenode") + msg = f"Host {self.host!r} is not an active namenode" + raise RuntimeError(msg) def _get_conn_str(self) -> str: # cache active host to reduce number of requests. @@ -420,7 +437,7 @@ def _get_conn_str(self) -> str: def _get_client(self) -> Client: if self.user and (self.keytab or self.password): - from hdfs.ext.kerberos import KerberosClient # noqa: F811 + from hdfs.ext.kerberos import KerberosClient kinit( self.user, @@ -431,7 +448,7 @@ def _get_client(self) -> Client: conn_str = self._get_conn_str() client = KerberosClient(conn_str, timeout=self.timeout) else: - from hdfs import InsecureClient # noqa: F401, WPS442, F811 + from hdfs import InsecureClient conn_str = self._get_conn_str() client = InsecureClient(conn_str, user=self.user) diff --git a/onetl/connection/file_connection/mixins/__init__.py b/onetl/connection/file_connection/mixins/__init__.py index 1089632a9..65d724e9c 100644 --- a/onetl/connection/file_connection/mixins/__init__.py +++ b/onetl/connection/file_connection/mixins/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.connection.file_connection.mixins.rename_dir_mixin import RenameDirMixin + +__all__ = ["RenameDirMixin"] diff --git a/onetl/connection/file_connection/mixins/rename_dir_mixin.py b/onetl/connection/file_connection/mixins/rename_dir_mixin.py index 749b7ce92..4c00a3fc7 100644 --- a/onetl/connection/file_connection/mixins/rename_dir_mixin.py +++ b/onetl/connection/file_connection/mixins/rename_dir_mixin.py @@ -18,6 +18,7 @@ def rename_dir( self, source_dir_path: os.PathLike | str, target_dir_path: os.PathLike | str, + *, replace: bool = False, ) -> RemoteDirectory: """ @@ -71,7 +72,8 @@ def rename_dir( if self.path_exists(target_dir): directory = self.resolve_dir(target_dir) if not replace: - raise DirectoryExistsError(f"Directory {path_repr(directory)} already exists") + msg = f"Directory {path_repr(directory)} already exists" + raise DirectoryExistsError(msg) log.warning("|%s| Directory %s already exists, removing", self.__class__.__name__, path_repr(directory)) self.remove_dir(target_dir, recursive=True) diff --git a/onetl/connection/file_connection/s3.py b/onetl/connection/file_connection/s3.py index 4c8a76ec0..d8835fea9 100644 --- a/onetl/connection/file_connection/s3.py +++ b/onetl/connection/file_connection/s3.py @@ -170,7 +170,7 @@ def create_dir(self, path: os.PathLike | str) -> RemoteDirectory: return RemoteDirectory(path=remote_directory, stats=RemotePathStat()) @slot - def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: # noqa: WPS321 + def remove_dir(self, path: os.PathLike | str, *, recursive: bool = False) -> bool: # optimize S3 directory recursive removal by using batch object deletion description = "RECURSIVELY" if recursive else "NON-recursively" log.debug("|%s| %s removing directory '%s'", self.__class__.__name__, description, path) @@ -179,7 +179,7 @@ def remove_dir(self, path: os.PathLike | str, recursive: bool = False) -> bool: is_empty = True - def _scan_entries_recursive(root: RemotePath) -> Iterable[Object]: # noqa: WPS430 + def _scan_entries_recursive(root: RemotePath) -> Iterable[Object]: nonlocal is_empty directory_path_str = self._delete_absolute_path_slash(root) + "/" entries = self.client.list_objects( @@ -188,7 +188,7 @@ def _scan_entries_recursive(root: RemotePath) -> Iterable[Object]: # noqa: WPS4 recursive=True, ) for entry in entries: - is_empty = False # noqa: WPS442 + is_empty = False name = self._extract_name_from_entry(entry) stat = self._extract_stat_from_entry(root, entry) @@ -208,9 +208,10 @@ def _scan_entries_recursive(root: RemotePath) -> Iterable[Object]: # noqa: WPS4 if not recursive: # self.list_dir may return large list # self._scan_entries return an iterator, which have to be iterated at least once - for _entry in self._scan_entries(remote_dir): # noqa: WPS122, WPS328 + for _entry in self._scan_entries(remote_dir): + msg = "|%s| Cannot delete non-empty directory %s" raise DirectoryNotEmptyError( - "|%s| Cannot delete non-empty directory %s", + msg, self.__class__.__name__, directory_info, ) @@ -251,7 +252,7 @@ def path_exists(self, path: os.PathLike | str) -> bool: return True remote_path_str = self._delete_absolute_path_slash(remote_path) - for component in self.client.list_objects( # noqa: WPS352 + for component in self.client.list_objects( bucket_name=self.bucket, prefix=remote_path_str, ): @@ -454,9 +455,10 @@ def _is_dir(self, path: RemotePath) -> bool: prefix=directory_path_str, ), ) - return True except StopIteration: return False + else: + return True def _is_file(self, path: RemotePath) -> bool: path_str = self._delete_absolute_path_slash(path) @@ -465,8 +467,9 @@ def _is_file(self, path: RemotePath) -> bool: bucket_name=self.bucket, object_name=path_str, ) - return True except S3Error as err: if err.code == "NoSuchKey": return False raise + else: + return True diff --git a/onetl/connection/file_connection/samba.py b/onetl/connection/file_connection/samba.py index 1a76badef..33e4c1f53 100644 --- a/onetl/connection/file_connection/samba.py +++ b/onetl/connection/file_connection/samba.py @@ -7,7 +7,7 @@ from io import BytesIO from logging import getLogger from pathlib import Path -from typing import Optional, Union +from typing import Optional from etl_entities.instance import Host from typing_extensions import Literal @@ -113,10 +113,10 @@ class Samba(FileConnection): host: Host share: str - protocol: Union[Literal["SMB"], Literal["NetBIOS"]] = "SMB" + protocol: Literal["SMB", "NetBIOS"] = "SMB" port: Optional[int] = None domain: str = "" - auth_type: Union[Literal["NTLMv1"], Literal["NTLMv2"]] = "NTLMv2" + auth_type: Literal["NTLMv1", "NTLMv2"] = "NTLMv2" user: Optional[str] = None password: Optional[SecretStr] = None @@ -142,10 +142,12 @@ def check(self): self.share, available_shares, ) - raise ConnectionError("Failed to connect to the Samba server.") + msg = "Failed to connect to the Samba server." + raise ConnectionError(msg) # noqa: TRY301 except Exception as exc: log.exception("|%s| Connection is unavailable", self.__class__.__name__) - raise RuntimeError("Connection is unavailable") from exc + msg = "Connection is unavailable" + raise RuntimeError(msg) from exc return self @@ -153,9 +155,10 @@ def check(self): def path_exists(self, path: os.PathLike | str) -> bool: try: self.client.getAttributes(self.share, os.fspath(path)) - return True except OperationFailure: return False + else: + return True def _scan_entries(self, path: RemotePath) -> list: if self._is_dir(path): @@ -228,7 +231,7 @@ def _close_client(self, client: SMBConnection) -> None: client.close() def _download_file(self, remote_file_path: RemotePath, local_file_path: LocalPath) -> None: - with open(local_file_path, "wb") as local_file: + with local_file_path.open("wb") as local_file: self.client.retrieveFile( self.share, os.fspath(remote_file_path), @@ -241,13 +244,13 @@ def _create_dir(self, path: RemotePath) -> None: # create dirs sequentially as .createDirectory(...) cannot create nested dirs try: self.client.getAttributes(self.share, os.fspath(parent)) - except OperationFailure: + except OperationFailure: # noqa: PERF203 self.client.createDirectory(self.share, os.fspath(parent)) self.client.createDirectory(self.share, os.fspath(path)) def _upload_file(self, local_file_path: LocalPath, remote_file_path: RemotePath) -> None: - with open(local_file_path, "rb") as file_obj: + with local_file_path.open("rb") as file_obj: self.client.storeFile( self.share, os.fspath(remote_file_path), diff --git a/onetl/connection/file_connection/sftp.py b/onetl/connection/file_connection/sftp.py index 2cfec7ca8..52a52582f 100644 --- a/onetl/connection/file_connection/sftp.py +++ b/onetl/connection/file_connection/sftp.py @@ -129,9 +129,10 @@ def __str__(self): def path_exists(self, path: os.PathLike | str) -> bool: try: self.client.stat(os.fspath(path)) - return True except FileNotFoundError: return False + else: + return True def _get_client(self) -> SFTPClient: host_proxy, key_file = self._parse_user_ssh_config() @@ -185,11 +186,11 @@ def _parse_user_ssh_config(self) -> tuple[str | None, str | None]: def _create_dir(self, path: RemotePath) -> None: try: self.client.stat(os.fspath(path)) - except Exception: + except OSError: for parent in reversed(path.parents): - try: # noqa: WPS505 + try: self.client.stat(os.fspath(parent)) - except Exception: + except OSError: # noqa: PERF203 self.client.mkdir(os.fspath(parent)) self.client.mkdir(os.fspath(path)) diff --git a/onetl/connection/file_connection/webdav.py b/onetl/connection/file_connection/webdav.py index c5f637bd1..386c613a7 100644 --- a/onetl/connection/file_connection/webdav.py +++ b/onetl/connection/file_connection/webdav.py @@ -167,9 +167,10 @@ def _get_stat(self, path: RemotePath) -> RemotePathStat: if self.client.is_dir(os.fspath(path)): return RemotePathStat() + mtime = datetime.datetime.strptime(info["modified"], DATA_MODIFIED_FORMAT) # noqa: DTZ007 return RemotePathStat( st_size=info["size"], - st_mtime=datetime.datetime.strptime(info["modified"], DATA_MODIFIED_FORMAT).timestamp(), + st_mtime=mtime.timestamp(), st_uid=info["name"], ) @@ -244,8 +245,9 @@ def _extract_stat_from_entry(self, top: RemotePath, entry: dict) -> RemotePathSt if entry["isdir"]: return RemotePathStat() + mtime = datetime.datetime.strptime(entry["modified"], DATA_MODIFIED_FORMAT) # noqa: DTZ007 return RemotePathStat( st_size=entry["size"], - st_mtime=datetime.datetime.strptime(entry["modified"], DATA_MODIFIED_FORMAT).timestamp(), + st_mtime=mtime.timestamp(), st_uid=entry["name"], ) diff --git a/onetl/connection/file_df_connection/spark_file_df_connection.py b/onetl/connection/file_df_connection/spark_file_df_connection.py index 362b90483..68a396c61 100644 --- a/onetl/connection/file_df_connection/spark_file_df_connection.py +++ b/onetl/connection/file_df_connection/spark_file_df_connection.py @@ -53,13 +53,14 @@ def check(self): fs.exists(path) log.info("|%s| Connection is available.", self.__class__.__name__) except Exception as e: - raise RuntimeError("Connection is unavailable") from e + msg = "Connection is unavailable" + raise RuntimeError(msg) from e return self def check_if_format_supported( self, - format: BaseReadableFileFormat | BaseWritableFileFormat, # noqa: WPS125 + format: BaseReadableFileFormat | BaseWritableFileFormat, ) -> None: format.check_if_supported(self.spark) @@ -67,7 +68,7 @@ def check_if_format_supported( def read_files_as_df( self, paths: list[PurePathProtocol], - format: BaseReadableFileFormat, # noqa: WPS125 + format: BaseReadableFileFormat, root: PurePathProtocol | None = None, df_schema: StructType | None = None, options: FileDFReadOptions | None = None, @@ -96,14 +97,14 @@ def read_files_as_df( df = reader.load(urls) log.info("|%s| DataFrame successfully created", self.__class__.__name__) - return df # type: ignore + return df @slot def write_df_as_files( self, df: DataFrame, path: PurePathProtocol, - format: BaseWritableFileFormat, # noqa: WPS125 + format: BaseWritableFileFormat, options: FileDFWriteOptions | None = None, ) -> None: log.info("|%s| Saving data to '%s' ...", self.__class__.__name__, path) @@ -144,7 +145,7 @@ def _get_spark_default_path(self): Return object of ``org.apache.hadoop.fs.Path`` class for :obj:`~_get_default_path`. """ url = self._convert_to_url(self._get_default_path()) - jvm = self.spark._jvm # noqa: WPS437 + jvm = self.spark._jvm # noqa: SLF001 return jvm.org.apache.hadoop.fs.Path(url) # type: ignore[union-attr] def _get_spark_fs(self): @@ -159,7 +160,7 @@ def _get_spark_fs(self): def _forward_refs(cls) -> dict[str, type]: try_import_pyspark() - from pyspark.sql import SparkSession # noqa: WPS442 + from pyspark.sql import SparkSession # avoid importing pyspark unless user called the constructor, # as we allow user to use `Connection.get_packages()` for creating Spark session @@ -172,7 +173,7 @@ def _check_spark_session_alive(cls, spark): # https://stackoverflow.com/a/36044685 msg = "Spark session is stopped. Please recreate Spark session." try: - if not spark._jsc.sc().isStopped(): + if not spark._jsc.sc().isStopped(): # noqa: SLF001 return spark except Exception as e: # None has no attribute "something" diff --git a/onetl/connection/file_df_connection/spark_hdfs/__init__.py b/onetl/connection/file_df_connection/spark_hdfs/__init__.py index 736754a2a..a28a64422 100644 --- a/onetl/connection/file_df_connection/spark_hdfs/__init__.py +++ b/onetl/connection/file_df_connection/spark_hdfs/__init__.py @@ -1,4 +1,5 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.connection.file_df_connection.spark_hdfs.connection import SparkHDFS -from onetl.connection.file_df_connection.spark_hdfs.slots import SparkHDFSSlots + +__all__ = ["SparkHDFS"] diff --git a/onetl/connection/file_df_connection/spark_hdfs/connection.py b/onetl/connection/file_df_connection/spark_hdfs/connection.py index 31c37095c..0450bd710 100644 --- a/onetl/connection/file_df_connection/spark_hdfs/connection.py +++ b/onetl/connection/file_df_connection/spark_hdfs/connection.py @@ -150,7 +150,7 @@ class SparkHDFS(SparkFileDFConnection): # Create connection hdfs = SparkHDFS(cluster="rnd-dwh", spark=spark).check() - """ + """ # noqa: E501 Slots = SparkHDFSSlots @@ -248,15 +248,16 @@ def get_current(cls, spark: SparkSession): # injecting current cluster name via hooks mechanism hdfs = SparkHDFS.get_current(spark=spark) - """ + """ # noqa: E501 log.info("|%s| Detecting current cluster...", cls.__name__) current_cluster = cls.Slots.get_current_cluster() if not current_cluster: - raise RuntimeError( + msg = ( f"{cls.__name__}.get_current() can be used only if there are " - f"some hooks bound to {cls.__name__}.Slots.get_current_cluster", + f"some hooks bound to {cls.__name__}.Slots.get_current_cluster" ) + raise RuntimeError(msg) log.info("|%s| Got %r", cls.__name__, current_cluster) return cls(cluster=current_cluster, spark=spark) @@ -271,9 +272,8 @@ def _validate_cluster_name(cls, cluster): log.debug("|%s| Checking if cluster %r is a known cluster...", cls.__name__, validated_cluster) known_clusters = cls.Slots.get_known_clusters() if known_clusters and validated_cluster not in known_clusters: - raise ValueError( - f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}", - ) + msg = f"Cluster {validated_cluster!r} is not in the known clusters list: {sorted(known_clusters)!r}" + raise ValueError(msg) return validated_cluster @@ -289,10 +289,11 @@ def _validate_host_name(cls, host, values): log.debug("|%s| Checking if %r is a known namenode of cluster %r ...", cls.__name__, namenode, cluster) known_namenodes = cls.Slots.get_cluster_namenodes(cluster) if known_namenodes and namenode not in known_namenodes: - raise ValueError( + msg = ( f"Namenode {namenode!r} is not in the known nodes list of cluster {cluster!r}: " - f"{sorted(known_namenodes)!r}", + f"{sorted(known_namenodes)!r}" ) + raise ValueError(msg) return namenode @@ -314,7 +315,8 @@ def _get_active_namenode(self) -> str: namenodes = self.Slots.get_cluster_namenodes(self.cluster) if not namenodes: - raise RuntimeError(f"Cannot get list of namenodes for a cluster {self.cluster!r}") + msg = f"Cannot get list of namenodes for a cluster {self.cluster!r}" + raise RuntimeError(msg) nodes_len = len(namenodes) for i, namenode in enumerate(namenodes, start=1): @@ -324,7 +326,8 @@ def _get_active_namenode(self) -> str: return namenode log.debug("|%s| Node %r is not active, skipping", class_name, namenode) - raise RuntimeError(f"Cannot detect active namenode for cluster {self.cluster!r}") + msg = f"Cannot detect active namenode for cluster {self.cluster!r}" + raise RuntimeError(msg) def _get_host(self) -> str: if not self.host: @@ -343,7 +346,8 @@ def _get_host(self) -> str: log.debug("|%s| No hooks, skip validation", class_name) return self.host - raise RuntimeError(f"Host {self.host!r} is not an active namenode of cluster {self.cluster!r}") + msg = f"Host {self.host!r} is not an active namenode of cluster {self.cluster!r}" + raise RuntimeError(msg) def _get_conn_str(self) -> str: # cache active host to reduce number of requests. diff --git a/onetl/connection/file_df_connection/spark_local_fs.py b/onetl/connection/file_df_connection/spark_local_fs.py index 7a421cf53..df1c36804 100644 --- a/onetl/connection/file_df_connection/spark_local_fs.py +++ b/onetl/connection/file_df_connection/spark_local_fs.py @@ -82,7 +82,8 @@ def __str__(self): def _validate_spark(cls, spark): master = spark.conf.get("spark.master") if not master.startswith("local"): - raise ValueError(f"Currently supports only spark.master='local', got {master!r}") + msg = f"Currently supports only spark.master='local', got {master!r}" + raise ValueError(msg) return spark def _convert_to_url(self, path: PurePathProtocol) -> str: @@ -92,4 +93,4 @@ def _convert_to_url(self, path: PurePathProtocol) -> str: return "file:///" + path.as_posix().lstrip("/") def _get_default_path(self): - return LocalPath(os.getcwd()) + return LocalPath.cwd() diff --git a/onetl/connection/file_df_connection/spark_s3/__init__.py b/onetl/connection/file_df_connection/spark_s3/__init__.py index 01f17de02..a8caa36af 100644 --- a/onetl/connection/file_df_connection/spark_s3/__init__.py +++ b/onetl/connection/file_df_connection/spark_s3/__init__.py @@ -2,3 +2,5 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.connection.file_df_connection.spark_s3.connection import SparkS3 from onetl.connection.file_df_connection.spark_s3.extra import SparkS3Extra + +__all__ = ["SparkS3", "SparkS3Extra"] diff --git a/onetl/connection/file_df_connection/spark_s3/connection.py b/onetl/connection/file_df_connection/spark_s3/connection.py index 6eec0d129..f485e6361 100644 --- a/onetl/connection/file_df_connection/spark_s3/connection.py +++ b/onetl/connection/file_df_connection/spark_s3/connection.py @@ -358,7 +358,7 @@ def check(self): def read_files_as_df( self, paths: list[PurePathProtocol], - format: BaseReadableFileFormat, # noqa: WPS125 + format: BaseReadableFileFormat, root: PurePathProtocol | None = None, df_schema: StructType | None = None, options: FileDFReadOptions | None = None, @@ -371,7 +371,7 @@ def write_df_as_files( self, df: DataFrame, path: PurePathProtocol, - format: BaseWritableFileFormat, # noqa: WPS125 + format: BaseWritableFileFormat, options: FileDFWriteOptions | None = None, ) -> None: self._patch_hadoop_conf() @@ -509,7 +509,7 @@ def _patch_hadoop_conf(self) -> None: self._get_spark_fs().close() log.debug("Set Hadoop configuration") - for key in real_values.keys(): + for key in real_values: hadoop_config.unset(key) for key, value in expected_values.items(): diff --git a/onetl/connection/file_df_connection/spark_s3/extra.py b/onetl/connection/file_df_connection/spark_s3/extra.py index 033bf776e..6df8cf119 100644 --- a/onetl/connection/file_df_connection/spark_s3/extra.py +++ b/onetl/connection/file_df_connection/spark_s3/extra.py @@ -30,6 +30,6 @@ class SparkS3Extra(GenericOptions): """ class Config: - strip_prefixes = ["spark.hadoop.", "fs.s3a.", re.compile(r"bucket\.[^.]+\.")] + strip_prefixes = ("spark.hadoop.", "fs.s3a.", re.compile(r"bucket\.[^.]+\.")) prohibited_options = PROHIBITED_OPTIONS extra = "allow" diff --git a/onetl/connection/kerberos_helpers.py b/onetl/connection/kerberos_helpers.py index e8af017d1..375af9807 100644 --- a/onetl/connection/kerberos_helpers.py +++ b/onetl/connection/kerberos_helpers.py @@ -19,25 +19,24 @@ def kinit_keytab(user: str, keytab: str | os.PathLike) -> None: with _kinit_lock: cmd = ["kinit", user, "-k", "-t", os.fspath(path)] log.info("|onETL| Executing kerberos auth command: %s", " ".join(cmd)) - subprocess.check_call(cmd) + subprocess.check_call(cmd) # noqa: S603 def kinit_password(user: str, password: str) -> None: cmd = ["kinit", user] log.info("|onETL| Executing kerberos auth command: %s", " ".join(cmd)) - with _kinit_lock: - with subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - # do not show user 'Please enter password' banner - stdout=subprocess.PIPE, - # do not capture stderr, immediately show all errors to user - ) as proc: - proc.communicate(password.encode("utf-8")) - exit_code = proc.poll() - if exit_code: - raise subprocess.CalledProcessError(exit_code, cmd) + with _kinit_lock, subprocess.Popen( # noqa: S603 + cmd, + stdin=subprocess.PIPE, + # do not show user 'Please enter password' banner + stdout=subprocess.PIPE, + # do not capture stderr, immediately show all errors to user + ) as proc: + proc.communicate(password.encode("utf-8")) + exit_code = proc.poll() + if exit_code: + raise subprocess.CalledProcessError(exit_code, cmd) def kinit(user: str, keytab: os.PathLike | None = None, password: str | None = None) -> None: diff --git a/onetl/core/__init__.py b/onetl/core/__init__.py index 52e4a427a..e3eb5e840 100644 --- a/onetl/core/__init__.py +++ b/onetl/core/__init__.py @@ -4,8 +4,8 @@ import warnings from importlib import import_module -from onetl.core.file_filter import * -from onetl.core.file_limit import * +from onetl.core.file_filter import FileFilter +from onetl.core.file_limit import FileLimit module_for_class = { "DBReader": "db", @@ -19,6 +19,20 @@ "FileSet": "file.file_set", } +__all__ = [ + "DBReader", + "DBWriter", + "DownloadResult", + "FileDownloader", + "FileFilter", + "FileLimit", + "FileResult", + "FileSet", + "FileUploader", + "MoveResult", + "UploadResult", +] + def __getattr__(name: str): if name in module_for_class: diff --git a/onetl/core/file_filter/__init__.py b/onetl/core/file_filter/__init__.py index ce60b9e83..290c5edeb 100644 --- a/onetl/core/file_filter/__init__.py +++ b/onetl/core/file_filter/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2022-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.core.file_filter.file_filter import FileFilter + +__all__ = ["FileFilter"] diff --git a/onetl/core/file_filter/file_filter.py b/onetl/core/file_filter/file_filter.py index 815b3d751..57534630c 100644 --- a/onetl/core/file_filter/file_filter.py +++ b/onetl/core/file_filter/file_filter.py @@ -107,7 +107,8 @@ class Config: @validator("glob", pre=True) def check_glob(cls, value: str) -> str: if not glob.has_magic(value): - raise ValueError("Invalid glob") + msg = "Invalid glob" + raise ValueError(msg) return value @@ -125,14 +126,16 @@ def check_exclude_dir(cls, value: Union[str, os.PathLike]) -> RemotePath: @root_validator def disallow_empty_fields(cls, value: dict) -> dict: if value.get("glob") is None and value.get("regexp") is None and not value.get("exclude_dirs"): - raise ValueError("One of the following fields must be set: `glob`, `regexp`, `exclude_dirs`") + msg = "One of the following fields must be set: `glob`, `regexp`, `exclude_dirs`" + raise ValueError(msg) return value @root_validator def disallow_both_glob_and_regexp(cls, value: dict) -> dict: if value.get("glob") and value.get("regexp"): - raise ValueError("Only one of `glob`, `regexp` fields can passed, not both") + msg = "Only one of `glob`, `regexp` fields can passed, not both" + raise ValueError(msg) return value @@ -141,7 +144,7 @@ def log_deprecated(cls, value: dict) -> dict: imports = [] old_filters = [] new_filters = [] - glob = value.get("glob") # noqa: WPS442 + glob = value.get("glob") if glob is not None: imports.append("Glob") old_filters.append(f"glob={glob!r}") diff --git a/onetl/core/file_limit/__init__.py b/onetl/core/file_limit/__init__.py index 70336fe08..27080a0d9 100644 --- a/onetl/core/file_limit/__init__.py +++ b/onetl/core/file_limit/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.core.file_limit.file_limit import FileLimit + +__all__ = ["FileLimit"] diff --git a/onetl/db/__init__.py b/onetl/db/__init__.py index 66889cfaf..bd6713804 100644 --- a/onetl/db/__init__.py +++ b/onetl/db/__init__.py @@ -2,3 +2,5 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.db.db_reader import DBReader from onetl.db.db_writer import DBWriter + +__all__ = ["DBReader", "DBWriter"] diff --git a/onetl/db/db_reader/__init__.py b/onetl/db/db_reader/__init__.py index 9bd9cdcb7..6e5e4997c 100644 --- a/onetl/db/db_reader/__init__.py +++ b/onetl/db/db_reader/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.db.db_reader.db_reader import DBReader + +__all__ = ["DBReader"] diff --git a/onetl/db/db_reader/db_reader.py b/onetl/db/db_reader/db_reader.py index 369d7ef88..167c02f0f 100644 --- a/onetl/db/db_reader/db_reader.py +++ b/onetl/db/db_reader/db_reader.py @@ -357,7 +357,7 @@ def validate_df_schema(cls, df_schema: StructType | None, values: dict) -> Struc return connection.dialect.validate_df_schema(df_schema) @root_validator(skip_on_failure=True) - def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 + def validate_hwm(cls, values: dict) -> dict: connection: BaseDBConnection = values["connection"] source: str = values["source"] hwm_column: str | tuple[str, str] | None = values.get("hwm_column") @@ -366,10 +366,11 @@ def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 if hwm_column is not None: if hwm: - raise ValueError("Please pass either DBReader(hwm=...) or DBReader(hwm_column=...), not both") + msg = "Please pass either DBReader(hwm=...) or DBReader(hwm_column=...), not both" + raise ValueError(msg) if not hwm_expression and isinstance(hwm_column, tuple): - hwm_column, hwm_expression = hwm_column # noqa: WPS434 + hwm_column, hwm_expression = hwm_column if not hwm_expression: error_message = textwrap.dedent( @@ -409,7 +410,8 @@ def validate_hwm(cls, values: dict) -> dict: # noqa: WPS231 ) if hwm and not hwm.expression: - raise ValueError("`hwm.expression` cannot be None") + msg = "`hwm.expression` cannot be None" + raise ValueError(msg) if hwm and not hwm.entity: hwm = hwm.copy(update={"entity": source}) @@ -443,9 +445,8 @@ def validate_options(cls, options, values): return read_options_class.parse(options) if options: - raise ValueError( - f"{connection.__class__.__name__} does not implement ReadOptions, but {options!r} is passed", - ) + msg = f"{connection.__class__.__name__} does not implement ReadOptions, but {options!r} is passed" + raise ValueError(msg) return None @@ -459,7 +460,9 @@ def has_data(self) -> bool: .. warning:: - If :etl-entities:`hwm ` is used, then method should be called inside :ref:`strategy` context. And vise-versa, if HWM is not used, this method should not be called within strategy. + If :etl-entities:`hwm ` is used, + then method should be called inside :ref:`strategy` context. + And vise-versa, if HWM is not used, this method should not be called within strategy. .. versionadded:: 0.10.0 @@ -521,7 +524,9 @@ def raise_if_no_data(self) -> None: .. warning:: - If :etl-entities:`hwm ` is used, then method should be called inside :ref:`strategy` context. And vise-versa, if HWM is not used, this method should not be called within strategy. + If :etl-entities:`hwm ` is used, + then method should be called inside :ref:`strategy` context. + And vise-versa, if HWM is not used, this method should not be called within strategy. .. versionadded:: 0.10.0 @@ -545,7 +550,8 @@ def raise_if_no_data(self) -> None: """ if not self.has_data(): - raise NoDataError(f"No data in the source: {self.source}") + msg = f"No data in the source: {self.source}" + raise NoDataError(msg) @slot def run(self) -> DataFrame: @@ -558,7 +564,9 @@ def run(self) -> DataFrame: .. warning:: - If :etl-entities:`hwm ` is used, then method should be called inside :ref:`strategy` context. And vise-versa, if HWM is not used, this method should not be called within strategy. + If :etl-entities:`hwm ` is used, + then method should be called inside :ref:`strategy` context. + And vise-versa, if HWM is not used, this method should not be called within strategy. .. versionadded:: 0.1.0 @@ -615,13 +623,17 @@ def _check_strategy(self): if self.hwm: if not isinstance(strategy, HWMStrategy): - raise RuntimeError( - f"{class_name}(hwm=...) cannot be used with {strategy_name}. Check documentation DBReader.has_data(): https://onetl.readthedocs.io/en/stable/db/db_reader.html#onetl.db.db_reader.db_reader.DBReader.has_data.", + msg = ( + f"{class_name}(hwm=...) cannot be used with {strategy_name}. " + "Check documentation DBReader.has_data(): " + "https://onetl.readthedocs.io/en/stable/db/db_reader.html#onetl.db.db_reader.db_reader.DBReader.has_data." ) + raise RuntimeError(msg) self._prepare_hwm(strategy, self.hwm) elif isinstance(strategy, HWMStrategy): - raise RuntimeError(f"{strategy_name} cannot be used without {class_name}(hwm=...)") + msg = f"{strategy_name} cannot be used without {class_name}(hwm=...)" + raise RuntimeError(msg) def _prepare_hwm(self, strategy: HWMStrategy, hwm: ColumnHWM): if not strategy.hwm: @@ -696,7 +708,8 @@ def _get_hwm_field(self, hwm: HWM) -> StructField: schema = {field.name.casefold(): field for field in self.df_schema} column = hwm.expression.casefold() if column not in schema: - raise ValueError(f"HWM column {column!r} not found in dataframe schema") + msg = f"HWM column {column!r} not found in dataframe schema" + raise ValueError(msg) result = schema[column] elif isinstance(self.connection, ContainsGetDFSchemaMethod): @@ -707,15 +720,16 @@ def _get_hwm_field(self, hwm: HWM) -> StructField: ) result = df_schema[0] else: - raise ValueError( + msg = ( "You should specify `df_schema` field to use DBReader with " - f"{self.connection.__class__.__name__} connection", + f"{self.connection.__class__.__name__} connection" ) + raise ValueError(msg) log.info("|%s| Got Spark field: %s", self.__class__.__name__, result) return result - def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: # noqa: WPS231 + def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: if not self.hwm: # SnapshotStrategy - always select all the data from source return None, None @@ -731,9 +745,8 @@ def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: # no return window, None if not isinstance(self.connection, ContainsGetMinMaxValues): - raise ValueError( - f"{self.connection.__class__.__name__} connection does not support {strategy.__class__.__name__}", - ) + msg = f"{self.connection.__class__.__name__} connection does not support {strategy.__class__.__name__}" + raise TypeError(msg) # strategy does not have start/stop/current value - use min/max values from source to fill them up min_value, max_value = self.connection.get_min_max_values( @@ -784,7 +797,8 @@ def _calculate_window_and_limit(self) -> tuple[Window | None, int | None]: # no window = Window(self.hwm.expression, start_from=strategy.current, stop_at=strategy.next) else: - # for IncrementalStrategy fix only max value to avoid difference between real dataframe content and HWM value + # for IncrementalStrategy fix only max value + # to avoid difference between real dataframe content and HWM value window = Window( self.hwm.expression, start_from=strategy.current, @@ -807,7 +821,7 @@ def _log_parameters(self) -> None: log_json(log, self.where, "where") if self.df_schema: - empty_df = self.connection.spark.createDataFrame([], self.df_schema) # type: ignore + empty_df = self.connection.spark.createDataFrame([], self.df_schema) log_dataframe_schema(log, empty_df) if self.hwm: @@ -825,7 +839,7 @@ def _get_read_kwargs(self) -> dict: @classmethod def _forward_refs(cls) -> dict[str, type]: try_import_pyspark() - from pyspark.sql.types import StructType # noqa: WPS442 + from pyspark.sql.types import StructType # avoid importing pyspark unless user called the constructor, # as we allow user to use `Connection.get_packages()` for creating Spark session diff --git a/onetl/db/db_writer/__init__.py b/onetl/db/db_writer/__init__.py index 96eaa7de8..7f8315286 100644 --- a/onetl/db/db_writer/__init__.py +++ b/onetl/db/db_writer/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2021-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.db.db_writer.db_writer import DBWriter + +__all__ = ["DBWriter"] diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index 124c6280e..b33400260 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -122,9 +122,8 @@ def validate_options(cls, options, values): return write_options_class.parse(options) if options: - raise ValueError( - f"{connection.__class__.__name__} does not implement WriteOptions, but {options!r} is passed", - ) + msg = f"{connection.__class__.__name__} does not implement WriteOptions, but {options!r} is passed" + raise ValueError(msg) return None @@ -152,7 +151,8 @@ def run(self, df: DataFrame) -> None: writer.run(df) """ if df.isStreaming: - raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") + msg = f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames." + raise ValueError(msg) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") @@ -176,12 +176,12 @@ def run(self, df: DataFrame) -> None: # SparkListener is not a reliable source of information, metrics may or may not be present. # Because of this we also do not return these metrics as method result if metrics.output.is_empty: - log.error( + log.error( # noqa: TRY400 "|%s| Error while writing dataframe.", self.__class__.__name__, ) else: - log.error( + log.error( # noqa: TRY400 "|%s| Error while writing dataframe. Target MAY contain partially written data!", self.__class__.__name__, ) diff --git a/onetl/file/__init__.py b/onetl/file/__init__.py index 03d890cce..9c12bab49 100644 --- a/onetl/file/__init__.py +++ b/onetl/file/__init__.py @@ -5,3 +5,14 @@ from onetl.file.file_downloader import DownloadResult, FileDownloader from onetl.file.file_mover import FileMover, MoveResult from onetl.file.file_uploader import FileUploader, UploadResult + +__all__ = [ + "DownloadResult", + "FileDFReader", + "FileDFWriter", + "FileDownloader", + "FileMover", + "FileUploader", + "MoveResult", + "UploadResult", +] diff --git a/onetl/file/file_df_reader/__init__.py b/onetl/file/file_df_reader/__init__.py index 09955ba95..6a2a7c540 100644 --- a/onetl/file/file_df_reader/__init__.py +++ b/onetl/file/file_df_reader/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.file.file_df_reader.file_df_reader import FileDFReader + +__all__ = ["FileDFReader"] diff --git a/onetl/file/file_df_reader/file_df_reader.py b/onetl/file/file_df_reader/file_df_reader.py index cc55985b7..3e9845a84 100644 --- a/onetl/file/file_df_reader/file_df_reader.py +++ b/onetl/file/file_df_reader/file_df_reader.py @@ -206,7 +206,8 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DataFrame: entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") if files is None and not self.source_path: - raise ValueError("Neither file list nor `source_path` are passed") + msg = "Neither file list nor `source_path` are passed" + raise ValueError(msg) if not self._connection_checked: self._log_parameters(files) @@ -272,7 +273,7 @@ def _validate_source_path(cls, source_path, values): return source_path @validator("format") - def _validate_format(cls, format, values): # noqa: WPS125 + def _validate_format(cls, format, values): connection = values.get("connection") if isinstance(connection, BaseFileDFConnection): connection.check_if_format_supported(format) @@ -282,7 +283,7 @@ def _validate_format(cls, format, values): # noqa: WPS125 def _validate_options(cls, value): return cls.Options.parse(value) - def _validate_files( # noqa: WPS231 + def _validate_files( self, files: Iterable[os.PathLike | str], ) -> OrderedSet[PurePathProtocol]: @@ -293,9 +294,11 @@ def _validate_files( # noqa: WPS231 if not self.source_path: if not file_path.is_absolute(): - raise ValueError("Cannot pass relative file path with empty `source_path`") + msg = "Cannot pass relative file path with empty `source_path`" + raise ValueError(msg) elif file_path.is_absolute() and self.source_path not in file_path.parents: - raise ValueError(f"File path '{file_path}' does not match source_path '{self.source_path}'") + msg = f"File path '{file_path}' does not match source_path '{self.source_path}'" + raise ValueError(msg) elif not file_path.is_absolute(): # Make file path absolute file_path = self.source_path / file @@ -307,7 +310,7 @@ def _validate_files( # noqa: WPS231 @classmethod def _forward_refs(cls) -> dict[str, type]: try_import_pyspark() - from pyspark.sql.types import StructType # noqa: WPS442 + from pyspark.sql.types import StructType # avoid importing pyspark unless user called the constructor, # as we allow user to use `Connection.get_packages()` for creating Spark session diff --git a/onetl/file/file_df_writer/__init__.py b/onetl/file/file_df_writer/__init__.py index 897a202b4..15219aab9 100644 --- a/onetl/file/file_df_writer/__init__.py +++ b/onetl/file/file_df_writer/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.file.file_df_writer.file_df_writer import FileDFWriter + +__all__ = ["FileDFWriter"] diff --git a/onetl/file/file_df_writer/file_df_writer.py b/onetl/file/file_df_writer/file_df_writer.py index 3b491974a..2dd2c28c3 100644 --- a/onetl/file/file_df_writer/file_df_writer.py +++ b/onetl/file/file_df_writer/file_df_writer.py @@ -121,7 +121,8 @@ def run(self, df: DataFrame) -> None: entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") if df.isStreaming: - raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") + msg = f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames." + raise ValueError(msg) if not self._connection_checked: self._log_parameters(df) @@ -143,12 +144,12 @@ def run(self, df: DataFrame) -> None: if metrics.output.is_empty: # SparkListener is not a reliable source of information, metrics may or may not be present. # Because of this we also do not return these metrics as method result - log.error( + log.error( # noqa: TRY400 "|%s| Error while writing dataframe.", self.__class__.__name__, ) else: - log.error( + log.error( # noqa: TRY400 "|%s| Error while writing dataframe. Target MAY contain partially written data!", self.__class__.__name__, ) @@ -181,7 +182,7 @@ def _validate_target_path(cls, target_path, values): return target_path @validator("format") - def _validate_format(cls, format, values): # noqa: WPS125 + def _validate_format(cls, format, values): connection = values.get("connection") if isinstance(connection, BaseFileDFConnection): connection.check_if_format_supported(format) diff --git a/onetl/file/file_df_writer/options.py b/onetl/file/file_df_writer/options.py index a947f364a..29bd07597 100644 --- a/onetl/file/file_df_writer/options.py +++ b/onetl/file/file_df_writer/options.py @@ -93,17 +93,20 @@ class Config: .. warning:: - Existing files still present in the root of directory, but Spark will ignore those files while reading, + Existing files still present in the root of directory, + but Spark will ignore those files while reading, unless using ``recursive=True``. * Directory exists and contains partitions, but :obj:`~partition_by` is not set - Data is appended to a directory, but to the root of directory instead of nested partition directories. + Data is appended to a directory, but to the root of + directory instead of nested partition directories. .. warning:: Spark will ignore such files while reading, unless using ``recursive=True``. - * Directory exists and contains partitions, but with different partitioning schema than :obj:`~partition_by` + * Directory exists and contains partitions, + but with different partitioning schema than :obj:`~partition_by` Data is appended to a directory with new partitioning schema. .. warning:: @@ -111,13 +114,16 @@ class Config: Spark cannot read directory with multiple partitioning schemas, unless using ``recursive=True`` to disable partition scanning. - * Directory exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + * Directory exists and partitioned according :obj:`~partition_by`, + but partition is present only in dataframe New partition directory is created. - * Directory exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and directory + * Directory exists and partitioned according :obj:`~partition_by`, + partition is present in both dataframe and directory New files are added to existing partition directory, existing files are sill present. - * Directory exists and partitioned according :obj:`~partition_by`, but partition is present only in directory, not dataframe + * Directory exists and partitioned according :obj:`~partition_by`, + but partition is present only in directory, not dataframe Existing partition is left intact. * ``replace_overlapping_partitions`` @@ -133,7 +139,8 @@ class Config: .. dropdown:: Behavior in details * Directory does not exist - Directory is created using all the provided options (``format``, ``partition_by``, etc). + Directory is created using all the provided options + (``format``, ``partition_by``, etc). * Directory exists, does not contain partitions, but :obj:`~partition_by` is set Directory **will be deleted**, and will be created with partitions. @@ -141,7 +148,8 @@ class Config: * Directory exists and contains partitions, but :obj:`~partition_by` is not set Directory **will be deleted**, and will be created with partitions. - * Directory exists and contains partitions, but with different partitioning schema than :obj:`~partition_by` + * Directory exists and contains partitions, + but with different partitioning schema than :obj:`~partition_by` Data is appended to a directory with new partitioning schema. .. warning:: @@ -149,13 +157,17 @@ class Config: Spark cannot read directory with multiple partitioning schemas, unless using ``recursive=True`` to disable partition scanning. - * Directory exists and partitioned according :obj:`~partition_by`, but partition is present only in dataframe + * Directory exists and partitioned according :obj:`~partition_by`, + but partition is present only in dataframe New partition directory is created. - * Directory exists and partitioned according :obj:`~partition_by`, partition is present in both dataframe and directory - Partition directory **will be deleted**, and new one is created with files containing data from dataframe. + * Directory exists and partitioned according :obj:`~partition_by`, + partition is present in both dataframe and directory + Partition directory **will be deleted**, + and new one is created with files containing data from dataframe. - * Directory exists and partitioned according :obj:`~partition_by`, but partition is present only in directory, not dataframe + * Directory exists and partitioned according :obj:`~partition_by`, + but partition is present only in directory, not dataframe Existing partition is left intact. * ``replace_entire_directory`` @@ -207,9 +219,9 @@ def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: # format orc, parquet methods and format simultaneously if hasattr(writer, method): if isinstance(value, Iterable) and not isinstance(value, str): - writer = getattr(writer, method)(*value) # noqa: WPS220 + writer = getattr(writer, method)(*value) else: - writer = getattr(writer, method)(value) # noqa: WPS220 + writer = getattr(writer, method)(value) else: writer = writer.option(method, value) @@ -229,7 +241,8 @@ def apply_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: @root_validator(pre=True) def _mode_is_restricted(cls, values): if "mode" in values: - raise ValueError("Parameter `mode` is not allowed. Please use `if_exists` parameter instead.") + msg = "Parameter `mode` is not allowed. Please use `if_exists` parameter instead." + raise ValueError(msg) return values @root_validator @@ -241,7 +254,6 @@ def _partition_overwrite_mode_is_not_allowed(cls, values): else: recommended_mode = "replace_overlapping_partitions" - raise ValueError( - f"`partitionOverwriteMode` option should be replaced with if_exists='{recommended_mode}'", - ) + msg = f"`partitionOverwriteMode` option should be replaced with if_exists='{recommended_mode}'" + raise ValueError(msg) return values diff --git a/onetl/file/file_downloader/__init__.py b/onetl/file/file_downloader/__init__.py index 9c66b3a37..05a2ea9ad 100644 --- a/onetl/file/file_downloader/__init__.py +++ b/onetl/file/file_downloader/__init__.py @@ -3,3 +3,9 @@ from onetl.file.file_downloader.file_downloader import FileDownloader from onetl.file.file_downloader.options import FileDownloaderOptions from onetl.file.file_downloader.result import DownloadResult + +__all__ = [ + "DownloadResult", + "FileDownloader", + "FileDownloaderOptions", +] diff --git a/onetl/file/file_downloader/file_downloader.py b/onetl/file/file_downloader/file_downloader.py index 8939dfc27..e39b04e66 100644 --- a/onetl/file/file_downloader/file_downloader.py +++ b/onetl/file/file_downloader/file_downloader.py @@ -138,13 +138,15 @@ class FileDownloader(FrozenModel): Renamed ``limit`` → ``limits`` options : :obj:`~FileDownloader.Options` | dict | None, default: ``None`` - File downloading options. See :obj:`FileDownloader.Options ` + File downloading options. + See :obj:`FileDownloader.Options ` .. versionadded:: 0.3.0 hwm : type[HWM] | None, default: ``None`` - HWM class to detect changes in incremental run. See :etl-entities:`File HWM ` + HWM class to detect changes in incremental run. + See :etl-entities:`File HWM ` .. warning :: Used only in :obj:`IncrementalStrategy `. @@ -271,7 +273,7 @@ class FileDownloader(FrozenModel): _connection_checked: bool = PrivateAttr(default=False) @slot - def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResult: # noqa: WPS231 + def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResult: # noqa: C901 """ Method for downloading files from source to local directory. |support_hooks| @@ -394,7 +396,8 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> DownloadResul self._check_strategy() if files is None and not self.source_path: - raise ValueError("Neither file list nor `source_path` are passed") + msg = "Neither file list nor `source_path` are passed" + raise ValueError(msg) # Check everything if not self._connection_checked: @@ -480,7 +483,8 @@ def view_files(self) -> FileSet[RemoteFile]: """ if not self.source_path: - raise ValueError("Cannot call `.view_files()` without `source_path`") + msg = "Cannot call `.view_files()` without `source_path`" + raise ValueError(msg) log.debug("|%s| Getting files list from path '%s'", self.connection.__class__.__name__, self.source_path) @@ -498,9 +502,8 @@ def view_files(self) -> FileSet[RemoteFile]: result.append(file) except Exception as e: - raise RuntimeError( - f"Couldn't read directory tree from remote dir '{self.source_path}'", - ) from e + msg = f"Couldn't read directory tree from remote dir '{self.source_path}'" + raise RuntimeError(msg) from e return result @@ -524,7 +527,8 @@ def _validate_hwm(cls, values): hwm = values.get("hwm") if (hwm or hwm_type) and not source_path: - raise ValueError("If `hwm` is passed, `source_path` must be specified") + msg = "If `hwm` is passed, `source_path` must be specified" + raise ValueError(msg) if hwm_type and (hwm_type == "file_list" or issubclass(hwm_type, OldFileListHWM)): remote_file_folder = RemoteFolder(name=source_path, instance=connection.instance_url) @@ -617,14 +621,17 @@ def _check_strategy(self): if self.hwm: if not isinstance(strategy, HWMStrategy): - raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}") + msg = f"{class_name}(hwm=...) cannot be used with {strategy_name}" + raise ValueError(msg) offset = getattr(strategy, "offset", None) if offset is not None: - raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}(offset={offset}, ...)") + msg = f"{class_name}(hwm=...) cannot be used with {strategy_name}(offset={offset}, ...)" + raise ValueError(msg) if isinstance(strategy, BatchHWMStrategy): - raise ValueError(f"{class_name}(hwm=...) cannot be used with {strategy_name}") + msg = f"{class_name}(hwm=...) cannot be used with {strategy_name}" + raise ValueError(msg) def _init_hwm(self, hwm: FileHWM) -> FileHWM: strategy: HWMStrategy = StrategyManager.get_current() @@ -687,7 +694,7 @@ def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> N self.__class__.__name__, ) - def _validate_files( # noqa: WPS231 + def _validate_files( self, remote_files: Iterable[os.PathLike | str], current_temp_dir: LocalPath | None, @@ -702,29 +709,30 @@ def _validate_files( # noqa: WPS231 if not self.source_path: # Download into a flat structure if not remote_file_path.is_absolute(): - raise ValueError("Cannot pass relative file path with empty `source_path`") + msg = "Cannot pass relative file path with empty `source_path`" + raise ValueError(msg) filename = remote_file_path.name local_file = self.local_path / filename if current_temp_dir: - tmp_file = current_temp_dir / filename # noqa: WPS220 + tmp_file = current_temp_dir / filename + # Download according to source folder structure + elif self.source_path in remote_file_path.parents: + # Make relative local path + local_file = self.local_path / remote_file_path.relative_to(self.source_path) + if current_temp_dir: + tmp_file = current_temp_dir / remote_file_path.relative_to(self.source_path) + + elif not remote_file_path.is_absolute(): + # Passed path is already relative + local_file = self.local_path / remote_file_path + remote_file = self.source_path / remote_file_path + if current_temp_dir: + tmp_file = current_temp_dir / remote_file_path else: - # Download according to source folder structure - if self.source_path in remote_file_path.parents: - # Make relative local path - local_file = self.local_path / remote_file_path.relative_to(self.source_path) - if current_temp_dir: - tmp_file = current_temp_dir / remote_file_path.relative_to(self.source_path) # noqa: WPS220 - - elif not remote_file_path.is_absolute(): - # Passed path is already relative - local_file = self.local_path / remote_file_path - remote_file = self.source_path / remote_file_path - if current_temp_dir: - tmp_file = current_temp_dir / remote_file_path # noqa: WPS220 - else: - # Wrong path (not relative path and source path not in the path to the file) - raise ValueError(f"File path '{remote_file}' does not match source_path '{self.source_path}'") + # Wrong path (not relative path and source path not in the path to the file) + msg = f"File path '{remote_file}' does not match source_path '{self.source_path}'" + raise ValueError(msg) if not isinstance(remote_file, PathProtocol) and self.connection.path_exists(remote_file): remote_file = self.connection.resolve_file(remote_file) @@ -738,11 +746,12 @@ def _check_source_path(self): def _check_local_path(self): if self.local_path.exists() and not self.local_path.is_dir(): - raise NotADirectoryError(f"{path_repr(self.local_path)} is not a directory") + msg = f"{path_repr(self.local_path)} is not a directory" + raise NotADirectoryError(msg) self.local_path.mkdir(exist_ok=True, parents=True) - def _download_files( # noqa: WPS231 + def _download_files( self, to_download: DOWNLOAD_ITEMS_TYPE, ) -> DownloadResult: @@ -757,7 +766,7 @@ def _download_files( # noqa: WPS231 strategy = StrategyManager.get_current() result = DownloadResult() source_files: list[RemotePath] = [] - try: # noqa: WPS501, WPS243 + try: for status, source_file, target_file in self._bulk_download(to_download): if status == FileDownloadStatus.SUCCESSFUL: result.successful.add(target_file) @@ -827,7 +836,7 @@ def _bulk_download( for source_file, target_file, tmp_file in to_download ) - def _download_file( # noqa: WPS231, WPS213 + def _download_file( # noqa: PLR0912, C901 self, source_file: RemotePath, local_file: LocalPath, @@ -854,7 +863,8 @@ def _download_file( # noqa: WPS231, WPS213 replace = False if local_file.exists(): if self.options.if_exists == FileExistBehavior.ERROR: - raise FileExistsError(f"File {path_repr(local_file)} already exists") + msg = f"File {path_repr(local_file)} already exists" + raise FileExistsError(msg) # noqa: TRY301 if self.options.if_exists == FileExistBehavior.IGNORE: log.warning("|Local FS| File %s already exists, skipping", path_repr(local_file)) @@ -884,8 +894,6 @@ def _download_file( # noqa: WPS231, WPS213 if self.options.delete_source: self.connection.remove_file(remote_file) - return FileDownloadStatus.SUCCESSFUL, remote_file, local_file - except Exception as e: if log.isEnabledFor(logging.DEBUG): log.exception( @@ -894,10 +902,10 @@ def _download_file( # noqa: WPS231, WPS213 exc_info=e, ) else: - log.exception( + log.exception( # noqa: LOG007 "|%s| Couldn't download file from source dir: %s", self.__class__.__name__, - e, + e, # noqa: TRY401 exc_info=False, ) failed_file = FailedRemoteFile( @@ -907,6 +915,9 @@ def _download_file( # noqa: WPS231, WPS213 ) return FileDownloadStatus.FAILED, failed_file, None + else: + return FileDownloadStatus.SUCCESSFUL, remote_file, local_file + def _remove_temp_dir(self, temp_dir: LocalPath) -> None: log.info("|Local FS| Removing temp directory '%s'", temp_dir) diff --git a/onetl/file/file_mover/__init__.py b/onetl/file/file_mover/__init__.py index 002959c0d..faa039623 100644 --- a/onetl/file/file_mover/__init__.py +++ b/onetl/file/file_mover/__init__.py @@ -3,3 +3,9 @@ from onetl.file.file_mover.file_mover import FileMover from onetl.file.file_mover.options import FileMoverOptions from onetl.file.file_mover.result import MoveResult + +__all__ = [ + "FileMover", + "FileMoverOptions", + "MoveResult", +] diff --git a/onetl/file/file_mover/file_mover.py b/onetl/file/file_mover/file_mover.py index b2764035f..f66d4cf20 100644 --- a/onetl/file/file_mover/file_mover.py +++ b/onetl/file/file_mover/file_mover.py @@ -160,7 +160,7 @@ class FileMover(FrozenModel): _connection_checked: bool = PrivateAttr(default=False) @slot - def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: # noqa: WPS231 + def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: """ Method for moving files from source to target directory. |support_hooks| @@ -272,7 +272,8 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> MoveResult: entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") if files is None and not self.source_path: - raise ValueError("Neither file list nor `source_path` are passed") + msg = "Neither file list nor `source_path` are passed" + raise ValueError(msg) if not self._connection_checked: self._log_parameters(files) @@ -345,7 +346,8 @@ def view_files(self) -> FileSet[RemoteFile]: """ if not self.source_path: - raise ValueError("Cannot call `.view_files()` without `source_path`") + msg = "Cannot call `.view_files()` without `source_path`" + raise ValueError(msg) log.debug("|%s| Getting files list from path '%s'", self.connection.__class__.__name__, self.source_path) @@ -360,9 +362,8 @@ def view_files(self) -> FileSet[RemoteFile]: result.append(file) except Exception as e: - raise RuntimeError( - f"Couldn't read directory tree from remote dir '{self.source_path}'", - ) from e + msg = f"Couldn't read directory tree from remote dir '{self.source_path}'" + raise RuntimeError(msg) from e return result @@ -384,7 +385,7 @@ def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> N self.__class__.__name__, ) - def _validate_files( # noqa: WPS231 + def _validate_files( self, remote_files: Iterable[os.PathLike | str], ) -> MOVE_ITEMS_TYPE: @@ -397,23 +398,24 @@ def _validate_files( # noqa: WPS231 if not self.source_path: # Move into a flat structure if not remote_file_path.is_absolute(): - raise ValueError("Cannot pass relative file path with empty `source_path`") + msg = "Cannot pass relative file path with empty `source_path`" + raise ValueError(msg) filename = remote_file_path.name new_file = self.target_path / filename + # Move according to source folder structure + elif self.source_path in remote_file_path.parents: + # Make relative local path + new_file = self.target_path / remote_file_path.relative_to(self.source_path) + + elif not remote_file_path.is_absolute(): + # Passed path is already relative + new_file = self.target_path / remote_file_path + old_file = self.source_path / remote_file_path else: - # Move according to source folder structure - if self.source_path in remote_file_path.parents: - # Make relative local path - new_file = self.target_path / remote_file_path.relative_to(self.source_path) - - elif not remote_file_path.is_absolute(): - # Passed path is already relative - new_file = self.target_path / remote_file_path - old_file = self.source_path / remote_file_path - else: - # Wrong path (not relative path and source path not in the path to the file) - raise ValueError(f"File path '{old_file}' does not match source_path '{self.source_path}'") + # Wrong path (not relative path and source path not in the path to the file) + msg = f"File path '{old_file}' does not match source_path '{self.source_path}'" + raise ValueError(msg) if not isinstance(old_file, PathProtocol) and self.connection.path_exists(old_file): old_file = self.connection.resolve_file(old_file) @@ -500,8 +502,7 @@ def _bulk_move( futures = [ executor.submit(self._move_file, source_file, target_file) for source_file, target_file in to_move ] - for future in as_completed(futures): - result.append(future.result()) + result = [future.result() for future in as_completed(futures)] else: log.debug("|%s| Using plain old for-loop", self.__class__.__name__) for source_file, target_file in to_move: @@ -514,7 +515,7 @@ def _bulk_move( return result - def _move_file( # noqa: WPS231, WPS213 + def _move_file( self, source_file: RemotePath, target_file: RemotePath, @@ -531,7 +532,8 @@ def _move_file( # noqa: WPS231, WPS213 new_file = self.connection.resolve_file(target_file) if self.options.if_exists == FileExistBehavior.ERROR: - raise FileExistsError(f"File {path_repr(new_file)} already exists") + msg = f"File {path_repr(new_file)} already exists" + raise FileExistsError(msg) # noqa: TRY301 if self.options.if_exists == FileExistBehavior.IGNORE: log.warning( @@ -544,7 +546,6 @@ def _move_file( # noqa: WPS231, WPS213 replace = True new_file = self.connection.rename_file(source_file, target_file, replace=replace) - return FileMoveStatus.SUCCESSFUL, new_file except Exception as e: if log.isEnabledFor(logging.DEBUG): @@ -554,14 +555,17 @@ def _move_file( # noqa: WPS231, WPS213 exc_info=e, ) else: - log.exception( + log.exception( # noqa: LOG007 "|%s| Couldn't move file to target dir: %s", self.__class__.__name__, - e, + e, # noqa: TRY401 exc_info=False, ) return FileMoveStatus.FAILED, FailedRemoteFile(path=source_file.path, stats=source_file.stats, exception=e) + else: + return FileMoveStatus.SUCCESSFUL, new_file + def _log_result(self, result: MoveResult) -> None: log_with_indent(log, "") log.info("|%s| Move result:", self.__class__.__name__) diff --git a/onetl/file/file_result.py b/onetl/file/file_result.py index 4523e5988..cebed5de7 100644 --- a/onetl/file/file_result.py +++ b/onetl/file/file_result.py @@ -420,7 +420,8 @@ def raise_if_empty(self) -> None: """ if self.is_empty: - raise EmptyFilesError("There are no files in the result") + msg = "There are no files in the result" + raise EmptyFilesError(msg) @property def details(self) -> str: diff --git a/onetl/file/file_set.py b/onetl/file/file_set.py index d9b8bccd3..cbea077c6 100644 --- a/onetl/file/file_set.py +++ b/onetl/file/file_set.py @@ -66,7 +66,8 @@ def raise_if_empty(self) -> None: """ if not self: - raise EmptyFilesError("There are no files in the set") + msg = "There are no files in the set" + raise EmptyFilesError(msg) def raise_if_contains_zero_size(self) -> None: """ @@ -145,7 +146,7 @@ def summary(self) -> str: return f"{file_number_str} (size='{naturalsize(self.total_size)}')" @property - def details(self) -> str: # noqa: WPS473 + def details(self) -> str: """ Return detailed information about files in the set diff --git a/onetl/file/file_uploader/__init__.py b/onetl/file/file_uploader/__init__.py index 55cb9eeeb..e953b5e8d 100644 --- a/onetl/file/file_uploader/__init__.py +++ b/onetl/file/file_uploader/__init__.py @@ -3,3 +3,9 @@ from onetl.file.file_uploader.file_uploader import FileUploader from onetl.file.file_uploader.options import FileUploaderOptions from onetl.file.file_uploader.result import UploadResult + +__all__ = [ + "FileUploader", + "FileUploaderOptions", + "UploadResult", +] diff --git a/onetl/file/file_uploader/file_uploader.py b/onetl/file/file_uploader/file_uploader.py index 13b924082..f4ff85441 100644 --- a/onetl/file/file_uploader/file_uploader.py +++ b/onetl/file/file_uploader/file_uploader.py @@ -267,7 +267,8 @@ def run(self, files: Iterable[str | os.PathLike] | None = None) -> UploadResult: entity_boundary_log(log, f"{self.__class__.__name__}.run() starts") if files is None and not self.local_path: - raise ValueError("Neither file list nor `local_path` are passed") + msg = "Neither file list nor `local_path` are passed" + raise ValueError(msg) if not self._connection_checked: self._log_parameters(files) @@ -348,7 +349,8 @@ def view_files(self) -> FileSet[LocalPath]: """ if not self.local_path: - raise ValueError("Cannot call `.view_files()` without `local_path`") + msg = "Cannot call `.view_files()` without `local_path`" + raise ValueError(msg) log.debug("|Local FS| Getting files list from path '%s'", self.local_path) @@ -362,9 +364,8 @@ def view_files(self) -> FileSet[LocalPath]: log.debug("|Local FS| Listing dir '%s': %d dirs, %d files", root, len(dirs), len(files)) result.update(LocalPath(root) / file for file in files) except Exception as e: - raise RuntimeError( - f"Couldn't read directory tree from local dir '{self.local_path}'", - ) from e + msg = f"Couldn't read directory tree from local dir '{self.local_path}'" + raise RuntimeError(msg) from e return result @@ -399,7 +400,7 @@ def _log_parameters(self, files: Iterable[str | os.PathLike] | None = None) -> N self.__class__.__name__, ) - def _validate_files( # noqa: WPS231 + def _validate_files( self, local_files: Iterable[os.PathLike | str], current_temp_dir: RemotePath | None, @@ -414,31 +415,33 @@ def _validate_files( # noqa: WPS231 if not self.local_path: # Upload into a flat structure if not local_file_path.is_absolute(): - raise ValueError("Cannot pass relative file path with empty `local_path`") + msg = "Cannot pass relative file path with empty `local_path`" + raise ValueError(msg) filename = local_file_path.name target_file = self.target_path / filename if current_temp_dir: - tmp_file = current_temp_dir / filename # noqa: WPS220 + tmp_file = current_temp_dir / filename + # Upload according to source folder structure + elif self.local_path in local_file_path.parents: + # Make relative remote path + target_file = self.target_path / local_file_path.relative_to(self.local_path) + if current_temp_dir: + tmp_file = current_temp_dir / local_file_path.relative_to(self.local_path) + elif not local_file_path.is_absolute(): + # Passed path is already relative + local_file = self.local_path / local_file_path + target_file = self.target_path / local_file_path + if current_temp_dir: + tmp_file = current_temp_dir / local_file_path else: - # Upload according to source folder structure - if self.local_path in local_file_path.parents: - # Make relative remote path - target_file = self.target_path / local_file_path.relative_to(self.local_path) - if current_temp_dir: - tmp_file = current_temp_dir / local_file_path.relative_to(self.local_path) # noqa: WPS220 - elif not local_file_path.is_absolute(): - # Passed path is already relative - local_file = self.local_path / local_file_path - target_file = self.target_path / local_file_path - if current_temp_dir: - tmp_file = current_temp_dir / local_file_path # noqa: WPS220 - else: - # Wrong path (not relative path and source path not in the path to the file) - raise ValueError(f"File path '{local_file}' does not match source_path '{self.local_path}'") + # Wrong path (not relative path and source path not in the path to the file) + msg = f"File path '{local_file}' does not match source_path '{self.local_path}'" + raise ValueError(msg) if local_file.exists() and not local_file.is_file(): - raise NotAFileError(f"{path_repr(local_file)} is not a file") + msg = f"{path_repr(local_file)} is not a file" + raise NotAFileError(msg) result.add((local_file, target_file, tmp_file)) @@ -446,10 +449,12 @@ def _validate_files( # noqa: WPS231 def _check_local_path(self): if not self.local_path.exists(): - raise DirectoryNotFoundError(f"'{self.local_path}' does not exist") + msg = f"'{self.local_path}' does not exist" + raise DirectoryNotFoundError(msg) if not self.local_path.is_dir(): - raise NotADirectoryError(f"{path_repr(self.local_path)} is not a directory") + msg = f"{path_repr(self.local_path)} is not a directory" + raise NotADirectoryError(msg) def _upload_files(self, to_upload: UPLOAD_ITEMS_TYPE) -> UploadResult: files = FileSet(item[0] for item in to_upload) @@ -518,8 +523,7 @@ def _bulk_upload( executor.submit(self._upload_file, local_file, target_file, tmp_file) for local_file, target_file, tmp_file in to_upload ] - for future in as_completed(futures): - result.append(future.result()) + result = [future.result() for future in as_completed(futures)] else: log.debug("|%s| Using plain old for-loop", self.__class__.__name__) for local_file, target_file, tmp_file in to_upload: @@ -533,7 +537,7 @@ def _bulk_upload( return result - def _upload_file( # noqa: WPS231 + def _upload_file( # noqa: PLR0912, C901 self, local_file: LocalPath, target_file: RemotePath, @@ -559,7 +563,8 @@ def _upload_file( # noqa: WPS231 if self.connection.path_exists(target_file): file = self.connection.resolve_file(target_file) if self.options.if_exists == FileExistBehavior.ERROR: - raise FileExistsError(f"File {path_repr(file)} already exists") + msg = f"File {path_repr(file)} already exists" + raise FileExistsError(msg) # noqa: TRY301 if self.options.if_exists == FileExistBehavior.IGNORE: log.warning("|%s| File %s already exists, skipping", self.__class__.__name__, path_repr(file)) @@ -581,16 +586,22 @@ def _upload_file( # noqa: WPS231 local_file.unlink() log.warning("|Local FS| Successfully removed file %s", local_file) - return FileUploadStatus.SUCCESSFUL, uploaded_file - except Exception as e: if log.isEnabledFor(logging.DEBUG): log.exception("|%s| Couldn't upload file to target dir", self.__class__.__name__, exc_info=e) else: - log.exception("|%s| Couldn't upload file to target dir: %s", self.__class__.__name__, e, exc_info=False) + log.exception( # noqa: LOG007 + "|%s| Couldn't upload file to target dir: %s", + self.__class__.__name__, + e, # noqa: TRY401 + exc_info=False, + ) return FileUploadStatus.FAILED, FailedLocalFile(path=local_file, exception=e) + else: + return FileUploadStatus.SUCCESSFUL, uploaded_file + def _remove_temp_dir(self, temp_dir: RemotePath) -> None: try: self.connection.remove_dir(temp_dir, recursive=True) diff --git a/onetl/file/filter/__init__.py b/onetl/file/filter/__init__.py index d0401bf9e..4ccc4556e 100644 --- a/onetl/file/filter/__init__.py +++ b/onetl/file/filter/__init__.py @@ -14,6 +14,6 @@ "FileModifiedTime", "FileSizeRange", "Glob", - "match_all_filters", "Regexp", + "match_all_filters", ] diff --git a/onetl/file/filter/exclude_dir.py b/onetl/file/filter/exclude_dir.py index cee13abf5..fb6f04b97 100644 --- a/onetl/file/filter/exclude_dir.py +++ b/onetl/file/filter/exclude_dir.py @@ -45,7 +45,7 @@ class Config: def __init__(self, path: str | os.PathLike): # this is only to allow passing glob as positional argument - super().__init__(path=path) # type: ignore + super().__init__(path=path) def __repr__(self): return f"{self.__class__.__name__}('{self.path}')" diff --git a/onetl/file/filter/file_mtime.py b/onetl/file/filter/file_mtime.py index 0291f0007..0f51df43a 100644 --- a/onetl/file/filter/file_mtime.py +++ b/onetl/file/filter/file_mtime.py @@ -74,11 +74,13 @@ def _validate_since_until(cls, values): until = values.get("until") if since is None and until is None: - raise ValueError("Either since or until must be specified") + msg = "Either since or until must be specified" + raise ValueError(msg) # since and until can be tz-naive and tz-aware, which are cannot be compared. if since and until and since.timestamp() > until.timestamp(): - raise ValueError("since cannot be greater than until") + msg = "since cannot be greater than until" + raise ValueError(msg) return values diff --git a/onetl/file/filter/file_size.py b/onetl/file/filter/file_size.py index a38184705..77962e6d0 100644 --- a/onetl/file/filter/file_size.py +++ b/onetl/file/filter/file_size.py @@ -77,17 +77,20 @@ def _validate_min_max(cls, values): max_value = values.get("max") if min_value is None and max_value is None: - raise ValueError("Either min or max must be specified") + msg = "Either min or max must be specified" + raise ValueError(msg) if min_value and max_value and min_value > max_value: - raise ValueError("Min size cannot be greater than max size") + msg = "Min size cannot be greater than max size" + raise ValueError(msg) return values @validator("min", "max") def _validate_min(cls, value): if value is not None and value < 0: - raise ValueError("size cannot be negative") + msg = "size cannot be negative" + raise ValueError(msg) return value def __repr__(self): diff --git a/onetl/file/filter/glob.py b/onetl/file/filter/glob.py index 0ba620aaa..2694e1256 100644 --- a/onetl/file/filter/glob.py +++ b/onetl/file/filter/glob.py @@ -45,7 +45,7 @@ class Config: def __init__(self, pattern: str): # this is only to allow passing glob as positional argument - super().__init__(pattern=pattern) # type: ignore + super().__init__(pattern=pattern) def __repr__(self): return f"{self.__class__.__name__}({self.pattern!r})" @@ -59,6 +59,7 @@ def match(self, path: PathProtocol) -> bool: @validator("pattern", pre=True) def _validate_pattern(cls, value: str) -> str: if not glob.has_magic(value): - raise ValueError(f"Invalid glob: {value!r}") + msg = f"Invalid glob: {value!r}" + raise ValueError(msg) return value diff --git a/onetl/file/filter/regexp.py b/onetl/file/filter/regexp.py index 15fa633f6..5ab85fac0 100644 --- a/onetl/file/filter/regexp.py +++ b/onetl/file/filter/regexp.py @@ -58,7 +58,7 @@ class Config: def __init__(self, pattern: str): # this is only to allow passing regexp as positional argument - super().__init__(pattern=pattern) # type: ignore + super().__init__(pattern=pattern) def __repr__(self): return f"{self.__class__.__name__}({self.pattern!r})" @@ -75,6 +75,7 @@ def _validate_pattern(cls, value: re.Pattern | str) -> re.Pattern: try: return re.compile(value, re.IGNORECASE | re.DOTALL) except re.error as e: - raise ValueError(f"Invalid regexp: {value!r}") from e + msg = f"Invalid regexp: {value!r}" + raise ValueError(msg) from e return value diff --git a/onetl/file/format/__init__.py b/onetl/file/format/__init__.py index 71632d191..28cdf740e 100644 --- a/onetl/file/format/__init__.py +++ b/onetl/file/format/__init__.py @@ -10,12 +10,12 @@ from onetl.file.format.xml import XML __all__ = [ - "Avro", "CSV", - "Excel", "JSON", - "JSONLine", "ORC", - "Parquet", "XML", + "Avro", + "Excel", + "JSONLine", + "Parquet", ] diff --git a/onetl/file/format/avro.py b/onetl/file/format/avro.py index 354097356..4490a13c6 100644 --- a/onetl/file/format/avro.py +++ b/onetl/file/format/avro.py @@ -12,7 +12,7 @@ try: from pydantic.v1 import Field, root_validator, validator except (ImportError, AttributeError): - from pydantic import Field, validator, root_validator # type: ignore[no-redef, assignment] + from pydantic import Field, root_validator, validator # type: ignore[no-redef, assignment] from onetl._util.java import try_import_java_class from onetl._util.scala import get_default_scala_version @@ -238,7 +238,8 @@ class Avro(ReadWriteFileFormat): Avro schema may contain union types, which are not supported by Spark. Different variants of union are split to separated DataFrame columns with respective type. - If option value is ``True``, DataFrame column names are based on Avro variant names, e.g. ``member_int``, ``member_string``. + If option value is ``True``, DataFrame column names are based on Avro variant names, + e.g. ``member_int``, ``member_string``. If ``False``, DataFrame column names are generated using field position, e.g. ``member0``, ``member1``. Default is ``False``. @@ -352,12 +353,13 @@ def parse_column(self, column: str | Column) -> Column: Returns ------- - Column with deserialized data. Schema is matching the provided Avro schema. Column name is the same as input column. + Column with deserialized data. Schema is matching the provided Avro schema. + Column name is the same as input column. Raises ------ ValueError - If the Spark version is less than 3.x or if neither ``avroSchema`` nor ``avroSchemaUrl`` are defined. + If neither ``avroSchema`` nor ``avroSchemaUrl`` are defined. ImportError If ``schema_url`` is used and the ``requests`` library is not installed. @@ -406,24 +408,24 @@ def parse_column(self, column: str | Column) -> Column: |-- value: struct (nullable = true) | |-- name: string (nullable = true) | |-- age: integer (nullable = true) - """ - from pyspark.sql import Column, SparkSession # noqa: WPS442 + """ # noqa: E501 + from pyspark.sql import Column, SparkSession from pyspark.sql.functions import col - spark = SparkSession._instantiatedSession # noqa: WPS437 - self.check_if_supported(spark) + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_parse_options() from pyspark.sql.avro.functions import from_avro if isinstance(column, Column): - column_name = column._jc.toString() # noqa: WPS437 + column_name = column._jc.toString() # noqa: SLF001 else: column_name, column = column, col(column).cast("binary") schema = self._get_schema_json() if not schema: - raise ValueError("Avro.parse_column can be used only with defined `avroSchema` or `avroSchemaUrl`") + msg = "Avro.parse_column can be used only with defined `avroSchema` or `avroSchemaUrl`" + raise ValueError(msg) return from_avro(column, schema).alias(column_name) @@ -438,7 +440,8 @@ def serialize_column(self, column: str | Column) -> Column: .. warning:: - If ``schema_url`` is provided, ``requests`` library is used to fetch the schema from the URL. It should be installed manually, like this: + If ``schema_url`` is provided, ``requests`` library is used to fetch the schema from the URL. + It should be installed manually, like this: .. code:: bash @@ -503,18 +506,17 @@ def serialize_column(self, column: str | Column) -> Column: root |-- key: string (nullable = true) |-- value: binary (nullable = true) - """ - from pyspark.sql import Column, SparkSession # noqa: WPS442 + """ # noqa: E501 + from pyspark.sql import Column, SparkSession from pyspark.sql.functions import col - spark = SparkSession._instantiatedSession # noqa: WPS437 - self.check_if_supported(spark) + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_serialization_options() from pyspark.sql.avro.functions import to_avro if isinstance(column, Column): - column_name = column._jc.toString() # noqa: WPS437 + column_name = column._jc.toString() # noqa: SLF001 else: column_name, column = column, col(column) @@ -557,22 +559,24 @@ def _check_schema(cls, values): schema_dict = values.get("schema_dict") schema_url = values.get("schema_url") if schema_dict and schema_url: - raise ValueError("Parameters `avroSchema` and `avroSchemaUrl` are mutually exclusive.") + msg = "Parameters `avroSchema` and `avroSchemaUrl` are mutually exclusive." + raise ValueError(msg) return values def _get_schema_json(self) -> str: if self.schema_dict: return json.dumps(self.schema_dict) - elif self.schema_url: + if self.schema_url: try: import requests - - response = requests.get(self.schema_url) # noqa: S113 - return response.text except ImportError as e: - raise ImportError( + msg = ( "The 'requests' library is required to use 'schema_url' but is not installed. " - "Install it with 'pip install requests' or avoid using 'schema_url'.", - ) from e + "Install it with 'pip install requests' or avoid using 'schema_url'." + ) + raise ImportError(msg) from e + else: + response = requests.get(self.schema_url) # noqa: S113 + return response.text else: return "" diff --git a/onetl/file/format/csv.py b/onetl/file/format/csv.py index 61b0d4043..4ce820869 100644 --- a/onetl/file/format/csv.py +++ b/onetl/file/format/csv.py @@ -475,7 +475,8 @@ def check_if_supported(cls, spark: SparkSession) -> None: def parse_column(self, column: str | Column, schema: StructType) -> Column: """ Parses a CSV string column to a structured Spark SQL column using Spark's - `from_csv `_ function, based on the provided schema. + `from_csv `_ function, + based on the provided schema. .. note:: @@ -489,11 +490,13 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: The name of the column or the column object containing CSV strings/bytes to parse. schema : StructType - The schema to apply when parsing the CSV data. This defines the structure of the output DataFrame column. + The schema to apply when parsing the CSV data. + This defines the structure of the output DataFrame column. Returns ------- - Column with deserialized data, with the same structure as the provided schema. Column name is the same as input column. + Column with deserialized data, with the same structure as the provided schema. + Column name is the same as input column. Examples -------- @@ -532,18 +535,17 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: |-- value: struct (nullable = true) | |-- name: string (nullable = true) | |-- age: integer (nullable = true) - """ + """ # noqa: E501 - from pyspark.sql import Column, SparkSession # noqa: WPS442 + from pyspark.sql import Column, SparkSession - spark = SparkSession._instantiatedSession # noqa: WPS437 - self.check_if_supported(spark) + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_serialization_options() from pyspark.sql.functions import col, from_csv if isinstance(column, Column): - column_name = column._jc.toString() # noqa: WPS437 + column_name = column._jc.toString() # noqa: SLF001 else: column_name, column = column, col(column).cast("string") @@ -603,18 +605,17 @@ def serialize_column(self, column: str | Column) -> Column: root |-- id: integer (nullable = true) |-- value: string (nullable = true) - """ + """ # noqa: E501 - from pyspark.sql import Column, SparkSession # noqa: WPS442 + from pyspark.sql import Column, SparkSession - spark = SparkSession._instantiatedSession # noqa: WPS437 - self.check_if_supported(spark) + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_serialization_options() from pyspark.sql.functions import col, to_csv if isinstance(column, Column): - column_name = column._jc.toString() # noqa: WPS437 + column_name = column._jc.toString() # noqa: SLF001 else: column_name, column = column, col(column) @@ -626,7 +627,8 @@ def _check_unsupported_serialization_options(self): unsupported_options = current_options.keys() & PARSE_COLUMN_UNSUPPORTED_OPTIONS if unsupported_options: warnings.warn( - f"Options `{sorted(unsupported_options)}` are set but not supported in `CSV.parse_column` or `CSV.serialize_column`.", + f"Options `{sorted(unsupported_options)}` are set but not supported " + "in `CSV.parse_column` or `CSV.serialize_column`.", UserWarning, stacklevel=2, ) diff --git a/onetl/file/format/excel.py b/onetl/file/format/excel.py index bfb688d0e..c7e0547e9 100644 --- a/onetl/file/format/excel.py +++ b/onetl/file/format/excel.py @@ -274,7 +274,8 @@ def get_packages( version = Version(package_version) if version < Version("0.30"): - raise ValueError(f"Package version should be at least 0.30, got {package_version}") + msg = f"Package version should be at least 0.30, got {package_version}" + raise ValueError(msg) spark_ver = Version(spark_version).min_digits(3) scala_ver = Version(scala_version).min_digits(2) if scala_version else get_default_scala_version(spark_ver) diff --git a/onetl/file/format/json.py b/onetl/file/format/json.py index 30655e2e2..6baf97036 100644 --- a/onetl/file/format/json.py +++ b/onetl/file/format/json.py @@ -327,7 +327,9 @@ def check_if_supported(self, spark: SparkSession) -> None: def parse_column(self, column: str | Column, schema: StructType | ArrayType | MapType) -> Column: """ - Parses a JSON string column to a structured Spark SQL column using Spark's `from_json `_ function, based on the provided schema. + Parses a JSON string column to a structured Spark SQL column using Spark's + `from_json `_ + function, based on the provided schema. .. versionadded:: 0.11.0 @@ -337,11 +339,13 @@ def parse_column(self, column: str | Column, schema: StructType | ArrayType | Ma The name of the column or the column object containing JSON strings/bytes to parse. schema : StructType | ArrayType | MapType - The schema to apply when parsing the JSON data. This defines the structure of the output DataFrame column. + The schema to apply when parsing the JSON data. + This defines the structure of the output DataFrame column. Returns ------- - Column with deserialized data, with the same structure as the provided schema. Column name is the same as input column. + Column with deserialized data, with the same structure as the provided schema. + Column name is the same as input column. Examples -------- @@ -387,14 +391,14 @@ def parse_column(self, column: str | Column, schema: StructType | ArrayType | Ma | |-- name: string (nullable = true) | |-- age: integer (nullable = true) """ - from pyspark.sql import Column, SparkSession # noqa: WPS442 + from pyspark.sql import Column, SparkSession from pyspark.sql.functions import col, from_json - self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437 + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_serialization_options() if isinstance(column, Column): - column_name, column = column._jc.toString(), column.cast("string") # noqa: WPS437 + column_name, column = column._jc.toString(), column.cast("string") # noqa: SLF001 else: column_name, column = column, col(column).cast("string") @@ -449,15 +453,15 @@ def serialize_column(self, column: str | Column) -> Column: root |-- key: string (nullable = true) |-- value: string (nullable = true) - """ - from pyspark.sql import Column, SparkSession # noqa: WPS442 + """ # noqa: E501 + from pyspark.sql import Column, SparkSession from pyspark.sql.functions import col, to_json - self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437 + self.check_if_supported(SparkSession._instantiatedSession) # noqa: SLF001 self._check_unsupported_serialization_options() if isinstance(column, Column): - column_name = column._jc.toString() # noqa: WPS437 + column_name = column._jc.toString() # noqa: SLF001 else: column_name, column = column, col(column) diff --git a/onetl/file/format/orc.py b/onetl/file/format/orc.py index a18c76000..2eb00af4e 100644 --- a/onetl/file/format/orc.py +++ b/onetl/file/format/orc.py @@ -46,8 +46,10 @@ class ORC(ReadWriteFileFormat): The set of supported options depends on Spark version. - You may also set options mentioned `orc-java documentation `_. - They are prefixed with ``orc.`` with dots in names, so instead of calling constructor ``ORC(orc.option=True)`` (invalid in Python) + You may also set options mentioned + `orc-java documentation `_. + They are prefixed with ``orc.`` with dots in names, + so instead of calling constructor ``ORC(orc.option=True)`` (invalid in Python) you should call method ``ORC.parse({"orc.option": True})``. .. tabs:: @@ -118,7 +120,7 @@ def check_if_supported(self, spark: SparkSession) -> None: def __repr__(self): options_dict = self.dict(by_alias=True, exclude_none=True) options_dict = dict(sorted(options_dict.items())) - if any("." in field for field in options_dict.keys()): + if any("." in field for field in options_dict): return f"{self.__class__.__name__}.parse({options_dict})" options_kwargs = ", ".join(f"{k}={v!r}" for k, v in options_dict.items()) diff --git a/onetl/file/format/parquet.py b/onetl/file/format/parquet.py index d102eac39..b27d435c5 100644 --- a/onetl/file/format/parquet.py +++ b/onetl/file/format/parquet.py @@ -46,8 +46,10 @@ class Parquet(ReadWriteFileFormat): The set of supported options depends on Spark version. - You may also set options mentioned `parquet-hadoop documentation `_. - They are prefixed with ``parquet.`` with dots in names, so instead of calling constructor ``Parquet(parquet.option=True)`` (invalid in Python) + You may also set options mentioned + `parquet-hadoop documentation `_. + They are prefixed with ``parquet.`` with dots in names, + so instead of calling constructor ``Parquet(parquet.option=True)`` (invalid in Python) you should call method ``Parquet.parse({"parquet.option": True})``. .. tabs:: @@ -115,7 +117,7 @@ def check_if_supported(self, spark: SparkSession) -> None: def __repr__(self): options_dict = self.dict(by_alias=True, exclude_none=True) options_dict = dict(sorted(options_dict.items())) - if any("." in field for field in options_dict.keys()): + if any("." in field for field in options_dict): return f"{self.__class__.__name__}.parse({options_dict})" options_kwargs = ", ".join(f"{k}={v!r}" for k, v in options_dict.items()) diff --git a/onetl/file/format/xml.py b/onetl/file/format/xml.py index 752d7eb82..41dbb5aef 100644 --- a/onetl/file/format/xml.py +++ b/onetl/file/format/xml.py @@ -80,7 +80,8 @@ class XML(ReadWriteFileFormat): .. warning:: - Due to `bug `_ written files currently does not have ``.xml`` extension. + Due to `bug `_ written files + currently do not have ``.xml`` extension. .. code:: python @@ -336,7 +337,7 @@ class Config: @slot @classmethod - def get_packages( # noqa: WPS231 + def get_packages( cls, spark_version: str, scala_version: str | None = None, @@ -393,14 +394,15 @@ def get_packages( # noqa: WPS231 """ spark_ver = Version(spark_version) - if spark_ver.major >= 4: + if spark_ver.major >= 4: # noqa: PLR2004 # since Spark 4.0, XML is bundled with Spark return [] if package_version: version = Version(package_version).min_digits(3) if version < Version("0.14"): - raise ValueError(f"Package version must be above 0.13, got {version}") + msg = f"Package version must be above 0.13, got {version}" + raise ValueError(msg) log.warning("Passed custom package version %r, it is not guaranteed to be supported", package_version) else: version = Version("0.18.0") @@ -411,7 +413,7 @@ def get_packages( # noqa: WPS231 @slot def check_if_supported(self, spark: SparkSession) -> None: version = get_spark_version(spark) - if version.major >= 4: + if version.major >= 4: # noqa: PLR2004 # since Spark 4.0, XML is bundled with Spark return @@ -439,8 +441,11 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: .. note:: - This method parses each DataFrame row individually. Therefore, for a specific column, each row must contain exactly one occurrence of the ``rowTag`` specified. - If your XML data includes a root tag that encapsulates multiple row tags, you can adjust the schema to use an ``ArrayType`` to keep all child elements under the single root. + This method parses each DataFrame row individually. Therefore, for a specific column, + each row must contain exactly one occurrence of the ``rowTag`` specified. + + If your XML data includes a root tag that encapsulates multiple row tags, you can adjust the schema + to use an ``ArrayType`` to keep all child elements under the single root. .. code-block:: xml @@ -484,11 +489,13 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: The name of the column or the column object containing XML strings/bytes to parse. schema : StructType - The schema to apply when parsing the XML data. This defines the structure of the output DataFrame column. + The schema to apply when parsing the XML data. + This defines the structure of the output DataFrame column. Returns ------- - Column with deserialized data, with the same structure as the provided schema. Column name is the same as input column. + Column with deserialized data, with the same structure as the provided schema. + Column name is the same as input column. Examples -------- @@ -528,34 +535,34 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column: | |-- name: string (nullable = true) | |-- age: integer (nullable = true) """ - from pyspark.sql import Column, SparkSession # noqa: WPS442 + from pyspark.sql import Column, SparkSession - spark = SparkSession._instantiatedSession # noqa: WPS437 + spark = SparkSession._instantiatedSession # noqa: SLF001 self.check_if_supported(spark) self._check_unsupported_serialization_options() from pyspark.sql.functions import col if isinstance(column, Column): - column_name, column = column._jc.toString(), column.cast("string") # noqa: WPS437 + column_name, column = column._jc.toString(), column.cast("string") # noqa: SLF001 else: column_name, column = column, col(column).cast("string") options = self.dict(by_alias=True, exclude_none=True) version = get_spark_version(spark) - if version.major >= 4: - from pyspark.sql.functions import from_xml # noqa: WPS450 + if version.major >= 4: # noqa: PLR2004 + from pyspark.sql.functions import from_xml return from_xml(column, schema, stringify(options)).alias(column_name) - from pyspark.sql.column import _to_java_column # noqa: WPS450 + from pyspark.sql.column import _to_java_column java_column = _to_java_column(column) - java_schema = spark._jsparkSession.parseDataType(schema.json()) # noqa: WPS437 - scala_options = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap( # noqa: WPS219, WPS437 + java_schema = spark._jsparkSession.parseDataType(schema.json()) # noqa: SLF001 + scala_options = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap( # noqa: SLF001 stringify(options), ) - jc = spark._jvm.com.databricks.spark.xml.functions.from_xml( # noqa: WPS219, WPS437 + jc = spark._jvm.com.databricks.spark.xml.functions.from_xml( # noqa: SLF001 java_column, java_schema, scala_options, diff --git a/onetl/file/limit/__init__.py b/onetl/file/limit/__init__.py index bfc0cc1c9..1a01e377f 100644 --- a/onetl/file/limit/__init__.py +++ b/onetl/file/limit/__init__.py @@ -7,9 +7,9 @@ from onetl.file.limit.total_files_size import TotalFilesSize __all__ = [ - "limits_reached", - "limits_stop_at", "MaxFilesCount", "TotalFilesSize", + "limits_reached", + "limits_stop_at", "reset_limits", ] diff --git a/onetl/file/limit/limits_reached.py b/onetl/file/limit/limits_reached.py index aba6ec612..31536d391 100644 --- a/onetl/file/limit/limits_reached.py +++ b/onetl/file/limit/limits_reached.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging -from typing import Iterable +from typing import TYPE_CHECKING, Iterable -from onetl.base import BaseFileLimit +if TYPE_CHECKING: + from onetl.base import BaseFileLimit log = logging.getLogger(__name__) diff --git a/onetl/file/limit/limits_stop_at.py b/onetl/file/limit/limits_stop_at.py index 893b96207..9891ab76f 100644 --- a/onetl/file/limit/limits_stop_at.py +++ b/onetl/file/limit/limits_stop_at.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging -from typing import Iterable +from typing import TYPE_CHECKING, Iterable -from onetl.base import BaseFileLimit, PathProtocol +if TYPE_CHECKING: + from onetl.base import BaseFileLimit, PathProtocol log = logging.getLogger(__name__) @@ -43,10 +44,7 @@ def limits_stop_at(path: PathProtocol, limits: Iterable[BaseFileLimit]) -> bool: >>> limits_stop_at(LocalPath("/path/to/file3.csv"), limits) True """ - reached = [] - for limit in limits: - if limit.stops_at(path): - reached.append(limit) + reached = [limit for limit in limits if limit.stops_at(path)] if reached: log.debug("|FileLimit| Limits %r are reached", reached) diff --git a/onetl/file/limit/max_files_count.py b/onetl/file/limit/max_files_count.py index 4dffb30ea..240b0a113 100644 --- a/onetl/file/limit/max_files_count.py +++ b/onetl/file/limit/max_files_count.py @@ -48,7 +48,7 @@ class MaxFilesCount(BaseFileLimit, FrozenModel): def __init__(self, limit: int): # this is only to allow passing glob as positional argument - super().__init__(limit=limit) # type: ignore + super().__init__(limit=limit) def __repr__(self): return f"{self.__class__.__name__}({self.limit})" @@ -56,7 +56,8 @@ def __repr__(self): @validator("limit") def _limit_cannot_be_negative(cls, value): if value <= 0: - raise ValueError("Limit should be positive number") + msg = "Limit should be positive number" + raise ValueError(msg) return value def reset(self): diff --git a/onetl/file/limit/reset_limits.py b/onetl/file/limit/reset_limits.py index 181114cab..4bafccbf4 100644 --- a/onetl/file/limit/reset_limits.py +++ b/onetl/file/limit/reset_limits.py @@ -3,9 +3,10 @@ from __future__ import annotations import logging -from typing import Iterable +from typing import TYPE_CHECKING, Iterable -from onetl.base import BaseFileLimit +if TYPE_CHECKING: + from onetl.base import BaseFileLimit log = logging.getLogger(__name__) diff --git a/onetl/file/limit/total_files_size.py b/onetl/file/limit/total_files_size.py index af1830064..9874d785b 100644 --- a/onetl/file/limit/total_files_size.py +++ b/onetl/file/limit/total_files_size.py @@ -58,7 +58,7 @@ class TotalFilesSize(BaseFileLimit, FrozenModel): def __init__(self, limit: int | str): # this is only to allow passing glob as positional argument - super().__init__(limit=limit) # type: ignore + super().__init__(limit=limit) def __repr__(self): return f'{self.__class__.__name__}("{self.limit.human_readable()}")' @@ -66,7 +66,8 @@ def __repr__(self): @validator("limit") def _limit_cannot_be_negative(cls, value): if value <= 0: - raise ValueError("Limit should be positive number") + msg = "Limit should be positive number" + raise ValueError(msg) return value def reset(self): diff --git a/onetl/hooks/__init__.py b/onetl/hooks/__init__.py index a7136a44c..5ac122147 100644 --- a/onetl/hooks/__init__.py +++ b/onetl/hooks/__init__.py @@ -4,3 +4,13 @@ from onetl.hooks.hooks_state import resume_all_hooks, skip_all_hooks, stop_all_hooks from onetl.hooks.slot import slot from onetl.hooks.support_hooks import support_hooks + +__all__ = [ + "HookPriority", + "hook", + "resume_all_hooks", + "skip_all_hooks", + "slot", + "stop_all_hooks", + "support_hooks", +] diff --git a/onetl/hooks/hook.py b/onetl/hooks/hook.py index 0d5c92335..5bf894099 100644 --- a/onetl/hooks/hook.py +++ b/onetl/hooks/hook.py @@ -4,7 +4,7 @@ import logging import sys -from contextlib import contextmanager +from contextlib import contextmanager, suppress from dataclasses import dataclass from enum import Enum from functools import wraps @@ -38,8 +38,8 @@ class HookPriority(int, Enum): "Hooks with this priority will run last." -@dataclass # noqa: WPS338 -class Hook(Generic[T]): # noqa: WPS338 +@dataclass +class Hook(Generic[T]): """ Hook representation. @@ -282,14 +282,12 @@ def __enter__(self): Just remember this output and return it in :obj:`~process_result` as is. """ - try: + with suppress(StopIteration): self.first_yield_result = self.gen.send(None) - except StopIteration: - pass return self - def __exit__(self, exc_type, value, traceback): # noqa: WPS231 + def __exit__(self, exc_type, value, traceback): """ Copy of :obj:`contextlib._GeneratorContextManager.__exit__` """ @@ -299,7 +297,8 @@ def __exit__(self, exc_type, value, traceback): # noqa: WPS231 next(self.gen) except StopIteration: return False - raise RuntimeError("generator didn't stop") + msg = "generator didn't stop" + raise RuntimeError(msg) if value is None: # Need to force instantiation so we can reliably @@ -323,7 +322,7 @@ def __exit__(self, exc_type, value, traceback): # noqa: WPS231 if exc_type is StopIteration and exc.__cause__ is value: return False raise - except: # noqa: E722, B001 + except: # only re-raise if it's *not* the exception that was # passed to throw(), because __exit__() must not raise # an exception unless __exit__() itself failed. But throw() @@ -338,7 +337,8 @@ def __exit__(self, exc_type, value, traceback): # noqa: WPS231 if sys.exc_info()[1] is value: return False raise - raise RuntimeError("generator didn't stop after throw()") + msg = "generator didn't stop after throw()" + raise RuntimeError(msg) def process_result(self, result): """ @@ -371,7 +371,7 @@ def process_result(self, result): return None -def hook(inp: Callable[..., T] | None = None, enabled: bool = True, priority: HookPriority = HookPriority.NORMAL): +def hook(inp: Callable[..., T] | None = None, *, enabled: bool = True, priority: HookPriority = HookPriority.NORMAL): """ Initialize hook from callable/context manager. @@ -435,9 +435,10 @@ def process_result(self, result): ... """ - def inner_wrapper(callback: Callable[..., T]): # noqa: WPS430 + def inner_wrapper(callback: Callable[..., T]): if isinstance(callback, Hook): - raise TypeError("@hook decorator can be applied only once") + msg = "@hook decorator can be applied only once" + raise TypeError(msg) result = Hook(callback=callback, enabled=enabled, priority=priority) return wraps(callback)(result) diff --git a/onetl/hooks/hook_collection.py b/onetl/hooks/hook_collection.py index 955e1c5ee..6873fdc24 100644 --- a/onetl/hooks/hook_collection.py +++ b/onetl/hooks/hook_collection.py @@ -4,11 +4,13 @@ import logging from contextlib import contextmanager -from typing import Iterable +from typing import TYPE_CHECKING, Iterable -from onetl.hooks.hook import Hook from onetl.log import NOTICE +if TYPE_CHECKING: + from onetl.hooks.hook import Hook + logger = logging.getLogger(__name__) diff --git a/onetl/hooks/method_inheritance_stack.py b/onetl/hooks/method_inheritance_stack.py index 1f5501ed4..a56d585ba 100644 --- a/onetl/hooks/method_inheritance_stack.py +++ b/onetl/hooks/method_inheritance_stack.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import defaultdict +from typing import ClassVar class MethodInheritanceStack: @@ -44,7 +45,7 @@ class MethodInheritanceStack: BaseClass 1 """ - _stack: dict[type, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + _stack: ClassVar[dict[type, dict[str, int]]] = defaultdict(lambda: defaultdict(int)) def __init__(self, klass: type, method_name: str): self.klass = klass @@ -52,16 +53,15 @@ def __init__(self, klass: type, method_name: str): def __enter__(self): for cls in self.klass.mro(): - self.__class__._stack[cls][self.method_name] += 1 # noqa: WPS437 + self._stack[cls][self.method_name] += 1 return self def __exit__(self, _exc_type, _exc_value, _traceback): for cls in self.klass.mro(): - self.__class__._stack[cls][self.method_name] -= 1 # noqa: WPS437 + self._stack[cls][self.method_name] -= 1 return False @property def level(self) -> int: """Get level of inheritance""" - stack = self.__class__._stack # noqa: WPS437 - return stack[self.klass][self.method_name] - 1 # exclude current class + return self._stack[self.klass][self.method_name] - 1 # exclude current class diff --git a/onetl/hooks/slot.py b/onetl/hooks/slot.py index 46ebbad1b..412eeb0fc 100644 --- a/onetl/hooks/slot.py +++ b/onetl/hooks/slot.py @@ -102,11 +102,10 @@ def another_callable(self, arg): obj.method(1) # will call both callable(obj, 1) and another_callable(obj, 1) """ - def inner_wrapper(hook): # noqa: WPS430 + def inner_wrapper(hook): if not isinstance(hook, Hook): - raise TypeError( - f"@{method.__qualname__}.bind decorator can be used only on top function marked with @hook", - ) + msg = f"@{method.__qualname__}.bind decorator can be used only on top function marked with @hook" + raise TypeError(msg) method.__hooks__.add(hook) @@ -248,7 +247,7 @@ def _handle_context_result(result: Any, context: CanProcessResult, hook: Hook): raise -def register_slot(cls: type, method_name: str): # noqa: WPS231, WPS213, WPS212 +def register_slot(cls: type, method_name: str): # noqa: C901, PLR0915 """ Internal callback to register ``SomeClass.some_method`` as a slot. @@ -292,8 +291,8 @@ def static_method(arg1, arg2): # original_method = _unwrap_method(method_or_descriptor) - @wraps(original_method) # noqa: WPS231, WPS213 - def wrapper(*args, **kwargs): # noqa: WPS231, WPS213 + @wraps(original_method) + def wrapper(*args, **kwargs): # noqa: C901 with MethodInheritanceStack(cls, method_name) as stack_manager, ExitStack() as context_stack: if not HooksState.enabled(): logger.log(NOTICE, "|Hooks| All hooks are disabled") @@ -422,14 +421,14 @@ def wrapper(*args, **kwargs): # noqa: WPS231, WPS213 ) if context_result is not None: - call_result = "(None)" if result is None else "(*NOT* None)" # noqa: WPS220 - logger.log( # noqa: WPS220 + call_result = "(None)" if result is None else "(*NOT* None)" + logger.log( NOTICE, "|Hooks| %sMethod call result %s is modified by hook!", " " * indent, call_result, ) - result = context_result # noqa: WPS220 + result = context_result indent += 2 return result @@ -445,7 +444,7 @@ def wrapper(*args, **kwargs): # noqa: WPS231, WPS213 # wrap result back to @classmethod and @staticmethod, if was used if isinstance(method_or_descriptor, classmethod): return classmethod(wrapper) - elif isinstance(method_or_descriptor, staticmethod): + if isinstance(method_or_descriptor, staticmethod): return staticmethod(wrapper) return wrapper @@ -690,15 +689,18 @@ def callback3(arg): ... """ if hasattr(method, "__hooks__"): - raise SyntaxError("Cannot place @slot hook twice on the same method") + msg = "Cannot place @slot hook twice on the same method" + raise SyntaxError(msg) original_method = getattr(method, "__wrapped__", method) if not _is_method(original_method): - raise TypeError(f"@slot decorator could be applied to only to methods of class, got {type(original_method)}") + msg = f"@slot decorator could be applied to only to methods of class, got {type(original_method)}" + raise TypeError(msg) if _is_private(original_method): - raise ValueError(f"@slot decorator could be applied to public methods only, got '{original_method.__name__}'") + msg = f"@slot decorator could be applied to public methods only, got '{original_method.__name__}'" + raise ValueError(msg) method.__hooks__ = HookCollection() # type: ignore[attr-defined] return method diff --git a/onetl/hooks/support_hooks.py b/onetl/hooks/support_hooks.py index ad6d65a40..bac37eb06 100644 --- a/onetl/hooks/support_hooks.py +++ b/onetl/hooks/support_hooks.py @@ -213,7 +213,8 @@ def callback(self, arg): ... setattr(cls, method_name, register_slot(cls, method_name)) if not has_slots: - raise SyntaxError("@support_hooks can be used only with @slot decorator on some of class methods") + msg = "@support_hooks can be used only with @slot decorator on some of class methods" + raise SyntaxError(msg) cls.skip_hooks = partial(skip_hooks, cls) # type: ignore[attr-defined] cls.suspend_hooks = partial(suspend_hooks, cls) # type: ignore[attr-defined] diff --git a/onetl/hwm/__init__.py b/onetl/hwm/__init__.py index f01655863..edc816fa2 100644 --- a/onetl/hwm/__init__.py +++ b/onetl/hwm/__init__.py @@ -2,3 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from onetl.hwm.auto_hwm import AutoDetectHWM from onetl.hwm.window import Edge, Window + +__all__ = [ + "AutoDetectHWM", + "Edge", + "Window", +] diff --git a/onetl/hwm/auto_hwm.py b/onetl/hwm/auto_hwm.py index e3e6866d3..057b4c4d8 100644 --- a/onetl/hwm/auto_hwm.py +++ b/onetl/hwm/auto_hwm.py @@ -31,7 +31,8 @@ def handle_aliases(cls, values): def update(self: AutoDetectHWM, value: Any) -> AutoDetectHWM: """Update current HWM value with some implementation-specific logic, and return HWM""" - raise NotImplementedError("update method should be implemented in auto detected subclasses") + msg = "update method should be implemented in auto detected subclasses" + raise NotImplementedError(msg) def reset(self: AutoDetectHWM) -> AutoDetectHWM: raise NotImplementedError diff --git a/onetl/hwm/store/__init__.py b/onetl/hwm/store/__init__.py index 9213e042c..e07f41891 100644 --- a/onetl/hwm/store/__init__.py +++ b/onetl/hwm/store/__init__.py @@ -19,6 +19,13 @@ "register_hwm_store_class", } +__all__ = [ + "SparkTypeToHWM", + "YAMLHWMStore", + "default_hwm_store_class", + "register_spark_type_to_hwm_type_mapping", +] + def __getattr__(name: str): if name in deprecated_imports: diff --git a/onetl/hwm/store/hwm_class_registry.py b/onetl/hwm/store/hwm_class_registry.py index 83b1b000f..51356acc2 100644 --- a/onetl/hwm/store/hwm_class_registry.py +++ b/onetl/hwm/store/hwm_class_registry.py @@ -34,7 +34,7 @@ class SparkTypeToHWM: @classmethod def get(cls, spark_type: DataType) -> type[HWM] | None: # avoid importing pyspark in the module - from pyspark.sql.types import ( # noqa: WPS235 + from pyspark.sql.types import ( ByteType, DateType, DecimalType, diff --git a/onetl/hwm/store/yaml_hwm_store.py b/onetl/hwm/store/yaml_hwm_store.py index 43b3f2cba..e89c0b289 100644 --- a/onetl/hwm/store/yaml_hwm_store.py +++ b/onetl/hwm/store/yaml_hwm_store.py @@ -173,19 +173,19 @@ def validate_path(cls, path): return path @slot - def get_hwm(self, name: str) -> HWM | None: # type: ignore + def get_hwm(self, name: str) -> HWM | None: data = self._load(name) if not data: return None latest = sorted(data, key=operator.itemgetter("modified_time"))[-1] - return HWMTypeRegistry.parse(latest) # type: ignore + return HWMTypeRegistry.parse(latest) @slot - def set_hwm(self, hwm: HWM) -> LocalPath: # type: ignore + def set_hwm(self, hwm: HWM) -> LocalPath: data = self._load(hwm.name) - self._dump(hwm.name, [hwm.serialize()] + data) + self._dump(hwm.name, [*hwm.serialize(), *data]) return self.get_file_path(hwm.name) @classmethod diff --git a/onetl/impl/__init__.py b/onetl/impl/__init__.py index d616d2e3b..f604fc985 100644 --- a/onetl/impl/__init__.py +++ b/onetl/impl/__init__.py @@ -11,3 +11,18 @@ from onetl.impl.remote_file import FailedRemoteFile, RemoteFile from onetl.impl.remote_path import RemotePath from onetl.impl.remote_path_stat import RemotePathStat + +__all__ = [ + "BaseModel", + "FailedLocalFile", + "FailedRemoteFile", + "FileExistBehavior", + "FrozenModel", + "GenericOptions", + "LocalPath", + "RemoteDirectory", + "RemoteFile", + "RemotePath", + "RemotePathStat", + "path_repr", +] diff --git a/onetl/impl/base_model.py b/onetl/impl/base_model.py index a633d4cf5..d6db29a4a 100644 --- a/onetl/impl/base_model.py +++ b/onetl/impl/base_model.py @@ -29,14 +29,14 @@ def __init__(self, **kwargs): # when first object instance is being created refs = self._forward_refs() self.__class__.update_forward_refs(**refs) - self.__class__._forward_refs_updated = True # noqa: WPS437 + self.__class__._forward_refs_updated = True # noqa: SLF001 super().__init__(**kwargs) @classmethod def _forward_refs(cls) -> dict[str, type]: refs: dict[str, type] = {} for item in dir(cls): - if item.startswith("_") or item.startswith("package"): + if item.startswith(("_", "package")): continue value = getattr(cls, item) diff --git a/onetl/impl/failed_local_file.py b/onetl/impl/failed_local_file.py index 6e16900ef..acad346ab 100644 --- a/onetl/impl/failed_local_file.py +++ b/onetl/impl/failed_local_file.py @@ -19,7 +19,7 @@ class FailedLocalFile(PathContainer[LocalPath]): def __post_init__(self): # frozen=True does not allow to change any field in __post_init__, small hack here - object.__setattr__(self, "path", LocalPath(self.path)) # noqa: WPS609 + object.__setattr__(self, "path", LocalPath(self.path)) def __repr__(self) -> str: return f"{self.__class__.__name__}({os.fspath(self.path)!r}, {self.exception!r})" diff --git a/onetl/impl/file_exist_behavior.py b/onetl/impl/file_exist_behavior.py index 03fcd0be5..2dc4372df 100644 --- a/onetl/impl/file_exist_behavior.py +++ b/onetl/impl/file_exist_behavior.py @@ -16,8 +16,8 @@ class FileExistBehavior(str, Enum): def __str__(self): return str(self.value) - @classmethod # noqa: WPS120 - def _missing_(cls, value: object): # noqa: WPS120 + @classmethod + def _missing_(cls, value: object): if str(value) == "overwrite": warnings.warn( "Mode `overwrite` is deprecated since v0.9.0 and will be removed in v1.0.0. Use `replace_file` instead", @@ -34,3 +34,5 @@ def _missing_(cls, value: object): # noqa: WPS120 stacklevel=4, ) return cls.REPLACE_ENTIRE_DIRECTORY + + return None diff --git a/onetl/impl/generic_options.py b/onetl/impl/generic_options.py index 29c561116..fa87f713b 100644 --- a/onetl/impl/generic_options.py +++ b/onetl/impl/generic_options.py @@ -12,6 +12,8 @@ except (ImportError, AttributeError): from pydantic import root_validator # type: ignore[no-redef, assignment] +from typing_extensions import Self + from onetl.impl.frozen_model import FrozenModel log = logging.getLogger(__name__) @@ -20,15 +22,15 @@ class GenericOptions(FrozenModel): class Config: - strip_prefixes: list[str | re.Pattern] = [] + strip_prefixes: tuple[str | re.Pattern, ...] = () known_options: frozenset[str] | None = None prohibited_options: frozenset[str] = frozenset() @classmethod def parse( - cls: type[T], + cls, options: GenericOptions | dict | None, - ) -> T: + ) -> Self: """ If a parameter inherited from the ReadOptions class was passed, then it will be returned unchanged. If a Dict object was passed it will be converted to ReadOptions. @@ -43,9 +45,8 @@ def parse( return cls.parse_obj(options) if not isinstance(options, cls): - raise TypeError( - f"{options.__class__.__name__} is not a {cls.__name__} instance", - ) + msg = f"{options.__class__.__name__} is not a {cls.__name__} instance" + raise TypeError(msg) return options @@ -67,16 +68,16 @@ def _strip_prefixes(cls, values): new_key, ) if new_key in values: - log.warning("Overwriting existing value of key %r with %r", key, new_key) # noqa: WPS220 + log.warning("Overwriting existing value of key %r with %r", key, new_key) values[new_key] = value - key = new_key + key = new_key # noqa: PLW2901 return values @staticmethod def _strip_prefix(key: str, prefix: str | re.Pattern) -> tuple[str, str | None]: if isinstance(prefix, str) and key.startswith(prefix): return key.replace(prefix, "", 1), prefix - elif isinstance(prefix, re.Pattern) and prefix.match(key): + if isinstance(prefix, re.Pattern) and prefix.match(key): return prefix.sub("", key, 1), prefix.pattern return key, None @@ -96,7 +97,8 @@ def _check_options_allowed( matching_options = sorted(cls._get_matching_options(unknown_options, prohibited)) if matching_options: class_name = cls.__name__ # type: ignore[attr-defined] - raise ValueError(f"Options {matching_options!r} are not allowed to use in a {class_name}") + msg = f"Options {matching_options!r} are not allowed to use in a {class_name}" + raise ValueError(msg) return values diff --git a/onetl/impl/local_path.py b/onetl/impl/local_path.py index 36e025c8a..4de5067f0 100644 --- a/onetl/impl/local_path.py +++ b/onetl/impl/local_path.py @@ -8,11 +8,10 @@ class LocalPath(Path): def __new__(cls, *args, **kwargs): if cls is LocalPath: - cls = LocalWindowsPath if os.name == "nt" else LocalPosixPath + cls = LocalWindowsPath if os.name == "nt" else LocalPosixPath # noqa: PLW0642 if sys.version_info < (3, 12): return cls._from_parts(args) - else: - return object.__new__(cls) # noqa: WPS503 + return object.__new__(cls) class LocalPosixPath(LocalPath, PurePosixPath): diff --git a/onetl/impl/path_repr.py b/onetl/impl/path_repr.py index 747990848..50d21c3b6 100644 --- a/onetl/impl/path_repr.py +++ b/onetl/impl/path_repr.py @@ -70,7 +70,7 @@ def detect_kind(path: PathProtocol) -> str: # try to detect mode based on stats for detector, file_kind in FILE_TYPE_DETECTORS.items(): if detector(mode): - return file_kind # noqa: WPS220 + return file_kind if path.is_dir(): return "directory" @@ -126,8 +126,9 @@ def repr_exception(self) -> str: exception_formatted = textwrap.indent(exception, prefix) return os.linesep + exception_formatted + os.linesep - def info( + def info( # noqa: PLR0913 self, + *, with_kind: bool = True, with_size: bool = True, with_mode: bool = True, @@ -163,8 +164,9 @@ def info( return result_str -def path_repr( +def path_repr( # noqa: PLR0913 path: os.PathLike | str, + *, with_kind: bool = True, with_size: bool = True, with_mode: bool = True, diff --git a/onetl/impl/remote_directory.py b/onetl/impl/remote_directory.py index ef87ceb4e..2403d5bc1 100644 --- a/onetl/impl/remote_directory.py +++ b/onetl/impl/remote_directory.py @@ -4,12 +4,15 @@ import os from dataclasses import dataclass, field +from typing import TYPE_CHECKING -from onetl.base import PathStatProtocol from onetl.impl.path_container import PathContainer from onetl.impl.remote_path import RemotePath from onetl.impl.remote_path_stat import RemotePathStat +if TYPE_CHECKING: + from onetl.base import PathStatProtocol + @dataclass(eq=False, frozen=True) class RemoteDirectory(PathContainer[RemotePath]): @@ -21,7 +24,7 @@ class RemoteDirectory(PathContainer[RemotePath]): def __post_init__(self): # frozen=True does not allow to change any field in __post_init__, small hack here - object.__setattr__(self, "path", RemotePath(self.path)) # noqa: WPS609 + object.__setattr__(self, "path", RemotePath(self.path)) def is_dir(self) -> bool: return True diff --git a/onetl/impl/remote_file.py b/onetl/impl/remote_file.py index f69d53287..0871f9052 100644 --- a/onetl/impl/remote_file.py +++ b/onetl/impl/remote_file.py @@ -4,12 +4,15 @@ import os from dataclasses import dataclass +from typing import TYPE_CHECKING -from onetl.base import PathStatProtocol from onetl.impl.path_container import PathContainer from onetl.impl.remote_directory import RemoteDirectory from onetl.impl.remote_path import RemotePath +if TYPE_CHECKING: + from onetl.base import PathStatProtocol + @dataclass(eq=False, frozen=True) class RemoteFile(PathContainer[RemotePath]): @@ -21,7 +24,7 @@ class RemoteFile(PathContainer[RemotePath]): def __post_init__(self): # frozen=True does not allow to change any field in __post_init__, small hack here - object.__setattr__(self, "path", RemotePath(self.path)) # noqa: WPS609 + object.__setattr__(self, "path", RemotePath(self.path)) def __repr__(self) -> str: return f"{self.__class__.__name__}({os.fspath(self.path)!r})" diff --git a/onetl/impl/remote_path_stat.py b/onetl/impl/remote_path_stat.py index cf269c80f..5054c0fbe 100644 --- a/onetl/impl/remote_path_stat.py +++ b/onetl/impl/remote_path_stat.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from onetl.base.path_stat_protocol import PathStatProtocol from onetl.impl.frozen_model import FrozenModel +if TYPE_CHECKING: + from onetl.base.path_stat_protocol import PathStatProtocol + class RemotePathStat(FrozenModel): st_size: int = 0 diff --git a/onetl/log.py b/onetl/log.py index 1451e72e8..87014c2d9 100644 --- a/onetl/log.py +++ b/onetl/log.py @@ -68,7 +68,7 @@ def setup_notebook_logging(level: int | str = logging.INFO) -> None: setup_logging(level) -def setup_logging(level: int | str = logging.INFO, enable_clients: bool = False) -> None: +def setup_logging(level: int | str = logging.INFO, *, enable_clients: bool = False) -> None: """Set up onETL logging. What this function does: @@ -174,7 +174,7 @@ def set_default_logging_format() -> None: def _log(logger: logging.Logger, msg: str, *args, level: int = logging.INFO, stacklevel: int = 1, **kwargs) -> None: if sys.version_info >= (3, 8): # https://github.com/python/cpython/pull/7424 - logger.log(level, msg, *args, stacklevel=stacklevel + 1, **kwargs) # noqa: WPS204 + logger.log(level, msg, *args, stacklevel=stacklevel + 1, **kwargs) else: logger.log(level, msg, *args, **kwargs) @@ -215,7 +215,7 @@ def log_with_indent( _log(logger, "%s" + inp, " " * (BASE_LOG_INDENT + indent), *args, level=level, stacklevel=stacklevel + 1, **kwargs) -def log_lines( +def log_lines( # noqa: PLR0913 logger: logging.Logger, inp: str, name: str | None = None, @@ -257,7 +257,7 @@ def log_lines( _log(logger, "%s%s", base_indent, line, level=level, stacklevel=stacklevel) -def log_json( +def log_json( # noqa: PLR0913 logger: logging.Logger, inp: Any, name: str | None = None, @@ -295,7 +295,7 @@ def log_json( log_lines(logger, json.dumps(inp, indent=4), name, indent, level, stacklevel=stacklevel + 1) -def log_collection( +def log_collection( # noqa: PLR0913 logger: logging.Logger, name: str, collection: Iterable, diff --git a/onetl/plugins/__init__.py b/onetl/plugins/__init__.py index 63d3bdcc5..aa8f768ce 100644 --- a/onetl/plugins/__init__.py +++ b/onetl/plugins/__init__.py @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2023-present MTS PJSC # SPDX-License-Identifier: Apache-2.0 from onetl.plugins.import_plugins import import_plugins + +__all__ = ["import_plugins"] diff --git a/onetl/plugins/import_plugins.py b/onetl/plugins/import_plugins.py index 69c75d380..618cb28b2 100644 --- a/onetl/plugins/import_plugins.py +++ b/onetl/plugins/import_plugins.py @@ -63,7 +63,7 @@ def import_plugin(entrypoint: EntryPoint): raise ImportError(error_msg) from e -def import_plugins(group: str, whitelist: list[str] | None = None, blacklist: list[str] | None = None): # noqa: WPS213 +def import_plugins(group: str, whitelist: list[str] | None = None, blacklist: list[str] | None = None): """ Import all plugins registered for onETL """ diff --git a/onetl/strategy/__init__.py b/onetl/strategy/__init__.py index 5eb8e0786..510ea6e26 100644 --- a/onetl/strategy/__init__.py +++ b/onetl/strategy/__init__.py @@ -7,3 +7,12 @@ ) from onetl.strategy.snapshot_strategy import SnapshotBatchStrategy, SnapshotStrategy from onetl.strategy.strategy_manager import StrategyManager + +__all__ = [ + "BaseStrategy", + "IncrementalBatchStrategy", + "IncrementalStrategy", + "SnapshotBatchStrategy", + "SnapshotStrategy", + "StrategyManager", +] diff --git a/onetl/strategy/base_strategy.py b/onetl/strategy/base_strategy.py index 236171a9f..e94467e68 100644 --- a/onetl/strategy/base_strategy.py +++ b/onetl/strategy/base_strategy.py @@ -13,7 +13,7 @@ class BaseStrategy(BaseModel): def __enter__(self): - # hack to avoid circular imports + # avoid circular imports from onetl.strategy.strategy_manager import StrategyManager log.debug("|%s| Entered stack at level %d", self.__class__.__name__, StrategyManager.get_current_level()) @@ -49,7 +49,7 @@ def next(self) -> Edge: def enter_hook(self) -> None: pass - def exit_hook(self, failed: bool = False) -> None: + def exit_hook(self, *, failed: bool = False) -> None: pass def _log_parameters(self) -> None: diff --git a/onetl/strategy/batch_hwm_strategy.py b/onetl/strategy/batch_hwm_strategy.py index 699560e72..a6147c5a4 100644 --- a/onetl/strategy/batch_hwm_strategy.py +++ b/onetl/strategy/batch_hwm_strategy.py @@ -30,7 +30,8 @@ class BatchHWMStrategy(HWMStrategy): @validator("step", always=True) def step_is_not_none(cls, step): if not step: - raise ValueError(f"'step' argument of {cls.__name__} cannot be empty!") + msg = f"'step' argument of {cls.__name__} cannot be empty!" + raise ValueError(msg) return step @@ -118,23 +119,26 @@ def check_hwm_increased(self, next_value: Any) -> None: if next_value is not None and self.current.value >= next_value: # negative or zero step - exception # DateHWM with step value less than one day - exception - raise ValueError( - f"HWM value is not increasing, please check options passed to {self.__class__.__name__}!", - ) + msg = f"HWM value is not increasing, please check options passed to {self.__class__.__name__}!" + raise ValueError(msg) if self.stop is not None: expected_iterations = int((self.stop - self.current.value) / self.step) if expected_iterations >= self.MAX_ITERATIONS: - raise ValueError( + msg = ( f"step={self.step!r} parameter of {self.__class__.__name__} leads to " - f"generating too many iterations ({expected_iterations}+)", + f"generating too many iterations ({expected_iterations}+)" + ) + raise ValueError( + msg, ) @property def next(self) -> Edge: if self.current.is_set(): if not hasattr(self.current.value, "__add__"): - raise RuntimeError(f"HWM: {self.hwm!r} cannot be used with Batch strategies") + msg = f"HWM: {self.hwm!r} cannot be used with Batch strategies" + raise RuntimeError(msg) result = Edge(value=self.current.value + self.step) else: diff --git a/onetl/strategy/hwm_store/__init__.py b/onetl/strategy/hwm_store/__init__.py index 0f5aa37c8..fbb4b4902 100644 --- a/onetl/strategy/hwm_store/__init__.py +++ b/onetl/strategy/hwm_store/__init__.py @@ -11,13 +11,11 @@ from etl_entities.hwm_store import ( BaseHWMStore, HWMStoreClassRegistry, - ) - from etl_entities.hwm_store import HWMStoreStackManager as HWMStoreManager - from etl_entities.hwm_store import ( MemoryHWMStore, detect_hwm_store, register_hwm_store_class, ) + from etl_entities.hwm_store import HWMStoreStackManager as HWMStoreManager from onetl.hwm.store import ( SparkTypeToHWM, @@ -28,15 +26,15 @@ __all__ = [ "BaseHWMStore", - "SparkTypeToHWM", - "register_spark_type_to_hwm_type_mapping", "HWMStoreClassRegistry", - "default_hwm_store_class", - "detect_hwm_store", - "register_hwm_store_class", "HWMStoreManager", "MemoryHWMStore", + "SparkTypeToHWM", "YAMLHWMStore", + "default_hwm_store_class", + "detect_hwm_store", + "register_hwm_store_class", + "register_spark_type_to_hwm_type_mapping", ] diff --git a/onetl/strategy/hwm_strategy.py b/onetl/strategy/hwm_strategy.py index 74623dd71..c1f016039 100644 --- a/onetl/strategy/hwm_strategy.py +++ b/onetl/strategy/hwm_strategy.py @@ -112,7 +112,7 @@ def validate_hwm_attributes(self, current_hwm: HWM, new_hwm: HWM, origin: str): warnings.warn(message, UserWarning, stacklevel=2) - def exit_hook(self, failed: bool = False) -> None: + def exit_hook(self, *, failed: bool = False) -> None: if not failed: self.save_hwm() @@ -128,7 +128,7 @@ def save_hwm(self) -> None: log.info("|%s| Saving HWM to %r:", class_name, hwm_store.__class__.__name__) log_hwm(log, self.hwm) - location = hwm_store.set_hwm(self.hwm) # type: ignore + location = hwm_store.set_hwm(self.hwm) log.info("|%s| HWM has been saved", class_name) if location: diff --git a/onetl/strategy/incremental_strategy.py b/onetl/strategy/incremental_strategy.py index c7045cd49..76648f01f 100644 --- a/onetl/strategy/incremental_strategy.py +++ b/onetl/strategy/incremental_strategy.py @@ -17,7 +17,8 @@ class IncrementalStrategy(HWMStrategy): by filtering items not covered by the previous :ref:`HWM` value. For :ref:`db-reader`: - First incremental run is just the same as :obj:`SnapshotStrategy `: + First incremental run is just the same as + :obj:`SnapshotStrategy `: .. code:: sql @@ -54,8 +55,8 @@ class IncrementalStrategy(HWMStrategy): .. tab:: FileListHWM - First incremental run is just the same as :obj:`SnapshotStrategy ` - - all files are downloaded: + First incremental run is just the same as + :obj:`SnapshotStrategy ` - all files are downloaded: .. code:: bash @@ -123,8 +124,8 @@ class IncrementalStrategy(HWMStrategy): .. tab:: FileModifiedTimeHWM - First incremental run is just the same as :obj:`SnapshotStrategy ` - - all files are downloaded: + First incremental run is just the same as + :obj:`SnapshotStrategy ` - all files are downloaded: .. code:: bash @@ -143,7 +144,8 @@ class IncrementalStrategy(HWMStrategy): }, ) - Then the maximum modified time of original files is saved as ``FileModifiedTimeHWM`` object into :ref:`HWM Store `: + Then the maximum modified time of original files is saved as + ``FileModifiedTimeHWM`` object into :ref:`HWM Store `: .. code:: python @@ -153,7 +155,8 @@ class IncrementalStrategy(HWMStrategy): value=datetime.datetime(2025, 1, 1, 11, 22, 33, 456789, tzinfo=timezone.utc), ) - Next incremental run will download only files from the source which were modified or created since previous run: + Next incremental run will download only files from the source + which were modified or created since previous run: .. code:: bash @@ -189,8 +192,11 @@ class IncrementalStrategy(HWMStrategy): **NOT** while exiting strategy context. This is because: * FileDownloader does not raise exceptions if some file cannot be downloaded. - * FileDownloader creates files on local filesystem, and file content may differ for different :obj:`modes `. - * It can remove files from the source if :obj:`delete_source ` is set to ``True``. + * FileDownloader creates files on local filesystem, and file content may differ for different + :obj:`modes `. + * It can remove files from the source + if :obj:`delete_source ` + is set to ``True``. .. versionadded:: 0.1.0 diff --git a/onetl/strategy/strategy_manager.py b/onetl/strategy/strategy_manager.py index 3dc52147e..a94dfa0de 100644 --- a/onetl/strategy/strategy_manager.py +++ b/onetl/strategy/strategy_manager.py @@ -3,16 +3,18 @@ from __future__ import annotations import logging -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar -from onetl.strategy.base_strategy import BaseStrategy from onetl.strategy.snapshot_strategy import SnapshotStrategy +if TYPE_CHECKING: + from onetl.strategy.base_strategy import BaseStrategy + log = logging.getLogger(__name__) class StrategyManager: - default_strategy: ClassVar[type] = SnapshotStrategy + default_strategy: ClassVar[type[BaseStrategy]] = SnapshotStrategy _stack: ClassVar[list[BaseStrategy]] = [] diff --git a/pyproject.toml b/pyproject.toml index 2162c7d3e..80d64ef6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,6 @@ dev = [ "prek~=0.2.25", "types-Deprecated~=1.3.1", "types-PyYAML~=6.0.12", - "wemake-python-styleguide~=1.5.0", ] docs = [ # TODO: remove version limit after upgrading all Pydantic models to v2 @@ -232,6 +231,67 @@ target-version = "py37" line-length = 120 extend-exclude = ["docs/", "Makefile"] +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "A002", + "ANN", + "ARG", + "COM812", + "D", + "FIX002", + "N805", + "N815", + "PLC0415", + "PYI051", + "TC001", + "TC002", + "TC003", + "TD", + "UP006", + "UP007", + "UP045", +] + +[tool.ruff.lint.per-file-ignores] +"onetl/connection/file_connection/*" = ["ERA001"] + +"tests/*" = [ + "S", + "A", + "PLC0415", + "TC", + "ICN001", + "PT", + "PLR2004", + "PLR0913", + "SLF001", + "FBT", + "SIM117", + "SIM108", + "PLR0915", + "DTZ", + "B017", + "PERF401", + "BLE001", + "PERF203", +] +"tests/tests_unit/test_hooks/*" = ["C901"] +"tests/libs/*" = ["INP001"] + +[tool.mypy] +python_version = "3.7" +plugins = ["pydantic.mypy"] +exclude = "^(?=.*file).*" +strict_optional = true +ignore_missing_imports = true +follow_imports = "silent" +show_error_codes = true + +[tool.codespell] +ignore-words-list = ["INOUT", "inout", "thirdparty"] + + [tool.towncrier] name = "onETL" package = "onetl" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a145b2192..000000000 --- a/setup.cfg +++ /dev/null @@ -1,462 +0,0 @@ -# WeMakePythonStyleGuide: -# https://wemake-python-stylegui.de/en/latest/index.html -# https://wemake-python-stylegui.de/en/latest/pages/usage/configuration.html -# https://wemake-python-stylegui.de/en/latest/pages/usage/violations/index.html -# http://pycodestyle.pycqa.org/en/latest/intro.html -# http://flake8.pycqa.org/en/latest/user/configuration.html -# http://flake8.pycqa.org/en/latest/user/options.html -# http://flake8.pycqa.org/en/latest/user/error-codes.html -# http://flake8.pycqa.org/en/latest/user/violations.html -# https://wemake-python-stylegui.de/en/latest/pages/usage/formatter.html -# https://wemake-python-stylegui.de/en/latest/pages/usage/integrations/plugins.html -# http://flake8.pycqa.org/en/latest/user/options.html?highlight=per-file-ignores#cmdoption-flake8-per-file-ignores - -[autoflake] -imports = onetl,pydantic,etl_entities,pyspark,tests -ignore-init-module-imports = true -remove-unused-variables = true - - -[flake8] -# Wemake Python Style Guide Configuration - -jobs = 4 - -min-name-length = 1 -# We don't control ones who use our code -i-control-code = False -nested-classes-whitelist = - Meta, - NewDate, -# Max of noqa in a module -max-noqa-comments = 10 -max-annotation-complexity = 4 -max-returns = 5 -max-awaits = 5 -max-local-variables = 20 -max-name-length = 60 -# Max of expressions in a function -max-expressions = 15 -# Max args in a function -max-arguments = 15 -# Max classes and functions in a single module -max-module-members = 35 -max-methods = 25 -# Max line complexity measured in AST nodes -max-line-complexity = 24 -# Max Jones Score for a module: the median of all lines complexity sum -max-jones-score = 15 -# Max amount of cognitive complexity per function -max-cognitive-score = 20 -# Max amount of cognitive complexity per module -max-cognitive-average = 25 -max-imports = 25 -max-imported-names = 60 -# Max of expression usages in a module -max-module-expressions = 15 -# Max of expression usages in a function -max-function-expressions = 15 -max-base-classes = 5 -max-decorators = 6 -# Max of repeated string constants in your modules -max-string-usages = 15 -max-try-body-length = 15 -max-asserts = 15 -# Max number of access level in an expression -max-access-level = 6 -# maximum number of public instance attributes -max-attributes = 20 - -max-line-length = 120 -max-doc-length = 120 - -# https://pypi.org/project/flake8-quotes/ -inline-quotes = double -multiline-quotes = double -docstring-quotes = double - -# https://wemake-python-stylegui.de/en/latest/pages/usage/formatter.html -#format = '%(path)s:%(row)d:%(col)d: %(code)s %(text)s' -;format = wemake -show-source = True -# Print total number of errors -count = True -statistics = True -# benchmark = True - -exclude = - .tox, - migrations, - dist, - build, - hadoop_archive_plugin, - virtualenv, - venv, - venv36, - ve, - .venv, - tox.ini, - docker, - Jenkinsfile, - dags, - setup.py, - docs, - -# https://github.com/peterjc/flake8-rst-docstrings/pull/16 -rst-directives = - # These are sorted alphabetically - but that does not matter - autosummary,data,currentmodule,deprecated, - glossary,moduleauthor,plot,testcode, - versionadded,versionchanged, - -rst-roles = - attr,class,func,meth,mod,obj,ref,term, - # Python programming language: - py:func,py:mod, - -# https://wemake-python-stylegui.de/en/latest/pages/usage/violations/index.html -# http://pycodestyle.pycqa.org/en/latest/intro.html -ignore = -# Import at the wrong position -# [buggy with trailing commas and "as " imports] -# [too much hassle] -# [sometimes flask imports cannot be placed alphabetically] -#FIXME: change where can be done, later switch on - I, - ANN, -# Found name reserved for first argument: cls [opinionated] - WPS117, -# __future__ import "division" missing - FI10, -# __future__ import "absolute_import" missing - FI11, -# __future__ import "with_statement" missing - FI12, -# __future__ import "print_function" missing - FI13, -# __future__ import "unicode_literals" missing - FI14, -# __future__ import "generator_stop" missing - FI15, -# __future__ import "nested_scopes" missing - FI16, -# __future__ import "generators" missing - FI17, -# __future__ import "annotations" present - FI58, -# Found `f` string [opinionated] - WPS305, -# Found explicit string concat [opinionated] - WPS336, -# Found using `@staticmethod` [opinionated] - WPS602, -# Found wrong variable name ("data", "handler", "params") [opinionated] - WPS110, -# Found upper-case . constant in a class (flask config requires uppercase consts) [opinionated] - WPS115, -# WPS223: Found too many `elif` branches - WPS223, -# Found class without a base class (goes against PEP8) [opinionated] - WPS306, -# Found line break before binary operator [goes against PEP8] [opinionated] - W503, -# Found multiline conditions [opinionated] - WPS337, -# Found mutable module constant [opinionated] - WPS407, -# WPS411 Found empty module: - WPS411, -# Found nested import [opinionated] - WPS433, -# Found negated condition [opinionated] - WPS504, -# WPS529:Found implicit `.get()` dict usage - WPS529, -# FIXME: handle with docstring later -# Docstrings [opinionated] - D, -# P101 and P103 string does contain unindexed parameters' - P101, - P103, -# WPS237:Found a too complex `f` string - WPS237, -# WPS316 Found context manager with too many assignments - WPS316, -# WPS326 Found implicit string concatenation [optional] - WPS326, -# WPS347 Found vague import that may cause confusion - WPS347, -# WPS421 Found wrong function call: locals' - WPS421, -# WPS348 Found a line that starts with a dot - WPS348, -# WPS440 Found block variables overlap - WPS440, -# WPS459 Found comparison with float or complex number [buggy] - WPS459, -# S108 Probable insecure usage of temp file/directory. - S108, -# S404 Consider possible security implications associated with check_call module - S404, -# S603 subprocess call - check for execution of untrusted input - S603, -# S607 Starting a process with a partial executable path - S607, -# S608 Possible SQL injection vector through string-based query construction. - S608, -# E402 module level import not at top of file - E402, -# RST399: Document or section may not begin with a transition. - RST399, -# WPS432 Found magic number - WPS432, -# WPS615 Found unpythonic getter or setter - WPS615, -# RST213: Inline emphasis start-string without end-string. - RST213, -# RST304: Unknown interpreted text role - RST304, -# RST307: Error in "code" directive - RST307, -# WPS428 Found statement that has no effect - WPS428, -# WPS462 Wrong multiline string usage - WPS462, -# WPS303 Found underscored number: - WPS303, -# WPS431 Found nested class - WPS431, -# WPS317 Invalid multiline string usage - WPS317, -# WPS226 Found string literal over-use: | [bug] - WPS226, -# WPS323 Found `%` string formatting - WPS323, -# RST305 Undefined substitution referenced: support_hooks - RST305, -# RST303 Unknown directive type tabs - RST303, -# WPS402 Found `noqa` comments overuse - WPS402, -# WPS214 Found too many methods - WPS214, -# WPS605 Found method without arguments - WPS605, -# N805 first argument of a method should be named 'self' - N805, -# WPS238 Found too many raises in a function - WPS238, -# W505: doc line too long - W505, -# E501: line too long - E501, -# WPS114 Found underscored number name pattern: package_spark_2_3 - WPS114, -# WPS420 Found wrong keyword pass - WPS420 -# WPS600 Found subclassing a builtin: str - WPS600, -# WPS601 Found shadowed class attribute - WPS601, -# WPS604 Found incorrect node inside `class` body: pass - WPS604, -# WPS100 Found wrong module name: util - WPS100, -# WPS436 Found protected module import: onetl._util -# https://github.com/wemake-services/wemake-python-styleguide/issues/1441 - WPS436, -# WPS201 Found module with too many imports: 26 > 25 - WPS201, -# WPS429 Found multiple assign targets - WPS429, -# https://github.com/wemake-services/wemake-python-styleguide/issues/2847 -# E704 multiple statements on one line: def func(): ... - E704, -# WPS474 Found import object collision - WPS474, -# WPS318 Found extra indentation - WPS318, -# WPS410 Found wrong metadata variable: __all__ - WPS410, -# WPS412 Found `__init__.py` module with logic - WPS412, -# WPS413 Found bad magic module function: __getattr__ - WPS413, -# WPS338 Found incorrect order of methods in a class - WPS338, -# P102 docstring does contain unindexed parameters - P102 - -# http://flake8.pycqa.org/en/latest/user/options.html?highlight=per-file-ignores#cmdoption-flake8-per-file-ignores -per-file-ignores = - __init__.py: -# * imports are valid for __init__.py scripts - F403, - WPS347, - WPS440, -# __init__.py scripts may require a lot of imports - WPS235, -# F401 imported but unused - F401, - conftest.py: -# E800 Found commented out code - E800, -# S105 Possible hardcoded password [test usage] - S105, -# WPS442 Found outer scope names shadowing - WPS442, -# WPS432 Found magic number: 2020 - WPS432, -# WPS235 Found too many imported names from a module - WPS235, -# WPS202 Found too many module members: 36 > 35 - WPS202, - file_result.py: -# E800 Found commented out code - E800, - yaml_hwm_store.py: -# E800 Found commented out code - E800, -# E800 Found commented out code - E800, - kafka.py: -# too few type annotations - TAE001, - *connection.py: -# WPS437 Found protected attribute usage: spark._sc._gateway - WPS437, - onetl/connection/db_connection/jdbc_mixin/connection.py: -# too few type annotations - TAE001, -# WPS219 :Found too deep access level - WPS219, -# WPS437: Found protected attribute usage: spark._jvm - WPS437, - onetl/connection/db_connection/kafka/connection.py: -# WPS342: Found implicit raw string \\n - WPS342, -# WPS437 Found protected attribute usage: self._jvm - WPS437, - onetl/connection/db_connection/kafka/kafka_ssl_protocol.py: -# WPS342: Found implicit raw string \\n - WPS342, - onetl/_util/*: -# WPS437 Found protected attribute usage: spark._jvm - WPS437, - file_filter.py: - onetl/connection/file_connection/file_connection.py: -# WPS220: Found too deep nesting - WPS220, - onetl/connection/file_connection/hdfs/connection.py: -# E800 Found commented out code - E800, -# F401 'hdfs.ext.kerberos.KerberosClient as CheckForKerberosSupport' imported but unused - F401, -# WPS442 Found outer scope names shadowing: KerberosClient - WPS442, - onetl/file/format/*.py: -# N815 variable 'rootTag' in class scope should not be mixedCase - N815, -# WPS342 Found implicit raw string - WPS342, - onetl/hooks/slot.py: -# WPS210 Found too many local variables - WPS210, - onetl/_metrics/listener/*: -# N802 function name 'onJobStart' should be lowercase - N802, - tests/*: -# Found too many empty lines in `def` - WPS473, -# TAE001 too few type annotations - TAE001, -# U100 Unused argument - U100, -# WPS220: Found too deep nesting - WPS220, -# WPS231 Found function with too much cognitive complexity - WPS231, -# FI18 __future__ import "annotations" missing - FI18, -# S101 Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. - S101, -# S105 Possible hardcoded password - S105, -# WPS122:Found all unused variables definition - WPS122, -# WPS125 Found builtin shadowing: globals [test setup] - WPS125, -# WPS204:Found overused expression [ok, for test purpose] - WPS204, -# WPS218 Found too many `assert` statements - WPS218, -# WPS219 Found too deep access level - WPS219, -# WPS425 Found boolean non-keyword argument: False - WPS425, -# WPS430 Found nested function - WPS430, -# WPS432 Found magic number - WPS432, -# WPS437 Found protected attribute usage - WPS437, -# WPS442 Found outer scope names shadowing [ok for test usage] - WPS442, -# WPS517 Found pointless starred expression - WPS517, -# WPS609 Found direct magic attribute usage - WPS609, -# WPS325 Found inconsistent `yield` statement - WPS325, -# WPS360 Found an unnecessary use of a raw string - WPS360 -# S106 Possible hardcoded password [test usage] - S106, -# WPS118 Found too long name - WPS118, -# WPS235 Found too many imported names from a module - WPS235, -# WPS213 Found too many expressions - WPS213, -# WPS212 Found too many return statements: 7 > 5 - WPS212, -# F401 'onetl' imported but unused - F401, -# F811 redefinition of unused 'onetl' from line 72 - F811, -# F821: undefined name - F821, -# WPS429: Found multiple assign targets a = b = 'c' - WPS429, -# WPS342: Found implicit raw string - WPS342, -# WPS520 Found compare with falsy constant: == [] - WPS520, -# B017 `pytest.raises(Exception)` should be considered evil - B017, -# WPS202 Found too many module members: 40 > 35 - WPS202, -# WPS210 Found too many local variables: 21 > 20 - WPS210, -# WPS441 Found control variable used after block: file - WPS441, -# WPS333 Found implicit complex compare - WPS333 - - -[darglint] -docstring_style = sphinx - -[mypy] -python_version = 3.8 -# TODO: remove later -exclude = ^(?=.*file).* -strict_optional = True -# ignore typing in third-party packages -ignore_missing_imports = True -follow_imports = silent -show_error_codes = True -disable_error_code = name-defined, misc - -[codespell] -ignore-words-list = INOUT, inout, thirdparty diff --git a/tests/fixtures/connections/ftp.py b/tests/fixtures/connections/ftp.py index 7c08bb37c..afc31c255 100644 --- a/tests/fixtures/connections/ftp.py +++ b/tests/fixtures/connections/ftp.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -14,7 +14,11 @@ ], ) def ftp_server(): - FTPServer = namedtuple("FTPServer", ["host", "port", "user", "password"]) + class FTPServer(NamedTuple): + host: str + port: str + user: str + password: str return FTPServer( host=os.getenv("ONETL_FTP_HOST"), diff --git a/tests/fixtures/connections/ftps.py b/tests/fixtures/connections/ftps.py index 9ca637a4c..ba89df935 100644 --- a/tests/fixtures/connections/ftps.py +++ b/tests/fixtures/connections/ftps.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -14,7 +14,11 @@ ], ) def ftps_server(): - FTPSServer = namedtuple("FTPSServer", ["host", "port", "user", "password"]) + class FTPSServer(NamedTuple): + host: str + port: str + user: str + password: str return FTPSServer( host=os.getenv("ONETL_FTPS_HOST"), diff --git a/tests/fixtures/connections/hdfs.py b/tests/fixtures/connections/hdfs.py index 568943598..239e38433 100644 --- a/tests/fixtures/connections/hdfs.py +++ b/tests/fixtures/connections/hdfs.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -14,7 +14,11 @@ ], ) def hdfs_server(): - HDFSServer = namedtuple("HDFSServer", ["host", "webhdfs_port", "ipc_port"]) + class HDFSServer(NamedTuple): + host: str + webhdfs_port: str + ipc_port: str + return HDFSServer( host=os.getenv("ONETL_HDFS_HOST"), webhdfs_port=os.getenv("ONETL_HDFS_WEBHDFS_PORT"), diff --git a/tests/fixtures/connections/s3.py b/tests/fixtures/connections/s3.py index e1ccc7c31..ebd3f9948 100644 --- a/tests/fixtures/connections/s3.py +++ b/tests/fixtures/connections/s3.py @@ -1,7 +1,7 @@ import os -from collections import namedtuple from contextlib import suppress from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -15,7 +15,14 @@ ], ) def s3_server(): - S3Server = namedtuple("S3Server", ["host", "port", "bucket", "access_key", "secret_key", "protocol", "region"]) + class S3Server(NamedTuple): + host: str + port: str + bucket: str + access_key: str + secret_key: str + protocol: str + region: str return S3Server( host=os.getenv("ONETL_S3_HOST"), diff --git a/tests/fixtures/connections/samba.py b/tests/fixtures/connections/samba.py index 076fd3c92..73e3fb1c9 100644 --- a/tests/fixtures/connections/samba.py +++ b/tests/fixtures/connections/samba.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -14,7 +14,13 @@ ], ) def samba_server(): - SambaServer = namedtuple("SambaServer", ["host", "protocol", "port", "share", "user", "password"]) + class SambaServer(NamedTuple): + host: str + protocol: str + port: str + share: str + user: str + password: str return SambaServer( host=os.getenv("ONETL_SAMBA_HOST"), diff --git a/tests/fixtures/connections/sftp.py b/tests/fixtures/connections/sftp.py index 605637e60..d52dda8b6 100644 --- a/tests/fixtures/connections/sftp.py +++ b/tests/fixtures/connections/sftp.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -14,7 +14,11 @@ ], ) def sftp_server(): - SFTPServer = namedtuple("SFTPServer", ["host", "port", "user", "password"]) + class SFTPServer(NamedTuple): + host: str + port: str + user: str + password: str return SFTPServer( host=os.getenv("ONETL_SFTP_HOST"), diff --git a/tests/fixtures/connections/webdav.py b/tests/fixtures/connections/webdav.py index d6975013f..33f03ceb1 100644 --- a/tests/fixtures/connections/webdav.py +++ b/tests/fixtures/connections/webdav.py @@ -1,6 +1,6 @@ import os -from collections import namedtuple from pathlib import PurePosixPath +from typing import NamedTuple import pytest @@ -17,7 +17,13 @@ ], ) def webdav_server(): - WebDAVServer = namedtuple("WebDAVServer", ["host", "port", "user", "password", "ssl_verify", "protocol"]) + class WebDAVServer(NamedTuple): + host: str + port: str + user: str + password: str + ssl_verify: bool + protocol: str return WebDAVServer( host=os.getenv("ONETL_WEBDAV_HOST"), diff --git a/tests/fixtures/create_keytab.py b/tests/fixtures/create_keytab.py index c44cb1a62..41136ea95 100644 --- a/tests/fixtures/create_keytab.py +++ b/tests/fixtures/create_keytab.py @@ -14,4 +14,4 @@ def create_keytab(tmp_path_factory): @pytest.fixture def keytab_md5(): - return hashlib.md5(b"content").hexdigest() # noqa: S324 # nosec + return hashlib.md5(b"content").hexdigest() # nosec diff --git a/tests/fixtures/global_hwm_store.py b/tests/fixtures/global_hwm_store.py index 2e006b923..4edd32cb2 100644 --- a/tests/fixtures/global_hwm_store.py +++ b/tests/fixtures/global_hwm_store.py @@ -3,7 +3,7 @@ @pytest.fixture(scope="function", autouse=True) -def global_hwm_store(request): # noqa: WPS325 +def global_hwm_store(request): test_function = request.function entities = set(test_function.__name__.split("_")) if test_function else set() diff --git a/tests/fixtures/processing/base_processing.py b/tests/fixtures/processing/base_processing.py index 5bf845aa3..3d258d05c 100644 --- a/tests/fixtures/processing/base_processing.py +++ b/tests/fixtures/processing/base_processing.py @@ -6,7 +6,7 @@ from datetime import date, datetime, timedelta from logging import getLogger from random import randint -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import pandas @@ -22,9 +22,9 @@ class BaseProcessing(ABC): _df_max_length: int = 100 - _column_types_and_names_matching: dict[str, str] = {} + _column_types_and_names_matching: ClassVar[dict[str, str]] = {} - column_names: list[str] = ["id_int", "text_string", "hwm_int", "hwm_date", "hwm_datetime", "float_value"] + column_names: ClassVar[list[str]] = ["id_int", "text_string", "hwm_int", "hwm_date", "hwm_datetime", "float_value"] def create_schema_ddl( self, @@ -140,10 +140,10 @@ def create_pandas_df( elif "text" in column_name: values[column].append(secrets.token_hex(16)) elif "datetime" in column_name: - rand_second = randint(0, i * time_multiplier) # noqa: S311 + rand_second = randint(0, i * time_multiplier) values[column].append(self.current_datetime() + timedelta(seconds=rand_second)) elif "date" in column_name: - rand_second = randint(0, i * time_multiplier) # noqa: S311 + rand_second = randint(0, i * time_multiplier) values[column].append(self.current_date() + timedelta(seconds=rand_second)) return pandas.DataFrame(data=values) @@ -175,7 +175,8 @@ def assert_equal_df( if other_frame is None: if schema is None or table is None: - raise TypeError("Cannot use assert_equal_df without schema and table") + msg = "Cannot use assert_equal_df without schema and table" + raise TypeError(msg) other_frame = self.get_expected_dataframe(schema=schema, table=table, order_by=order_by) left_df = self.fix_pandas_df(to_pandas(df)) @@ -194,7 +195,8 @@ def assert_subset_df( if other_frame is None: if schema is None or table is None: - raise TypeError("Cannot use assert_equal_df without schema and table") + msg = "Cannot use assert_equal_df without schema and table" + raise TypeError(msg) other_frame = self.get_expected_dataframe(schema=schema, table=table) small_df = self.fix_pandas_df(to_pandas(df)) diff --git a/tests/fixtures/processing/clickhouse.py b/tests/fixtures/processing/clickhouse.py index 5a851ad51..c7ace6f70 100644 --- a/tests/fixtures/processing/clickhouse.py +++ b/tests/fixtures/processing/clickhouse.py @@ -6,6 +6,7 @@ from datetime import date, datetime, timedelta from logging import getLogger from random import randint +from typing import ClassVar import clickhouse_driver import pandas @@ -16,7 +17,7 @@ class ClickhouseProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "Int32", "text_string": "String", "hwm_int": "Int32", @@ -62,13 +63,13 @@ def client_port(self) -> int: return int(os.environ["ONETL_CH_PORT_CLIENT"]) def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas.DataFrame: - max_id = self._df_max_length if not max_id else max_id + max_id = max_id if max_id else self._df_max_length time_multiplier = 100000 values = defaultdict(list) for i in range(min_id, max_id + 1): - for column_name in self.column_names: - column_name = column_name.lower() + for raw_column_name in self.column_names: + column_name = raw_column_name.lower() if "int" in column_name: values[column_name].append(i) @@ -77,11 +78,11 @@ def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas elif "text" in column_name: values[column_name].append(secrets.token_hex(16)) elif "datetime" in column_name: - rand_second = randint(0, i * time_multiplier) # noqa: S311 + rand_second = randint(0, i * time_multiplier) # Clickhouse DATETIME format has time range: 00:00:00 through 23:59:59 values[column_name].append(datetime.now().replace(microsecond=0) + timedelta(seconds=rand_second)) elif "date" in column_name: - rand_second = randint(0, i * time_multiplier) # noqa: S311 + rand_second = randint(0, i * time_multiplier) values[column_name].append(date.today() + timedelta(seconds=rand_second)) return pandas.DataFrame(data=values) @@ -105,7 +106,7 @@ def create_table_ddl( schema: str, ) -> str: str_fields = ", ".join([f"{key} {value}" for key, value in fields.items()]) - first_field = list(fields.keys())[0] + first_field = next(iter(fields.keys())) return f""" CREATE TABLE IF NOT EXISTS {schema}.{table} ({str_fields}) diff --git a/tests/fixtures/processing/fixtures/iceberg.py b/tests/fixtures/processing/fixtures/iceberg.py index f96ba13ad..f4d89aabc 100644 --- a/tests/fixtures/processing/fixtures/iceberg.py +++ b/tests/fixtures/processing/fixtures/iceberg.py @@ -1,5 +1,5 @@ import os -from collections import namedtuple +from typing import NamedTuple import pytest from pytest_lazyfixture import lazy_fixture @@ -48,7 +48,9 @@ def iceberg_connection_fs_catalog_hdfs_warehouse(spark, iceberg_warehouse_dir, h ], ) def iceberg_rest_catalog_server(): - IcebergRESTCatalogServer = namedtuple("IcebergRESTCatalogServer", ["host", "port"]) + class IcebergRESTCatalogServer(NamedTuple): + host: str + port: str return IcebergRESTCatalogServer( host=os.getenv("ONETL_ICEBERG_REST_CATALOG_HOST"), diff --git a/tests/fixtures/processing/fixtures/processing.py b/tests/fixtures/processing/fixtures/processing.py index ccd3c57ad..17dce00f7 100644 --- a/tests/fixtures/processing/fixtures/processing.py +++ b/tests/fixtures/processing/fixtures/processing.py @@ -1,10 +1,15 @@ import secrets -from collections import namedtuple +from contextlib import suppress from importlib import import_module +from typing import NamedTuple import pytest -PreparedDbInfo = namedtuple("PreparedDbInfo", ["full_name", "schema", "table"]) + +class PreparedDbInfo(NamedTuple): + full_name: str + schema: str + table: str @pytest.fixture() @@ -25,8 +30,12 @@ def processing(request, spark): test_name_parts = set(request.function.__name__.split("_")) matches = set(processing_classes.keys()) & test_name_parts if not matches or len(matches) > 1: + msg = ( + f"Test name {request.function.__name__} should have one " + "of these components: {list(processing_classes.keys())}" + ) raise ValueError( - f"Test name {request.function.__name__} should have one of these components: {list(processing_classes.keys())}", + msg, ) db_storage_name = matches.pop() @@ -51,13 +60,11 @@ def get_schema_table(processing, worker_id): yield PreparedDbInfo(full_name=full_name, schema=schema, table=table) - try: + with suppress(Exception): processing.drop_table( table=table, schema=schema, ) - except Exception: # noqa: S110 - pass @pytest.fixture diff --git a/tests/fixtures/processing/hive.py b/tests/fixtures/processing/hive.py index fea84e540..50b53e378 100644 --- a/tests/fixtures/processing/hive.py +++ b/tests/fixtures/processing/hive.py @@ -3,7 +3,7 @@ import os from collections import defaultdict from logging import getLogger -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import pandas from pytest import FixtureRequest @@ -17,7 +17,7 @@ class HiveProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "int", "text_string": "string", "hwm_int": "int", diff --git a/tests/fixtures/processing/iceberg.py b/tests/fixtures/processing/iceberg.py index 0b274dd50..fa3f99479 100644 --- a/tests/fixtures/processing/iceberg.py +++ b/tests/fixtures/processing/iceberg.py @@ -2,7 +2,7 @@ import os from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import pandas from pytest import FixtureRequest @@ -14,7 +14,7 @@ class IcebergProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "int", "text_string": "string", "hwm_int": "int", @@ -35,8 +35,9 @@ def catalog(self) -> str: if catalog in markers: return f"my_{catalog}" + msg = f"One of possible catalog types should be in markers: {self._supported_catalog_types}" raise ValueError( - f"One of possible catalog types should be in markers: {self._supported_catalog_types}", + msg, ) @property diff --git a/tests/fixtures/processing/kafka.py b/tests/fixtures/processing/kafka.py index bddb3490b..ad21c4cb7 100644 --- a/tests/fixtures/processing/kafka.py +++ b/tests/fixtures/processing/kafka.py @@ -2,7 +2,7 @@ import json import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar import pandas @@ -16,7 +16,7 @@ class KafkaProcessing(BaseProcessing): - column_names: list[str] = ["id_int", "text_string", "hwm_int", "float_value"] + column_names: ClassVar[list[str]] = ["id_int", "text_string", "hwm_int", "float_value"] def __enter__(self): return self @@ -87,7 +87,8 @@ def delivery_report(err, msg): from confluent_kafka import KafkaException if err is not None: - raise KafkaException(f"Message {msg} delivery failed: {err}") + msg = f"Message {msg} delivery failed: {err}" + raise KafkaException(msg) def send_message(self, topic, message, timeout: float = DEFAULT_TIMEOUT): from confluent_kafka import KafkaException @@ -96,7 +97,8 @@ def send_message(self, topic, message, timeout: float = DEFAULT_TIMEOUT): producer.produce(topic, message, callback=self.delivery_report) messages_left = producer.flush(timeout) if messages_left: - raise KafkaException(f"{messages_left} messages were not delivered") + msg = f"{messages_left} messages were not delivered" + raise KafkaException(msg) def get_expected_df(self, topic: str, num_messages: int = 1, timeout: float = DEFAULT_TIMEOUT) -> pandas.DataFrame: from confluent_kafka import KafkaException @@ -110,12 +112,11 @@ def get_expected_df(self, topic: str, num_messages: int = 1, timeout: float = DE for msg in messages: if msg.error(): raise KafkaException(msg.error()) - else: - key = msg.key().decode("utf-8") if msg.key() else None - value = msg.value().decode("utf-8") if msg.value() else None - partition = msg.partition() - headers = msg.headers() - result.append((key, value, partition, headers, topic)) + key = msg.key().decode("utf-8") if msg.key() else None + value = msg.value().decode("utf-8") if msg.value() else None + partition = msg.partition() + headers = msg.headers() + result.append((key, value, partition, headers, topic)) consumer.close() return pandas.DataFrame(result, columns=["key", "value", "partition", "headers", "topic"]) @@ -135,11 +136,12 @@ def change_topic_partitions(self, topic: str, num_partitions: int, timeout: floa # change the number of partitions fs = admin_client.create_partitions(new_partitions, request_timeout=timeout) - for topic, f in fs.items(): + for topic_name, f in fs.items(): try: f.result() except Exception as e: - raise Exception(f"Failed to update number of partitions for topic '{topic}': {e}") # noqa: WPS454 + msg = f"Failed to update number of partitions for topic '{topic_name}': {e}" + raise RuntimeError(msg) from e def create_topic(self, topic: str, num_partitions: int, timeout: float = DEFAULT_TIMEOUT): from confluent_kafka.admin import KafkaException, NewTopic @@ -148,11 +150,12 @@ def create_topic(self, topic: str, num_partitions: int, timeout: float = DEFAULT topic_config = NewTopic(topic, num_partitions=num_partitions, replication_factor=1) fs = admin_client.create_topics([topic_config], request_timeout=timeout) - for topic, f in fs.items(): + for topic_name, f in fs.items(): try: f.result() except Exception as e: - raise KafkaException(f"Error creating topic '{topic}': {e}") + msg = f"Error creating topic '{topic_name}': {e}" + raise KafkaException(msg) from e def delete_topic(self, topics: list[str], timeout: float = DEFAULT_TIMEOUT): admin = self.get_admin_client() @@ -177,7 +180,7 @@ def get_num_partitions(self, topic: str, timeout: float = DEFAULT_TIMEOUT) -> in # Return the number of partitions return len(topic_metadata.partitions) - def get_expected_dataframe( # noqa: WPS463 + def get_expected_dataframe( self, schema: str, table: str, @@ -193,17 +196,13 @@ def json_deserialize( """Deserializes dataframe's "value" column from JSON to struct""" from pyspark.sql.functions import col, from_json - df = df.select( + return df.select( from_json(col=col("value").cast("string"), schema=df_schema).alias("value"), ).select("value.*") - return df # noqa: WPS331 - def json_serialize(self, df: SparkDataFrame) -> SparkDataFrame: """Serializes dataframe's columns into JSON "value" field""" from pyspark.sql.functions import col, struct, to_json - df = df.select(struct(*df.columns).alias("value")) - df = df.select(to_json(col("value")).alias("value")) - - return df # noqa: WPS331 + intermediate_df = df.select(struct(*df.columns).alias("value")) + return intermediate_df.select(to_json(col("value")).alias("value")) diff --git a/tests/fixtures/processing/mongodb.py b/tests/fixtures/processing/mongodb.py index 61007e191..fd442febf 100644 --- a/tests/fixtures/processing/mongodb.py +++ b/tests/fixtures/processing/mongodb.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from logging import getLogger from random import randint +from typing import ClassVar from urllib.parse import quote import pandas @@ -17,7 +18,7 @@ class MongoDBProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "_id": "", "text_string": "", "hwm_int": "", @@ -28,7 +29,7 @@ class MongoDBProcessing(BaseProcessing): "float_value": "", } - column_names: list = ["_id", "text_string", "hwm_int", "hwm_datetime", "float_value"] + column_names: ClassVar[list] = ["_id", "text_string", "hwm_int", "hwm_datetime", "float_value"] def __enter__(self): self.connection = self.get_conn() @@ -137,13 +138,13 @@ def current_datetime() -> datetime: return datetime.now() def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas.DataFrame: - max_id = self._df_max_length if not max_id else max_id + max_id = max_id if max_id else self._df_max_length time_multiplier = 100000 values = defaultdict(list) for i in range(min_id, max_id + 1): - for column_name in self.column_names: - column_name = column_name.lower() + for raw_column_name in self.column_names: + column_name = raw_column_name.lower() if column_name == "_id" or "int" in column_name: values[column_name].append(i) @@ -152,7 +153,7 @@ def create_pandas_df(self, min_id: int = 1, max_id: int | None = None) -> pandas elif "text" in column_name: values[column_name].append(secrets.token_hex(16)) elif "datetime" in column_name: - rand_second = randint(0, i * time_multiplier) # noqa: S311 + rand_second = randint(0, i * time_multiplier) now = self.current_datetime() + timedelta(seconds=rand_second) # In the case that after rounding the result # will not be in the range from 0 to 999999 diff --git a/tests/fixtures/processing/mssql.py b/tests/fixtures/processing/mssql.py index 47030fb76..089507991 100644 --- a/tests/fixtures/processing/mssql.py +++ b/tests/fixtures/processing/mssql.py @@ -3,6 +3,7 @@ import os from datetime import datetime from logging import getLogger +from typing import ClassVar from urllib.parse import quote import pandas @@ -15,7 +16,7 @@ class MSSQLProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "INT", "text_string": "VARCHAR(50)", "hwm_int": "INT", diff --git a/tests/fixtures/processing/mysql.py b/tests/fixtures/processing/mysql.py index b8b015bf9..519811673 100644 --- a/tests/fixtures/processing/mysql.py +++ b/tests/fixtures/processing/mysql.py @@ -2,6 +2,7 @@ import os from logging import getLogger +from typing import ClassVar from urllib.parse import quote import pandas @@ -14,7 +15,7 @@ class MySQLProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "INT NOT NULL", "text_string": "VARCHAR(50)", "hwm_int": "INT", diff --git a/tests/fixtures/processing/oracle.py b/tests/fixtures/processing/oracle.py index 0af8ecf39..e013887fe 100644 --- a/tests/fixtures/processing/oracle.py +++ b/tests/fixtures/processing/oracle.py @@ -2,6 +2,7 @@ import os from logging import getLogger +from typing import ClassVar from urllib.parse import quote import cx_Oracle @@ -14,7 +15,7 @@ class OracleProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "INTEGER NOT NULL", "text_string": "VARCHAR2(50) NOT NULL", "hwm_int": "INTEGER", diff --git a/tests/fixtures/processing/postgres.py b/tests/fixtures/processing/postgres.py index 3bc7fb61f..2e9d66e12 100644 --- a/tests/fixtures/processing/postgres.py +++ b/tests/fixtures/processing/postgres.py @@ -2,6 +2,7 @@ import os from logging import getLogger +from typing import ClassVar from urllib.parse import quote import pandas @@ -15,7 +16,7 @@ class PostgresProcessing(BaseProcessing): - _column_types_and_names_matching = { + _column_types_and_names_matching: ClassVar[dict[str, str]] = { "id_int": "serial primary key", "text_string": "text", "hwm_int": "bigint", diff --git a/tests/fixtures/spark.py b/tests/fixtures/spark.py index 1fc1d9b8b..c4096af28 100644 --- a/tests/fixtures/spark.py +++ b/tests/fixtures/spark.py @@ -32,7 +32,7 @@ def ivysettings_path(): @pytest.fixture(scope="session") -def maven_packages(request): +def maven_packages(request): # noqa: C901, PLR0912 import pyspark from onetl.connection import ( @@ -163,7 +163,7 @@ def spark( from pyspark.sql import SparkSession spark_builder = ( - SparkSession.builder.config("spark.app.name", "onetl") # noqa: WPS221 + SparkSession.builder.config("spark.app.name", "onetl") .config("spark.master", "local[*]") .config("spark.jars.packages", ",".join(maven_packages)) .config("spark.jars.excludes", ",".join(excluded_packages)) diff --git a/tests/libs/failing/failing.py b/tests/libs/failing/failing.py index 40d3912ae..328bd6fcd 100644 --- a/tests/libs/failing/failing.py +++ b/tests/libs/failing/failing.py @@ -1 +1,2 @@ -raise RuntimeError("something went wrong") +msg = "something went wrong" +raise RuntimeError(msg) diff --git a/tests/resources/file_df_connection/__init__.py b/tests/resources/file_df_connection/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/resources/file_df_connection/generate_files.py b/tests/resources/file_df_connection/generate_files.py index 2417cb15d..0345d8a88 100755 --- a/tests/resources/file_df_connection/generate_files.py +++ b/tests/resources/file_df_connection/generate_files.py @@ -6,7 +6,7 @@ import gzip import io import json -import os +import logging import random import shutil import sys @@ -15,17 +15,20 @@ from datetime import date, datetime, timezone from pathlib import Path from tempfile import gettempdir -from typing import TYPE_CHECKING, Any, Iterator, TextIO -from xml.etree import ElementTree # noqa: S405 +from typing import TYPE_CHECKING, Any, TextIO +from xml.etree import ElementTree as ET from zipfile import ZipFile if TYPE_CHECKING: + from collections.abc import Iterator + from avro.schema import Schema as AvroSchema from pandas import DataFrame as PandasDataFrame from pyarrow import Schema as ArrowSchema from pyarrow import Table as ArrowTable SEED = 42 +logger = logging.getLogger(__name__) def get_data() -> list[dict]: @@ -144,7 +147,7 @@ def _to_string(obj): def _write_csv(data: list[dict], file: TextIO, header: bool = False, **kwargs) -> None: columns = list(data[0].keys()) - writer = csv.DictWriter(file, fieldnames=columns, **kwargs) + writer = csv.DictWriter(file, fieldnames=columns, lineterminator="\n", **kwargs) if header: writer.writeheader() @@ -155,19 +158,19 @@ def _write_csv(data: list[dict], file: TextIO, header: bool = False, **kwargs) - def save_as_csv_without_header(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - with open(path / "file.csv", "w", newline="") as file: + with path.joinpath("file.csv").open("w", newline="") as file: _write_csv(data, file) def save_as_csv_with_header(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - with open(path / "file.csv", "w", newline="") as file: + with path.joinpath("file.csv").open("w", newline="") as file: _write_csv(data, file, header=True) def save_as_csv_with_delimiter(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - with open(path / "file.csv", "w", newline="") as file: + with path.joinpath("file.csv").open("w", newline="") as file: _write_csv(data, file, delimiter=";") @@ -178,9 +181,8 @@ def save_as_csv_gz(data: list[dict], path: Path) -> None: # Instead of just writing data to file we write it to a buffer, and then compress with fixed mtime buffer = io.StringIO() _write_csv(data, buffer) - with open(path / "file.csv.gz", "wb") as file: - with gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: - gzfile.write(buffer.getvalue().encode("utf-8")) + with path.joinpath("file.csv.gz").open("wb") as file, gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: + gzfile.write(buffer.getvalue().encode("utf-8")) def save_as_csv_nested(data: list[dict], path: Path) -> None: @@ -188,15 +190,15 @@ def save_as_csv_nested(data: list[dict], path: Path) -> None: path.joinpath("some/path/more").mkdir(parents=True, exist_ok=True) path.joinpath("some/path/more/even_more").mkdir(parents=True, exist_ok=True) - with open(path / "some/path/for_val1.csv", "w", newline="") as file: + with path.joinpath("some/path/for_val1.csv").open("w", newline="") as file: data_for_val1 = [row for row in data if row["str_value"] == "val1"] _write_csv(data_for_val1, file) - with open(path / "some/path/more/for_val2.csv", "w", newline="") as file: + with path.joinpath("some/path/more/for_val2.csv").open("w", newline="") as file: data_for_val2 = [row for row in data if row["str_value"] == "val2"] _write_csv(data_for_val2, file) - with open(path / "some/path/more/even_more/for_val3.csv", "w", newline="") as file: + with path.joinpath("some/path/more/even_more/for_val3.csv").open("w", newline="") as file: data_for_val3 = [row for row in data if row["str_value"] == "val3"] _write_csv(data_for_val3, file) @@ -217,15 +219,15 @@ def filter_and_drop(rows: list[dict], column: str, value: Any) -> list[dict]: columns = list(data[0].keys()) columns.remove("str_value") - with open(path / "str_value=val1/file.csv", "w", newline="") as file: + with path.joinpath("str_value=val1/file.csv").open("w", newline="") as file: data_for_val1 = filter_and_drop(data, "str_value", "val1") _write_csv(data_for_val1, file) - with open(path / "str_value=val2/file.csv", "w", newline="") as file: + with path.joinpath("str_value=val2/file.csv").open("w", newline="") as file: data_for_val2 = filter_and_drop(data, "str_value", "val2") _write_csv(data_for_val2, file) - with open(path / "str_value=val3/file.csv", "w", newline="") as file: + with path.joinpath("str_value=val3/file.csv").open("w", newline="") as file: data_for_val3 = filter_and_drop(data, "str_value", "val3") _write_csv(data_for_val3, file) @@ -251,9 +253,8 @@ def save_as_json_gz(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) buffer = io.StringIO() json.dump(data, buffer, default=_to_string) - with open(path / "file.json.gz", "wb") as file: - with gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: - gzfile.write(buffer.getvalue().encode("utf-8")) + with path.joinpath("file.json.gz").open("wb") as file, gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: + gzfile.write(buffer.getvalue().encode("utf-8")) def save_as_json(data: list[dict], path: Path) -> None: @@ -266,10 +267,10 @@ def save_as_json(data: list[dict], path: Path) -> None: def save_as_jsonline_plain(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - with open(path / "file.jsonl", "w") as file: + with path.joinpath("file.jsonl").open("w") as file: for row in data: row_str = json.dumps(row, default=_to_string) - file.write(row_str + os.linesep) + file.write(row_str + "\n") def save_as_jsonline_gz(data: list[dict], path: Path) -> None: @@ -278,11 +279,10 @@ def save_as_jsonline_gz(data: list[dict], path: Path) -> None: buffer = io.StringIO() for row in data: row_str = json.dumps(row, default=_to_string) - buffer.write(row_str + os.linesep) + buffer.write(row_str + "\n") - with open(path / "file.jsonl.gz", "wb") as file: - with gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: - gzfile.write(buffer.getvalue().encode("utf-8")) + with path.joinpath("file.jsonl.gz").open("wb") as file, gzip.GzipFile(fileobj=file, mode="w", mtime=0) as gzfile: + gzfile.write(buffer.getvalue().encode("utf-8")) def save_as_jsonline(data: list[dict], path: Path) -> None: @@ -347,7 +347,7 @@ def save_as_parquet(data: list[dict], path: Path) -> None: def temporary_set_seed(seed: int) -> Iterator[int]: """Set random.seed to expected value, and return previous value after exit""" state = random.getstate() - try: # noqa: WPS501 + try: random.seed(seed) yield seed finally: @@ -355,33 +355,35 @@ def temporary_set_seed(seed: int) -> Iterator[int]: def save_as_avro_plain(data: list[dict], path: Path) -> None: - from avro.datafile import DataFileDFWriter + from avro.datafile import DataFileWriter from avro.io import DatumWriter path.mkdir(parents=True, exist_ok=True) schema = get_avro_schema() - with open(path / "file.avro", "wb") as file: - # DataFileDFWriter.sync_marker is initialized with randbytes - # temporary set seed to avoid generating files with different hashes - with temporary_set_seed(SEED): - with DataFileDFWriter(file, DatumWriter(), schema) as writer: - for row in data: - writer.append(row) + + # DataFileDFWriter.sync_marker is initialized with randbytes + # temporary set seed to avoid generating files with different hashes + with temporary_set_seed(SEED), path.joinpath("file.avro").open("wb") as file, DataFileWriter( + file, DatumWriter(), schema + ) as writer: + for row in data: + writer.append(row) def save_as_avro_snappy(data: list[dict], path: Path) -> None: - from avro.datafile import DataFileDFWriter + from avro.datafile import DataFileWriter from avro.io import DatumWriter path.mkdir(parents=True, exist_ok=True) schema = get_avro_schema() - with open(path / "file.snappy.avro", "wb") as file: - # DataFileDFWriter.sync_marker is initialized with randbytes - # temporary set seed to avoid generating files with different hashes - with temporary_set_seed(SEED): - with DataFileDFWriter(file, DatumWriter(), schema, codec="snappy") as writer: - for row in data: - writer.append(row) + + # DataFileDFWriter.sync_marker is initialized with randbytes + # temporary set seed to avoid generating files with different hashes + with temporary_set_seed(SEED), path.joinpath("file.snappy.avro").open("wb") as file, DataFileWriter( + file, DatumWriter(), schema, codec="snappy" + ) as writer: + for row in data: + writer.append(row) def save_as_avro(data: list[dict], path: Path) -> None: @@ -399,7 +401,7 @@ def save_as_xls_with_options( **kwargs, ) -> None: # required to register xlwt writer which supports generating .xls files - import pandas_xlwt + import pandas_xlwt # noqa: F401 path.mkdir(parents=True, exist_ok=True) file = path / "file.xls" @@ -413,15 +415,14 @@ def make_zip_deterministic(path: Path) -> None: temp_dir = gettempdir() file_copy = Path(shutil.copy(path, temp_dir)) - with ZipFile(file_copy, "r") as original_file: - with ZipFile(path, "w") as new_file: - for item in original_file.infolist(): - if item.filename == "docProps/core.xml": - # this file contains modification time, which produces files with different hashes - continue - # reset modification time of all files - item.date_time = (1980, 1, 1, 0, 0, 0) - new_file.writestr(item, original_file.read(item.filename)) + with ZipFile(file_copy, "r") as original_file, ZipFile(path, "w") as new_file: + for item in original_file.infolist(): + if item.filename == "docProps/core.xml": + # this file contains modification time, which produces files with different hashes + continue + # reset modification time of all files + item.date_time = (1980, 1, 1, 0, 0, 0) + new_file.writestr(item, original_file.read(item.filename)) def save_as_xlsx_with_options( @@ -475,56 +476,56 @@ def save_as_xls(data: list[dict], path: Path) -> None: def save_as_xml_plain(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - root = ElementTree.Element("root") + root = ET.Element("root") for record in data: - item = ElementTree.SubElement(root, "item") + item = ET.SubElement(root, "item") for key, value in record.items(): - child = ElementTree.SubElement(item, key) + child = ET.SubElement(item, key) if isinstance(value, datetime): child.text = value.isoformat() else: child.text = str(value) - tree = ElementTree.ElementTree(root) + tree = ET.ElementTree(root) tree.write(path / "file.xml") def save_as_xml_with_attributes(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - root = ElementTree.Element("root") + root = ET.Element("root") for record in data: str_attributes = { key: value.isoformat() if isinstance(value, datetime) else str(value) for key, value in record.items() } - item = ElementTree.SubElement(root, "item", attrib=str_attributes) + item = ET.SubElement(root, "item", attrib=str_attributes) for key, value in record.items(): - child = ElementTree.SubElement(item, key) + child = ET.SubElement(item, key) if isinstance(value, datetime): child.text = value.isoformat() else: child.text = str(value) - tree = ElementTree.ElementTree(root) + tree = ET.ElementTree(root) tree.write(str(path / "file_with_attributes.xml")) def save_as_xml_gz(data: list[dict], path: Path) -> None: path.mkdir(parents=True, exist_ok=True) - root = ElementTree.Element("root") + root = ET.Element("root") for record in data: - item = ElementTree.SubElement(root, "item") + item = ET.SubElement(root, "item") for key, value in record.items(): - child = ElementTree.SubElement(item, key) + child = ET.SubElement(item, key) if isinstance(value, datetime): child.text = value.isoformat() else: child.text = str(value) - ElementTree.ElementTree(root) - xml_string = ElementTree.tostring(root, encoding="utf-8") + ET.ElementTree(root) + xml_string = ET.tostring(root, encoding="utf-8") with gzip.open(path / "file.xml.gz", "wb", compresslevel=9) as f: f.write(xml_string) @@ -565,7 +566,8 @@ def main(argv: list[str] | None = None) -> None: args = parser.parse_args(argv or sys.argv[1:]) if args.format not in format_mapping and args.format != "all": - raise ValueError(f"Format {args.format} is not supported") + msg = f"Format {args.format} is not supported" + raise ValueError(msg) data = get_data() if args.format == "all": diff --git a/tests/tests_integration/__init__.py b/tests/tests_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_file_df_connection_integration/__init__.py b/tests/tests_integration/test_file_df_connection_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py b/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py index 2962aa23b..d66b78def 100644 --- a/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py +++ b/tests/tests_integration/test_file_df_connection_integration/test_spark_hdfs_integration.py @@ -33,9 +33,8 @@ def test_spark_hdfs_file_connection_check_failed(spark): spark=spark, ) - with wrong_hdfs: - with pytest.raises(RuntimeError, match="Connection is unavailable"): - wrong_hdfs.check() + with wrong_hdfs, pytest.raises(RuntimeError, match="Connection is unavailable"): + wrong_hdfs.check() def test_spark_hdfs_file_connection_check_with_hooks(spark, request, hdfs_server): @@ -58,7 +57,7 @@ def is_namenode_active(host: str, cluster: str) -> bool: with pytest.raises( RuntimeError, - match="Host 'some-node2.domain.com' is not an active namenode of cluster 'rnd-dwh'", + match=r"Host 'some-node2\.domain\.com' is not an active namenode of cluster 'rnd-dwh'", ): SparkHDFS( cluster="rnd-dwh", diff --git a/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py b/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py index 51d81295b..38d6b0f54 100644 --- a/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py +++ b/tests/tests_integration/test_file_df_connection_integration/test_spark_s3_integration.py @@ -42,9 +42,8 @@ def test_spark_s3_check_failed(spark, s3_server): spark=spark, ) - with wrong_s3: - with pytest.raises(RuntimeError, match="Connection is unavailable"): - wrong_s3.check() + with wrong_s3, pytest.raises(RuntimeError, match="Connection is unavailable"): + wrong_s3.check() def test_spark_s3_check_hadoop_config_reset(spark, s3_server, caplog): diff --git a/tests/tests_integration/test_file_format_integration/__init__.py b/tests/tests_integration/test_file_format_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_file_format_integration/test_avro_integration.py b/tests/tests_integration/test_file_format_integration/test_avro_integration.py index 1e701029e..15f189952 100644 --- a/tests/tests_integration/test_file_format_integration/test_avro_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_avro_integration.py @@ -40,7 +40,7 @@ def avro_schema(): @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {}), ("with_compression", {"compression": "snappy"}), @@ -148,7 +148,7 @@ def test_avro_serialize_and_parse_no_schema( with pytest.raises( ValueError, - match="Avro.parse_column can be used only with defined `avroSchema` or `avroSchemaUrl`", + match=r"Avro\.parse_column can be used only with defined `avroSchema` or `avroSchemaUrl`", ): serialized_df.select(avro.parse_column(column_type("combined"))) diff --git a/tests/tests_integration/test_file_format_integration/test_csv_integration.py b/tests/tests_integration/test_file_format_integration/test_csv_integration.py index c58c534aa..1917db77f 100644 --- a/tests/tests_integration/test_file_format_integration/test_csv_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_csv_integration.py @@ -58,7 +58,7 @@ def test_csv_reader_with_infer_schema( @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("header", True), ("delimiter", ";"), @@ -91,7 +91,7 @@ def test_csv_reader_with_options( @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("header", "True"), ("delimiter", ";"), @@ -133,7 +133,7 @@ def test_csv_writer_with_options( @pytest.mark.parametrize( - "csv_string, schema, options, expected", + ("csv_string", "schema", "options", "expected"), [ ( "1,Anne", @@ -167,7 +167,7 @@ def test_csv_parse_column(spark, csv_string, schema, options, expected, column_t @pytest.mark.parametrize( - "data, schema, options, expected_csv", + ("data", "schema", "options", "expected_csv"), [ ( Row(id=1, name="Alice"), diff --git a/tests/tests_integration/test_file_format_integration/test_excel_integration.py b/tests/tests_integration/test_file_format_integration/test_excel_integration.py index 769722818..8e6f20ead 100644 --- a/tests/tests_integration/test_file_format_integration/test_excel_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_excel_integration.py @@ -56,7 +56,7 @@ def test_excel_reader_with_infer_schema( @pytest.mark.parametrize("format", ["xlsx", "xls"]) @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_header", {}), ("with_header", {"header": True}), diff --git a/tests/tests_integration/test_file_format_integration/test_json_integration.py b/tests/tests_integration/test_file_format_integration/test_json_integration.py index 896d14e72..3e579472f 100644 --- a/tests/tests_integration/test_file_format_integration/test_json_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_json_integration.py @@ -29,7 +29,7 @@ @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {}), ("with_compression", {"compression": "gzip"}), @@ -76,7 +76,7 @@ def test_json_writer_is_not_supported( @pytest.mark.parametrize( - "json_row, schema, expected_row", + ("json_row", "schema", "expected_row"), [ ( Row(json_column='{"id": 1, "name": "Alice"}'), @@ -105,7 +105,7 @@ def test_json_parse_column(spark, json_row, schema, expected_row, column_type): @pytest.mark.parametrize( - "row, expected_row", + ("row", "expected_row"), [ ( Row(json_column=Row(id=1, name="Alice")), @@ -160,7 +160,8 @@ def test_json_parse_column_unsupported_options_warning(spark): dropFieldIfAllNull=True, ) msg = ( - "Options `['dropFieldIfAllNull', 'encoding', 'lineSep', 'prefersDecimal', 'primitivesAsString', 'samplingRatio']` " + "Options `['dropFieldIfAllNull', 'encoding', 'lineSep', " + "'prefersDecimal', 'primitivesAsString', 'samplingRatio']` " "are set but not supported in `JSON.parse_column` or `JSON.serialize_column`." ) diff --git a/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py b/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py index a246843e4..a2ac481a8 100644 --- a/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_jsonline_integration.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {}), ("with_compression", {"compression": "gzip"}), diff --git a/tests/tests_integration/test_file_format_integration/test_orc_integration.py b/tests/tests_integration/test_file_format_integration/test_orc_integration.py index 40902cef9..13a6814d5 100644 --- a/tests/tests_integration/test_file_format_integration/test_orc_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_orc_integration.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {}), ("with_compression", {"compression": "snappy"}), diff --git a/tests/tests_integration/test_file_format_integration/test_parquet_integration.py b/tests/tests_integration/test_file_format_integration/test_parquet_integration.py index ea5844eb1..2393edb6f 100644 --- a/tests/tests_integration/test_file_format_integration/test_parquet_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_parquet_integration.py @@ -18,7 +18,7 @@ @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {}), ("with_compression", {"compression": "snappy"}), diff --git a/tests/tests_integration/test_file_format_integration/test_xml_integration.py b/tests/tests_integration/test_file_format_integration/test_xml_integration.py index 162651e5c..9150cb531 100644 --- a/tests/tests_integration/test_file_format_integration/test_xml_integration.py +++ b/tests/tests_integration/test_file_format_integration/test_xml_integration.py @@ -30,7 +30,7 @@ def expected_xml_attributes_df(file_df_dataframe): @pytest.mark.parametrize( - "path, options", + ("path", "options"), [ ("without_compression", {"rowTag": "item"}), ("with_compression", {"rowTag": "item", "compression": "gzip"}), @@ -82,7 +82,8 @@ def test_xml_reader_with_infer_schema( assert read_df.count() assert read_df.schema != df.schema - # "DataFrames have different column types: StructField('id', IntegerType(), True), StructField('id', LongType(), True), etc." + # "DataFrames have different column types: + # StructField('id', IntegerType(), True), StructField('id', LongType(), True), etc." assert set(read_df.columns) == set(expected_xml_attributes_df.columns) assert_equal_df(read_df, expected_xml_attributes_df, order_by="id") diff --git a/tests/tests_integration/test_metrics/__init__.py b/tests/tests_integration/test_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_file_df.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_file_df.py index 4ee96b96f..772385e5c 100644 --- a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_file_df.py +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_file_df.py @@ -201,7 +201,8 @@ def test_spark_metrics_recorder_file_df_writer_executor_failed( @udf(returnType=IntegerType()) def raise_exception(): - raise ValueError("Force task failure") + msg = "Force task failure" + raise ValueError(msg) local_fs, target_path = local_fs_file_df_connection_with_path diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py index 79116a3eb..d69d254b6 100644 --- a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_hive.py @@ -166,7 +166,8 @@ def test_spark_metrics_recorder_hive_write_executor_failed(spark, processing, ge @udf(returnType=IntegerType()) def raise_exception(): - raise ValueError("Force task failure") + msg = "Force task failure" + raise ValueError(msg) failing_df = df.select(raise_exception().alias("some")) diff --git a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py index 04913a30d..375e2df5b 100644 --- a/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py +++ b/tests/tests_integration/test_metrics/test_spark_metrics_recorder_postgres.py @@ -206,7 +206,8 @@ def test_spark_metrics_recorder_postgres_write_executor_failed(spark, processing @udf(returnType=IntegerType()) def raise_exception(): - raise ValueError("Force task failure") + msg = "Force task failure" + raise ValueError(msg) df = processing.create_spark_df(spark).limit(0) failing_df = df.select(raise_exception().alias("some")) diff --git a/tests/tests_integration/tests_core_integration/__init__.py b/tests/tests_integration/tests_core_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/__init__.py b/tests/tests_integration/tests_core_integration/test_file_df_reader_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/__init__.py b/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py b/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py index 0b64ad58d..902d349a4 100644 --- a/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_df_writer_integration/test_common_file_df_writer_integration.py @@ -257,7 +257,7 @@ def test_file_df_writer_run_if_exists_replace_overlapping_partitions_to_overlapp @pytest.mark.parametrize( - "original_options, new_options, real_df_schema", + ("original_options", "new_options", "real_df_schema"), [ pytest.param( {}, @@ -512,5 +512,7 @@ def test_file_df_writer_with_streaming_df( streaming_df = spark.readStream.format("rate").load() assert streaming_df.isStreaming - with pytest.raises(ValueError, match="DataFrame is streaming. FileDFWriter supports only batch DataFrames."): + + msg = r"DataFrame is streaming\. FileDFWriter supports only batch DataFrames\." + with pytest.raises(ValueError, match=msg): writer.run(streaming_df) diff --git a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py index 25bab3514..176858641 100644 --- a/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_downloader_integration.py @@ -591,7 +591,7 @@ def test_file_downloader_run_absolute_path_not_match_source_path( file_connection_with_path_and_files, tmp_path_factory, ): - file_connection, remote_path, uploaded_files = file_connection_with_path_and_files + file_connection, remote_path, _ = file_connection_with_path_and_files local_path = tmp_path_factory.mktemp("local_path") downloader = FileDownloader( @@ -824,7 +824,7 @@ def finalizer(): missing_file = target_path / "missing" with caplog.at_level(logging.WARNING): - download_result = downloader.run(uploaded_files + [missing_file]) + download_result = downloader.run([*uploaded_files, missing_file]) assert f"Missing file '{missing_file}', skipping" in caplog.text @@ -1022,9 +1022,8 @@ def test_file_downloader_detect_hwm_type_snapshot_batch_strategy( ) error_message = "FileDownloader(hwm=...) cannot be used with SnapshotBatchStrategy" - with pytest.raises(ValueError, match=re.escape(error_message)): - with SnapshotBatchStrategy(step=100500): - downloader.run() + with pytest.raises(ValueError, match=re.escape(error_message)), SnapshotBatchStrategy(step=100500): + downloader.run() @pytest.mark.parametrize("hwm_type", SUPPORTED_HWM_TYPES) @@ -1044,11 +1043,10 @@ def test_file_downloader_detect_hwm_type_incremental_batch_strategy( ) error_message = "FileDownloader(hwm=...) cannot be used with IncrementalBatchStrategy" - with pytest.raises(ValueError, match=re.escape(error_message)): - with IncrementalBatchStrategy( - step=timedelta(days=5), - ): - downloader.run() + with pytest.raises(ValueError, match=re.escape(error_message)), IncrementalBatchStrategy( + step=timedelta(days=5), + ): + downloader.run() @pytest.mark.parametrize("hwm_type", SUPPORTED_HWM_TYPES) @@ -1089,9 +1087,8 @@ def test_file_downloader_file_hwm_strategy_with_wrong_parameters( ) error_message = "FileDownloader(hwm=...) cannot be used with IncrementalStrategy(offset=1, ...)" - with pytest.raises(ValueError, match=re.escape(error_message)): - with IncrementalStrategy(offset=1): - downloader.run() + with pytest.raises(ValueError, match=re.escape(error_message)), IncrementalStrategy(offset=1): + downloader.run() with IncrementalStrategy(): downloader.run() diff --git a/tests/tests_integration/tests_core_integration/test_file_mover_integration.py b/tests/tests_integration/tests_core_integration/test_file_mover_integration.py index 7f63cc991..d6e03c687 100644 --- a/tests/tests_integration/tests_core_integration/test_file_mover_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_mover_integration.py @@ -782,7 +782,7 @@ def test_file_mover_mode_replace_entire_directory( remote_dir_exist, caplog, ): - file_connection, source_path, uploaded_files = file_connection_with_path_and_files + file_connection, source_path, _ = file_connection_with_path_and_files target_path = RemotePath(f"/tmp/test_move_{secrets.token_hex(5)}") def finalizer(): @@ -836,7 +836,7 @@ def finalizer(): missing_file = target_path / "missing" with caplog.at_level(logging.WARNING): - move_result = mover.run(uploaded_files + [missing_file]) + move_result = mover.run([*uploaded_files, missing_file]) assert f"Missing file '{missing_file}', skipping" in caplog.text @@ -946,7 +946,7 @@ def finalizer(): def test_file_mover_file_limit_custom(file_connection_with_path_and_files, caplog): - file_connection, source_path, uploaded_files = file_connection_with_path_and_files + file_connection, source_path, _ = file_connection_with_path_and_files limit = 2 target_path = f"/tmp/test_move_{secrets.token_hex(5)}" diff --git a/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py b/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py index f7719bc7d..1d0927b8f 100644 --- a/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py +++ b/tests/tests_integration/tests_core_integration/test_file_uploader_integration.py @@ -174,7 +174,7 @@ def finalizer(): missing_file = PurePosixPath(f"/tmp/test_upload_{secrets.token_hex(5)}") with caplog.at_level(logging.WARNING): - upload_result = uploader.run(test_files + [missing_file]) + upload_result = uploader.run([*test_files, missing_file]) assert f"Missing file '{missing_file}', skipping" in caplog.text diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/__init__.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py index f0ede8a82..14ba5bb35 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_greenplum_reader_integration.py @@ -121,7 +121,7 @@ def test_greenplum_reader_snapshot_with_columns_duplicated(spark, processing, pr ) df2 = reader2.run() - assert df2.columns == df1.columns + ["id_int"] + assert df2.columns == [*df1.columns, "id_int"] def test_greenplum_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py index 6656241b3..cf003c20c 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_hive_reader_integration.py @@ -96,7 +96,7 @@ def test_hive_reader_snapshot_with_columns_duplicated(spark, prepare_schema_tabl ) df2 = reader2.run() - assert df2.columns == df1.columns + ["id_int"] + assert df2.columns == [*df1.columns, "id_int"] def test_hive_reader_snapshot_with_columns_mixed_naming(spark, processing, get_schema_table): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_iceberg_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_iceberg_reader_integration.py index f0d06bd1c..8690f931b 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_iceberg_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_iceberg_reader_integration.py @@ -98,7 +98,7 @@ def test_iceberg_reader_snapshot_with_columns_duplicated( ) df2 = reader2.run() - assert df2.columns == df1.columns + ["id_int"] + assert df2.columns == [*df1.columns, "id_int"] def test_iceberg_reader_snapshot_with_columns_mixed_naming( diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py index 0ee87695e..9f1761c31 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_kafka_reader_integration.py @@ -23,7 +23,7 @@ def kafka_schema(): TimestampType, ) - schema = StructType( + return StructType( [ StructField("key", BinaryType(), nullable=True), StructField("value", BinaryType(), nullable=True), @@ -34,7 +34,6 @@ def kafka_schema(): StructField("timestampType", IntegerType(), nullable=True), ], ) - return schema # noqa: WPS331 @pytest.fixture @@ -50,7 +49,7 @@ def kafka_schema_with_headers(): TimestampType, ) - schema = StructType( + return StructType( [ StructField("key", BinaryType(), nullable=True), StructField("value", BinaryType(), nullable=True), @@ -73,7 +72,6 @@ def kafka_schema_with_headers(): ), ], ) - return schema # noqa: WPS331 def test_kafka_reader(spark, processing, kafka_dataframe_schema, kafka_topic): diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py index ac35a3ced..ad6246669 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_oracle_reader_integration.py @@ -389,7 +389,7 @@ def test_oracle_reader_snapshot_with_columns_duplicated(spark, processing, prepa ], ) # https://stackoverflow.com/questions/27965130/how-to-select-column-from-table-in-oracle - with pytest.raises(Exception, match="java.sql.SQLSyntaxErrorException"): + with pytest.raises(Exception, match=r"java\.sql\.SQLSyntaxErrorException"): reader.run() diff --git a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py index 9e6b9929b..2ab1eb5e3 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_reader_integration/test_postgres_reader_integration.py @@ -39,7 +39,7 @@ def test_postgres_reader_snapshot(spark, processing, load_table_data): @pytest.mark.parametrize( - "mode, column", + ("mode", "column"), [ ("range", "id_int"), ("hash", "text_string"), diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/__init__.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_common_db_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_common_db_writer_integration.py index ef5439fbd..2d4e401ae 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_common_db_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_common_db_writer_integration.py @@ -23,5 +23,7 @@ def test_mongodb_writer_with_streaming_df(spark, processing, prepare_schema_tabl streaming_df = spark.readStream.format("rate").load() assert streaming_df.isStreaming - with pytest.raises(ValueError, match="DataFrame is streaming. DBWriter supports only batch DataFrames."): + + msg = r"DataFrame is streaming\. DBWriter supports only batch DataFrames\." + with pytest.raises(ValueError, match=msg): writer.run(streaming_df) diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py index 90df8fe95..843fe71f1 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_hive_writer_integration.py @@ -71,7 +71,7 @@ def test_hive_writer_with_options(spark, processing, get_schema_table, options): @pytest.mark.parametrize( - "options, format", + ("options", "format"), [ (Hive.WriteOptions(), "orc"), # default (Hive.WriteOptions(format="orc"), "orc"), @@ -99,7 +99,7 @@ def test_hive_writer_with_format(spark, processing, get_schema_table, options, f @pytest.mark.parametrize( - "bucket_number, bucket_columns", + ("bucket_number", "bucket_columns"), [ (10, "id_int"), (5, ["id_int", "hwm_int"]), @@ -261,7 +261,7 @@ def test_hive_writer_create_table_if_exists(spark, processing, get_schema_table, @pytest.mark.parametrize( - "options, option_kv", + ("options", "option_kv"), [ (Hive.WriteOptions(partitionBy="str"), "{'partitionBy': 'str'}"), (Hive.WriteOptions(bucketBy=(10, "id_int")), "{'bucketBy': (10, 'id_int')}"), @@ -334,7 +334,7 @@ def test_hive_writer_insert_into_with_options_ignored(spark, processing, get_sch ], ) @pytest.mark.parametrize( - "original_options, new_options", + ("original_options", "new_options"), [ pytest.param({}, {"partitionBy": "id_int"}, id="table_not_partitioned_dataframe_is"), pytest.param({"partitionBy": "text_string"}, {}, id="table_partitioned_dataframe_is_not"), @@ -403,7 +403,7 @@ def test_hive_writer_insert_into_append( ], ) @pytest.mark.parametrize( - "original_options, new_options", + ("original_options", "new_options"), [ pytest.param({}, {"partitionBy": "id_int"}, id="table_not_partitioned_dataframe_is"), pytest.param({"partitionBy": "text_string"}, {}, id="table_partitioned_dataframe_is_not"), @@ -471,7 +471,7 @@ def test_hive_writer_insert_into_ignore( ], ) @pytest.mark.parametrize( - "original_options, new_options", + ("original_options", "new_options"), [ pytest.param({}, {"partitionBy": "id_int"}, id="table_not_partitioned_dataframe_is"), pytest.param({"partitionBy": "text_string"}, {}, id="table_partitioned_dataframe_is_not"), @@ -535,7 +535,7 @@ def test_hive_writer_insert_into_error( ], ) @pytest.mark.parametrize( - "original_options, new_options", + ("original_options", "new_options"), [ pytest.param({}, {"partitionBy": "id_int"}, id="table_not_partitioned_dataframe_is"), pytest.param({"partitionBy": "text_string"}, {}, id="table_partitioned_dataframe_is_not"), diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_iceberg_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_iceberg_writer_integration.py index a6ba83658..ac9d1ee9c 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_iceberg_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_iceberg_writer_integration.py @@ -41,7 +41,7 @@ def test_iceberg_writer_with_custom_location( connection = iceberg_connection_rest_catalog_s3_warehouse df = processing.create_spark_df(spark) table = f"{connection.catalog_name}.{get_schema_table.full_name}" - location = "s3a://" + os.path.join( + location = "s3a://" + os.path.join( # noqa: PTH118 connection.warehouse.bucket, connection.warehouse.path.as_posix().lstrip("/"), get_schema_table.schema, diff --git a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py index 78f12490a..ed9a822de 100644 --- a/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py +++ b/tests/tests_integration/tests_core_integration/tests_db_writer_integration/test_kafka_writer_integration.py @@ -111,7 +111,7 @@ def test_kafka_writer_no_value_column_error(spark, kafka_processing, kafka_spark @pytest.mark.parametrize( - "column, value", + ("column", "value"), [ ("offset", 0), ("timestamp", 10000), diff --git a/tests/tests_integration/tests_db_connection_integration/__init__.py b/tests/tests_integration/tests_db_connection_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_db_connection_integration/test_iceberg_integration.py b/tests/tests_integration/tests_db_connection_integration/test_iceberg_integration.py index 8481b9d55..772ce284d 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_iceberg_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_iceberg_integration.py @@ -68,7 +68,7 @@ def test_iceberg_connection_execute_ddl( suffix, ): connection = iceberg_connection_fs_catalog_local_fs_warehouse - table_name, schema, table = get_schema_table + _table_name, schema, table = get_schema_table fields = { column_name: processing.get_column_type(column_name) for column_name in processing.column_names diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index 18bf90ff9..a67125ef3 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -997,7 +997,7 @@ def function_finalizer(): @pytest.mark.parametrize( - "options_class, options_kwargs, expected_warning", + ("options_class", "options_kwargs", "expected_warning"), [ (Postgres.ReadOptions, {"fetchsize": 5000, "sessionInitStatement": "SET timezone TO 'UTC'"}, UserWarning), (Postgres.SQLOptions, {"fetchsize": 5000, "sessionInitStatement": "SET timezone TO 'UTC'"}, None), diff --git a/tests/tests_integration/tests_file_connection_integration/__init__.py b/tests/tests_integration/tests_file_connection_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_file_connection_integration/test_file_connection_common_integration.py b/tests/tests_integration/tests_file_connection_integration/test_file_connection_common_integration.py index bbf696d84..f8cc86cbf 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_file_connection_common_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_file_connection_common_integration.py @@ -109,7 +109,7 @@ def test_file_connection_rename_dir(file_connection_with_path_and_files, path_ty def test_file_connection_rename_dir_already_exists(request, file_connection_with_path_and_files): - file_connection, remote_path, upload_files = file_connection_with_path_and_files + file_connection, remote_path, _ = file_connection_with_path_and_files if not isinstance(file_connection, SupportsRenameDir): # S3 does not have directories return @@ -195,7 +195,7 @@ def test_file_connection_read_bytes(file_connection_with_path_and_files): @pytest.mark.parametrize( - "pass_real_path, exception", + ("pass_real_path", "exception"), [(True, NotAFileError), (False, FileNotFoundError)], ) def test_file_connection_read_text_negative( @@ -211,7 +211,7 @@ def test_file_connection_read_text_negative( @pytest.mark.parametrize( - "pass_real_path, exception", + ("pass_real_path", "exception"), [(True, NotAFileError), (False, FileNotFoundError)], ) def test_file_connection_read_bytes_negative( @@ -345,7 +345,7 @@ def test_file_connection_upload_file(file_connection, file_connection_test_files @pytest.mark.parametrize( - "path,exception", + ("path", "exception"), [ ("exclude_dir/", NotAFileError), ("exclude_dir/file_not_exists", FileNotFoundError), @@ -385,7 +385,7 @@ def test_file_connection_download_file_wrong_target_type( @pytest.mark.parametrize( - "source,exception", + ("source", "exception"), [("exclude_dir", NotAFileError), ("missing", FileNotFoundError)], ids=["directory", "missing"], ) diff --git a/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py index 69ae359cd..3b6e6b265 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_hdfs_file_connection_integration.py @@ -102,12 +102,12 @@ def is_namenode_active(host: str, cluster: str | None) -> bool: HDFS(host=hdfs_server.host, webhdfs_port=hdfs_server.webhdfs_port).check() # no exception - with pytest.raises(RuntimeError, match="Host 'some-node2.domain.com' is not an active namenode"): + with pytest.raises(RuntimeError, match=r"Host 'some-node2\.domain\.com' is not an active namenode"): HDFS(host="some-node2.domain.com").check() with pytest.raises( RuntimeError, - match="Host 'some-node2.domain.com' is not an active namenode of cluster 'rnd-dwh'", + match=r"Host 'some-node2\.domain\.com' is not an active namenode of cluster 'rnd-dwh'", ): HDFS(host="some-node2.domain.com", cluster="rnd-dwh").check() diff --git a/tests/tests_integration/tests_file_connection_integration/test_samba_file_connection_integration.py b/tests/tests_integration/tests_file_connection_integration/test_samba_file_connection_integration.py index 7c5c8f5d5..154ac3d0c 100644 --- a/tests/tests_integration/tests_file_connection_integration/test_samba_file_connection_integration.py +++ b/tests/tests_integration/tests_file_connection_integration/test_samba_file_connection_integration.py @@ -35,9 +35,8 @@ def test_samba_file_connection_check_not_existing_share_failed(samba_server, cap password=samba_server.password, ) - with caplog.at_level(logging.INFO): - with pytest.raises(RuntimeError, match="Connection is unavailable"): - samba.check() + with caplog.at_level(logging.INFO), pytest.raises(RuntimeError, match="Connection is unavailable"): + samba.check() assert f"Share '{not_existing_share}' not found among existing shares" in caplog.text diff --git a/tests/tests_integration/tests_hwm_store_integration.py b/tests/tests_integration/tests_hwm_store_integration.py index 1b589de2d..213911478 100644 --- a/tests/tests_integration/tests_hwm_store_integration.py +++ b/tests/tests_integration/tests_hwm_store_integration.py @@ -11,7 +11,7 @@ hwm_store = [ MemoryHWMStore(), - YAMLHWMStore(path=tempfile.mktemp("hwmstore")), # noqa: S306 # nosec + YAMLHWMStore(path=tempfile.mktemp("hwmstore")), # nosec ] diff --git a/tests/tests_integration/tests_strategy_integration/__init__.py b/tests/tests_integration/tests_strategy_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py index e72b91e83..fe4481c7c 100644 --- a/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py +++ b/tests/tests_integration/tests_strategy_integration/test_strategy_incremental_batch.py @@ -44,9 +44,8 @@ def test_postgres_strategy_incremental_batch_outside_loop( ) error_msg = "Invalid IncrementalBatchStrategy usage!" - with pytest.raises(RuntimeError, match=re.escape(error_msg)): - with IncrementalBatchStrategy(step=1): - reader.run() + with pytest.raises(RuntimeError, match=re.escape(error_msg)), IncrementalBatchStrategy(step=1): + reader.run() def test_postgres_strategy_incremental_batch_where(spark, processing, prepare_schema_table): @@ -157,19 +156,19 @@ def test_postgres_strategy_incremental_batch_hwm_set_twice( with pytest.raises( ValueError, - match="Detected wrong IncrementalBatchStrategy usage.", + match="Detected wrong IncrementalBatchStrategy usage", ): reader2.run() with pytest.raises( ValueError, - match="Detected wrong IncrementalBatchStrategy usage.", + match="Detected wrong IncrementalBatchStrategy usage", ): reader3.run() @pytest.mark.parametrize( - "hwm_column, new_type, step", + ("hwm_column", "new_type", "step"), [ ("hwm_int", "date", 200), ("hwm_date", "integer", timedelta(days=20)), @@ -213,7 +212,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_type_in_store( processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) processing.create_table(schema=load_table_data.schema, table=load_table_data.table, fields=new_fields) - with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): + with pytest.raises(TypeError, match=r"Cannot cast HWM of type .* as .*"): with IncrementalBatchStrategy(step=step) as batches: for _ in batches: reader.run() @@ -293,7 +292,7 @@ def test_postgres_strategy_incremental_batch_different_hwm_optional_attribute_in @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", 1.5), ("hwm_int", "abc"), @@ -361,14 +360,13 @@ def test_postgres_strategy_incremental_batch_wrong_step_type( values=second_span, ) - with pytest.raises((TypeError, ValueError)): - with IncrementalBatchStrategy(step=step) as part: - for _ in part: - reader.run() + with pytest.raises((TypeError, ValueError)), IncrementalBatchStrategy(step=step) as part: + for _ in part: + reader.run() @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", -10), ("hwm_date", timedelta(days=-10)), @@ -431,14 +429,13 @@ def test_postgres_strategy_incremental_batch_step_negative( ) error_msg = "HWM value is not increasing, please check options passed to IncrementalBatchStrategy" - with pytest.raises(ValueError, match=error_msg): - with IncrementalBatchStrategy(step=step) as part: - for _ in part: - reader.run() + with pytest.raises(ValueError, match=error_msg), IncrementalBatchStrategy(step=step) as part: + for _ in part: + reader.run() @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", 0.01), ("hwm_date", timedelta(days=1)), @@ -501,15 +498,14 @@ def test_postgres_strategy_incremental_batch_step_too_small( ) error_msg = f"step={step!r} parameter of IncrementalBatchStrategy leads to generating too many iterations" - with pytest.raises(ValueError, match=re.escape(error_msg)): - with IncrementalBatchStrategy(step=step) as batches: - for _ in batches: - reader.run() + with pytest.raises(ValueError, match=re.escape(error_msg)), IncrementalBatchStrategy(step=step) as batches: + for _ in batches: + reader.run() @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column, step, per_iter", + ("hwm_type", "hwm_column", "step", "per_iter"), [ (ColumnIntHWM, "hwm_int", 20, 30), # step < per_iter (ColumnIntHWM, "hwm_int", 30, 30), # step == per_iter @@ -518,7 +514,7 @@ def test_postgres_strategy_incremental_batch_step_too_small( ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (50, 100), # step < gap < span_length (50, 40), # step < gap > span_length @@ -653,7 +649,7 @@ def test_postgres_strategy_incremental_batch( @pytest.mark.parametrize( - "hwm_column, step, stop", + ("hwm_column", "step", "stop"), [ ("hwm_int", 10, 50), # step < stop ("hwm_int", 50, 10), # step > stop @@ -721,7 +717,7 @@ def test_postgres_strategy_incremental_batch_stop( @pytest.mark.parametrize( - "span_gap, span_length, hwm_column, step, offset, full", + ("span_gap", "span_length", "hwm_column", "step", "offset", "full"), [ (10, 60, "hwm_int", 100, 40 + 10 + 40 + 1, False), # step > offset, step < span_length + gap (10, 60, "hwm_int", 100, 60 + 10 + 60 + 1, True), # step < offset, step < span_length + gap diff --git a/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py b/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py index 4c8121dbc..5de7a2b7a 100644 --- a/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py +++ b/tests/tests_integration/tests_strategy_integration/test_strategy_snapshot_batch.py @@ -38,13 +38,12 @@ def test_postgres_strategy_snapshot_hwm_column_present(spark, processing, prepar ) error_message = "DBReader(hwm=...) cannot be used with SnapshotStrategy" - with SnapshotStrategy(): - with pytest.raises(RuntimeError, match=re.escape(error_message)): - reader.run() + with SnapshotStrategy(), pytest.raises(RuntimeError, match=re.escape(error_message)): + reader.run() @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", "abc"), ("hwm_int", timedelta(hours=10)), @@ -78,14 +77,13 @@ def test_postgres_strategy_snapshot_batch_wrong_step_type( hwm=DBReader.AutoDetectHWM(name=secrets.token_hex(5), expression=hwm_column), ) - with pytest.raises((TypeError, ValueError)): - with SnapshotBatchStrategy(step=step) as part: - for _ in part: - reader.run() + with pytest.raises((TypeError, ValueError)), SnapshotBatchStrategy(step=step) as part: + for _ in part: + reader.run() @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", -10), ("hwm_date", timedelta(days=-10)), @@ -116,15 +114,14 @@ def test_postgres_strategy_snapshot_batch_step_negative( ) error_msg = "HWM value is not increasing, please check options passed to SnapshotBatchStrategy" - with pytest.raises(ValueError, match=error_msg): - with SnapshotBatchStrategy(step=step) as batches: - for _ in batches: - reader.run() + with pytest.raises(ValueError, match=error_msg), SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + reader.run() @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_column, step", + ("hwm_column", "step"), [ ("hwm_int", 0.5), ("hwm_date", timedelta(days=1)), @@ -154,10 +151,9 @@ def test_postgres_strategy_snapshot_batch_step_too_small( ) error_msg = f"step={step!r} parameter of SnapshotBatchStrategy leads to generating too many iterations" - with pytest.raises(ValueError, match=re.escape(error_msg)): - with SnapshotBatchStrategy(step=step) as batches: - for _ in batches: - reader.run() + with pytest.raises(ValueError, match=re.escape(error_msg)), SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + reader.run() def test_postgres_strategy_snapshot_batch_outside_loop( @@ -181,9 +177,8 @@ def test_postgres_strategy_snapshot_batch_outside_loop( ) error_message = "Invalid SnapshotBatchStrategy usage!" - with pytest.raises(RuntimeError, match=re.escape(error_message)): - with SnapshotBatchStrategy(step=1): - reader.run() + with pytest.raises(RuntimeError, match=re.escape(error_message)), SnapshotBatchStrategy(step=1): + reader.run() def test_postgres_strategy_snapshot_batch_hwm_set_twice(spark, processing, load_table_data): @@ -223,13 +218,13 @@ def test_postgres_strategy_snapshot_batch_hwm_set_twice(spark, processing, load_ with pytest.raises( ValueError, - match="Detected wrong SnapshotBatchStrategy usage.", + match="Detected wrong SnapshotBatchStrategy usage", ): reader2.run() with pytest.raises( ValueError, - match="Detected wrong SnapshotBatchStrategy usage.", + match="Detected wrong SnapshotBatchStrategy usage", ): reader3.run() @@ -280,7 +275,7 @@ def test_postgres_strategy_snapshot_batch_where(spark, processing, prepare_schem @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_column, step, per_iter", + ("hwm_column", "step", "per_iter"), [ ( "hwm_int", @@ -296,7 +291,7 @@ def test_postgres_strategy_snapshot_batch_where(spark, processing, prepare_schem ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (50, 60), # step < gap < span_length (50, 40), # step < gap > span_length @@ -446,7 +441,7 @@ def test_postgres_strategy_snapshot_batch_ignores_hwm_value( @pytest.mark.parametrize( - "hwm_column, step, stop", + ("hwm_column", "step", "stop"), [ ("hwm_int", 10, 50), # step < stop ("hwm_int", 50, 10), # step > stop @@ -555,18 +550,18 @@ def test_postgres_strategy_snapshot_batch_handle_exception(spark, processing, pr first_df = None raise_counter = 0 - with suppress(ValueError): - with SnapshotBatchStrategy(step=step) as batches: - for _ in batches: - if first_df is None: - first_df = reader.run() - else: - first_df = first_df.union(reader.run()) - - raise_counter += step - # raise exception somewhere in the middle of the read process - if raise_counter >= span_gap + (span_length // 2): - raise ValueError("some error") + with suppress(ValueError), SnapshotBatchStrategy(step=step) as batches: + for _ in batches: + if first_df is None: + first_df = reader.run() + else: + first_df = first_df.union(reader.run()) + + raise_counter += step + # raise exception somewhere in the middle of the read process + if raise_counter >= span_gap + (span_length // 2): + msg = "some error" + raise ValueError(msg) # and then process is retried total_df = None diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/__init__.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py index 2809cdb1b..dfefe8966 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_batch_strategy_integration/test_strategy_incremental_batch_mongodb.py @@ -36,7 +36,7 @@ def df_schema(): @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column, step, per_iter", + ("hwm_type", "hwm_column", "step", "per_iter"), [ (ColumnIntHWM, "hwm_int", 20, 30), # step < per_iter (ColumnIntHWM, "hwm_int", 30, 30), # step == per_iter @@ -44,7 +44,7 @@ def df_schema(): ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (50, 100), # step < gap < span_length (50, 40), # step < gap > span_length diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/__init__.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py index 0e4fb46ef..0a5f3b610 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_clickhouse.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -208,7 +208,7 @@ def test_clickhouse_strategy_incremental_nothing_to_read(spark, processing, prep only_rerun="py4j.protocol.Py4JError", ) @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -288,7 +288,7 @@ def test_clickhouse_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py index 3497266fa..08ea18111 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_common.py @@ -21,7 +21,7 @@ @pytest.mark.parametrize( - "hwm_column, new_type", + ("hwm_column", "new_type"), [ ("hwm_int", "date"), ("hwm_date", "integer"), @@ -60,9 +60,8 @@ def test_postgres_strategy_incremental_different_hwm_type_in_store( processing.drop_table(schema=load_table_data.schema, table=load_table_data.table) processing.create_table(schema=load_table_data.schema, table=load_table_data.table, fields=new_fields) - with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): - with IncrementalStrategy(): - reader.run() + with pytest.raises(TypeError, match=r"Cannot cast HWM of type .* as .*"), IncrementalStrategy(): + reader.run() def test_postgres_strategy_incremental_different_hwm_source_in_store( @@ -92,9 +91,8 @@ def test_postgres_strategy_incremental_different_hwm_source_in_store( source=load_table_data.full_name, hwm=old_hwm, ) - with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"): - with IncrementalStrategy(): - reader.run() + with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"), IncrementalStrategy(): + reader.run() @pytest.mark.parametrize("attribute", ["expression", "description"]) @@ -170,13 +168,13 @@ def test_postgres_strategy_incremental_hwm_set_twice(spark, processing, load_tab with pytest.raises( ValueError, - match="Detected wrong IncrementalStrategy usage.", + match="Detected wrong IncrementalStrategy usage", ): reader2.run() with pytest.raises( ValueError, - match="Detected wrong IncrementalStrategy usage.", + match="Detected wrong IncrementalStrategy usage", ): reader3.run() @@ -244,7 +242,7 @@ def test_postgres_strategy_incremental_where(spark, processing, prepare_schema_t @pytest.mark.parametrize( - "span_gap, span_length, hwm_column, offset", + ("span_gap", "span_length", "hwm_column", "offset"), [ (10, 50, "hwm_int", 50 + 10 + 50 + 1), # offset > span_length + gap (50, 10, "hwm_int", 10 + 50 + 10 + 1), # offset < span_length + gap @@ -368,10 +366,10 @@ def test_postgres_strategy_incremental_handle_exception(spark, processing, prepa ) # process is failed - with suppress(ValueError): - with IncrementalStrategy(): - reader.run() - raise ValueError("some error") + with suppress(ValueError), IncrementalStrategy(): + reader.run() + msg = "some error" + raise ValueError(msg) # and then process is retried with IncrementalStrategy(): diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py index 8fe5964a6..964553a16 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_file.py @@ -126,7 +126,8 @@ def test_file_downloader_incremental_strategy_fail( available = downloader.view_files() downloaded = downloader.run() # simulating a failure after download - raise RuntimeError("some exception") + msg = "some exception" + raise RuntimeError(msg) assert len(available) == len(downloaded.successful) == 1 assert downloaded.successful[0].name == new_file_name @@ -193,9 +194,8 @@ def test_file_downloader_incremental_strategy_different_hwm_type_in_store( # HWM Store contains HWM with same name, but different type hwm_store.set_hwm(ColumnIntHWM(name=hwm_name, expression="hwm_int")) - with pytest.raises(TypeError, match="Cannot cast HWM of type .* as .*"): - with IncrementalStrategy(): - downloader.run() + with pytest.raises(TypeError, match=r"Cannot cast HWM of type .* as .*"), IncrementalStrategy(): + downloader.run() @pytest.mark.parametrize("hwm_type", SUPPORTED_HWM_TYPES) @@ -219,9 +219,8 @@ def test_file_downloader_incremental_strategy_different_hwm_directory_in_store( # HWM Store contains HWM with same name, but different directory hwm_store.set_hwm(hwm_type(name=hwm_name, directory=local_path)) - with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"): - with IncrementalStrategy(): - downloader.run() + with pytest.raises(ValueError, match="Detected HWM with different `entity` attribute"), IncrementalStrategy(): + downloader.run() @pytest.mark.parametrize("hwm_type", SUPPORTED_HWM_TYPES) @@ -295,12 +294,12 @@ def test_file_downloader_incremental_strategy_hwm_set_twice( with pytest.raises( ValueError, - match="Detected wrong IncrementalStrategy usage.", + match="Detected wrong IncrementalStrategy usage", ): downloader2.run() with pytest.raises( ValueError, - match="Detected wrong IncrementalStrategy usage.", + match="Detected wrong IncrementalStrategy usage", ): downloader3.run() diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py index 8159bca25..c3b42e92e 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_greenplum.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -212,7 +212,7 @@ def test_greenplum_strategy_incremental_nothing_to_read(spark, processing, prepa # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -305,7 +305,7 @@ def test_greenplum_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py index 40ff79b06..d363af7c0 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_hive.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -196,7 +196,7 @@ def test_hive_strategy_incremental_nothing_to_read(spark, processing, prepare_sc # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -274,7 +274,7 @@ def test_hive_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ("hwm_int", "CAST(text_string AS INT)", ColumnIntHWM, str), ("hwm_date", "CAST(text_string AS DATE)", ColumnDateHWM, lambda x: x.isoformat()), diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py index d506394d0..8977ba3bf 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_kafka.py @@ -231,7 +231,7 @@ def test_kafka_strategy_incremental_nothing_to_read( @pytest.mark.parametrize( - "initial_partitions, new_partitions", + ("initial_partitions", "new_partitions"), [ (3, 5), (5, 6), @@ -280,8 +280,10 @@ def test_kafka_strategy_incremental_with_new_partition( with IncrementalStrategy(): first_df = reader.run() - # it is crucial to save dataframe after reading as if number of partitions is altered before executing any subsequent operations, Spark fails to run them due to - # Caused by: java.lang.AssertionError: assertion failed: If startingOffsets contains specific offsets, you must specify all TopicPartitions. + # it is crucial to save dataframe after reading, as if number of partitions is altered before executing + # any subsequent operations, Spark fails to run them due to + # Caused by: java.lang.AssertionError: assertion failed: + # If startingOffsets contains specific offsets, you must specify all TopicPartitions. # Use -1 for latest, -2 for earliest. # Specified: Set(topic1, topic2) Assigned: Set(topic1, topic2, additional_topic3, additional_topic4) first_df.cache() diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py index 699e4e973..3de9ca252 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mongodb.py @@ -35,14 +35,14 @@ def df_schema(): @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateTimeHWM, "hwm_datetime"), ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -241,7 +241,7 @@ def test_mongodb_strategy_incremental_nothing_to_read( # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py index a03db2759..75e1213e4 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mssql.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -212,7 +212,7 @@ def test_mssql_strategy_incremental_nothing_to_read(spark, processing, prepare_s # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -305,7 +305,7 @@ def test_mssql_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py index e762eb0da..bd5e90432 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_mysql.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -210,7 +210,7 @@ def test_mysql_strategy_incremental_nothing_to_read(spark, processing, prepare_s # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -301,7 +301,7 @@ def test_mysql_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py index 08598b31c..c9939afbd 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_oracle.py @@ -16,7 +16,7 @@ # Do not fail in such the case @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "HWM_INT"), # there is no Date type in Oracle @@ -25,7 +25,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -226,7 +226,7 @@ def test_oracle_strategy_incremental_nothing_to_read(spark, processing, prepare_ # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("FLOAT_VALUE", ValueError, "Expression 'FLOAT_VALUE' returned values"), ("TEXT_STRING", RuntimeError, "Cannot detect HWM type for"), @@ -320,7 +320,7 @@ def test_oracle_strategy_incremental_explicit_hwm_type( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", diff --git a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py index 8d5f73705..b8093d4f0 100644 --- a/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py +++ b/tests/tests_integration/tests_strategy_integration/tests_incremental_strategy_integration/test_strategy_increment_postgres.py @@ -13,7 +13,7 @@ @pytest.mark.flaky(reruns=5) @pytest.mark.parametrize( - "hwm_type, hwm_column", + ("hwm_type", "hwm_column"), [ (ColumnIntHWM, "hwm_int"), (ColumnDateHWM, "hwm_date"), @@ -21,7 +21,7 @@ ], ) @pytest.mark.parametrize( - "span_gap, span_length", + ("span_gap", "span_length"), [ (10, 100), (10, 50), @@ -114,7 +114,7 @@ def test_postgres_strategy_incremental( @pytest.mark.parametrize( - "hwm_source, hwm_expr, hwm_type, func", + ("hwm_source", "hwm_expr", "hwm_type", "func"), [ ( "hwm_int", @@ -312,7 +312,7 @@ def test_postgres_strategy_incremental_nothing_to_read(spark, processing, prepar # Fail if HWM is Numeric, or Decimal with fractional part, or string @pytest.mark.parametrize( - "hwm_column, exception_type, error_message", + ("hwm_column", "exception_type", "error_message"), [ ("float_value", ValueError, "Expression 'float_value' returned values"), ("text_string", RuntimeError, "Cannot detect HWM type for"), @@ -385,6 +385,5 @@ def test_postgres_strategy_incremental_explicit_hwm_type( ) # incremental run - with pytest.raises(Exception, match="operator does not exist: text <= integer"): - with IncrementalStrategy(): - reader.run() + with pytest.raises(Exception, match="operator does not exist: text <= integer"), IncrementalStrategy(): + reader.run() diff --git a/tests/tests_unit/__init__.py b/tests/tests_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_db/__init__.py b/tests/tests_unit/test_db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_db/test_db_reader_unit/__init__.py b/tests/tests_unit/test_db/test_db_reader_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py index d6c44540e..55b2435ec 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_clickhouse_reader_unit.py @@ -38,7 +38,7 @@ def test_clickhouse_reader_snapshot_error_pass_df_schema(spark_mock): def test_clickhouse_reader_wrong_table_name(spark_mock): clickhouse = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=clickhouse, source="table", # Required format: source="schema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py index 7fd04f2f8..dd3929066 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_common_reader_unit.py @@ -70,7 +70,7 @@ def test_reader_invalid_columns(spark_mock, columns): @pytest.mark.parametrize( - "columns, real_columns", + ("columns", "real_columns"), [ (None, ["*"]), (["*"], ["*"]), @@ -89,7 +89,7 @@ def test_reader_valid_columns(spark_mock, columns, real_columns): @pytest.mark.parametrize( - "column, real_columns, msg", + ("column", "real_columns", "msg"), [ ( "*", @@ -118,7 +118,7 @@ def test_reader_legacy_columns(spark_mock, column, real_columns, msg): @pytest.mark.parametrize( - "hwm_column, real_hwm_expression", + ("hwm_column", "real_hwm_expression"), [ ("hwm_column", "hwm_column"), (("hwm_column", "expression"), "expression"), @@ -126,7 +126,7 @@ def test_reader_legacy_columns(spark_mock, column, real_columns, msg): ], ) def test_reader_deprecated_hwm_column(spark_mock, hwm_column, real_hwm_expression): - error_msg = 'Passing "hwm_column" in DBReader class is deprecated since version 0.10.0' + error_msg = r'Passing "hwm_column" in DBReader class is deprecated since version 0\.10\.0' with pytest.warns(UserWarning, match=error_msg): reader = DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), @@ -169,7 +169,7 @@ def test_reader_hwm_has_same_source(spark_mock): def test_reader_hwm_has_different_source(spark_mock): - error_msg = "Passed `hwm.source` is different from `source`" + error_msg = r"Passed `hwm\.source` is different from `source`" with pytest.raises(ValueError, match=error_msg): DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), @@ -183,7 +183,7 @@ def test_reader_hwm_has_different_source(spark_mock): def test_reader_no_hwm_expression(spark_mock): - with pytest.raises(ValueError, match="`hwm.expression` cannot be None"): + with pytest.raises(ValueError, match=r"`hwm\.expression` cannot be None"): DBReader( connection=Hive(cluster="rnd-dwh", spark=spark_mock), source="schema.table", @@ -192,7 +192,7 @@ def test_reader_no_hwm_expression(spark_mock): @pytest.mark.parametrize( - "alias_key, alias_value", + ("alias_key", "alias_value"), [ ("source", "test_source"), ("topic", "test_topic"), diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py index 99b473974..a69d076eb 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_greenplum_reader_unit.py @@ -46,7 +46,7 @@ def test_greenplum_reader_snapshot_error_pass_df_schema(spark_mock): def test_greenplum_reader_wrong_table_name(spark_mock): greenplum = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=greenplum, source="table", # Required format: source="schema.table" @@ -82,7 +82,7 @@ def test_greenplum_reader_wrong_where_type(spark_mock): @pytest.mark.parametrize( - ["df_partitions", "spark_config"], + ("df_partitions", "spark_config"), [ pytest.param(30, {"spark.master": "local[200]"}, id="small_df, local[200]"), pytest.param(200, {"spark.master": "local[30]"}, id="large_df, local[30]"), @@ -184,7 +184,7 @@ def test_greenplum_reader_number_of_connections_less_than_warning_threshold( @pytest.mark.parametrize( - ["df_partitions", "spark_config", "parallel_connections"], + ("df_partitions", "spark_config", "parallel_connections"), [ pytest.param(31, {"spark.master": "local[200]"}, 31, id="small_df, local[200]"), pytest.param(200, {"spark.master": "local[31]"}, 31, id="large_df, local[31]"), @@ -320,7 +320,7 @@ def test_greenplum_reader_number_of_connections_higher_than_warning_threshold( @pytest.mark.parametrize( - ["df_partitions", "spark_config", "parallel_connections"], + ("df_partitions", "spark_config", "parallel_connections"), [ pytest.param(100, {"spark.master": "local[200]"}, 100, id="large_df, local[200]"), pytest.param(200, {"spark.master": "local[100]"}, 100, id="extra_large_df, local[100]"), diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py index 2373fedd7..b262315c4 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_kafka_reader_unit.py @@ -105,7 +105,7 @@ def test_kafka_reader_invalid_hwm_column(spark_mock, hwm_expression): @pytest.mark.parametrize( - "topic, error_message", + ("topic", "error_message"), [ ("*", r"source/target=\* is not supported by Kafka. Provide a singular topic."), ("topic1, topic2", "source/target=topic1, topic2 is not supported by Kafka. Provide a singular topic."), diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py index 06a5b0912..8def0db5b 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mssql_reader_unit.py @@ -39,7 +39,7 @@ def test_mssql_reader_snapshot_error_pass_df_schema(spark_mock): def test_mssql_reader_wrong_table_name(spark_mock): mssql = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=mssql, source="table", # Required format: source="schema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py index 642902dca..8fd7dcf05 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_mysql_reader_unit.py @@ -39,7 +39,7 @@ def test_mysql_reader_snapshot_error_pass_df_schema(spark_mock): def test_mysql_reader_wrong_table_name(spark_mock): mysql = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=mysql, source="table", # Required format: source="schema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py index 1f45c790b..648e14bf2 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_oracle_reader_unit.py @@ -38,7 +38,7 @@ def test_oracle_reader_error_df_schema(spark_mock): def test_oracle_reader_wrong_table_name(spark_mock): oracle = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=oracle, source="table", # Required format: source="schema.table" diff --git a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py index f82687a2c..4fe4359bb 100644 --- a/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py +++ b/tests/tests_unit/test_db/test_db_reader_unit/test_postgres_reader_unit.py @@ -39,7 +39,7 @@ def test_postgres_reader_snapshot_error_pass_df_schema(spark_mock): def test_postgres_reader_wrong_table_name(spark_mock): postgres = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBReader( connection=postgres, source="table", # Required format: source="schema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/__init__.py b/tests/tests_unit/test_db/test_db_writer_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py index 862b2fd28..4eae79c1f 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_clickhouse_writer_unit.py @@ -9,7 +9,7 @@ def test_clickhouse_writer_wrong_table_name(spark_mock): clickhouse = Clickhouse(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=clickhouse, target="table", # Required format: target="schema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py index aac320ab2..e44f3c4a6 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_greenplum_writer_unit.py @@ -16,7 +16,7 @@ def test_greenplum_writer_wrong_table_name(spark_mock): greenplum = Greenplum(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=greenplum, target="table", # Required format: target="schema.table" @@ -24,7 +24,7 @@ def test_greenplum_writer_wrong_table_name(spark_mock): @pytest.mark.parametrize( - ["df_partitions", "spark_config"], + ("df_partitions", "spark_config"), [ pytest.param(30, {"spark.master": "local[200]"}, id="small_df, local[200]"), pytest.param(200, {"spark.master": "local[30]"}, id="large_df, local[30]"), @@ -129,7 +129,7 @@ def test_greenplum_writer_number_of_connections_less_than_warning_threshold( @pytest.mark.parametrize( - ["df_partitions", "spark_config", "parallel_connections"], + ("df_partitions", "spark_config", "parallel_connections"), [ pytest.param(31, {"spark.master": "local[200]"}, 31, id="small_df, local[200]"), pytest.param(200, {"spark.master": "local[31]"}, 31, id="large_df, local[31]"), @@ -269,7 +269,7 @@ def test_greenplum_writer_number_of_connections_higher_than_warning_threshold( @pytest.mark.parametrize( - ["df_partitions", "spark_config", "parallel_connections"], + ("df_partitions", "spark_config", "parallel_connections"), [ pytest.param(100, {"spark.master": "local[200]"}, 100, id="large_df, local[200]"), pytest.param(200, {"spark.master": "local[100]"}, 100, id="extra_large_df, local[100]"), diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py index 83fca9ecc..111a1ae95 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_hive_writer_unit.py @@ -11,7 +11,7 @@ def test_hive_writer_wrong_table_name(spark_mock): with pytest.raises( ValueError, - match="Name should be passed in `schema.name` format", + match=r"Name should be passed in `schema\.name` format", ): DBWriter( connection=hive, diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_iceberg_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_iceberg_writer_unit.py index 367a34b19..95901d52a 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_iceberg_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_iceberg_writer_unit.py @@ -8,7 +8,7 @@ def test_iceberg_writer_wrong_table_name(iceberg_mock): with pytest.raises( ValueError, - match="Name should be passed in `schema.name` format", + match=r"Name should be passed in `schema\.name` format", ): DBWriter( connection=iceberg_mock, diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py index 57fa4dbfe..f148eafe0 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mssql_writer_unit.py @@ -9,7 +9,7 @@ def test_mssql_writer_wrong_table_name(spark_mock): mssql = MSSQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=mssql, target="table", # Required format: target="schema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py index 39b6847cc..b42eaf509 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_mysql_writer_unit.py @@ -9,7 +9,7 @@ def test_mysql_writer_wrong_table_name(spark_mock): mysql = MySQL(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=mysql, target="table", # Required format: target="schema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py index 3f968abca..cd814d2eb 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_oracle_writer_unit.py @@ -9,7 +9,7 @@ def test_oracle_writer_wrong_table_name(spark_mock): oracle = Oracle(host="some_host", user="user", sid="sid", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=oracle, target="table", # Required format: target="schema.table" diff --git a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py index 7f3dc160f..fbbd5fcb5 100644 --- a/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py +++ b/tests/tests_unit/test_db/test_db_writer_unit/test_postgres_writer_unit.py @@ -9,7 +9,7 @@ def test_postgres_writer_wrong_table_name(spark_mock): postgres = Postgres(host="some_host", user="user", database="database", password="passwd", spark=spark_mock) - with pytest.raises(ValueError, match="Name should be passed in `schema.name` format"): + with pytest.raises(ValueError, match=r"Name should be passed in `schema\.name` format"): DBWriter( connection=postgres, target="table", # Required format: target="schema.table" diff --git a/tests/tests_unit/test_file/__init__.py b/tests/tests_unit/test_file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_file/test_file_df_writer_unit.py b/tests/tests_unit/test_file/test_file_df_writer_unit.py index 77a0a0fab..3250112a2 100644 --- a/tests/tests_unit/test_file/test_file_df_writer_unit.py +++ b/tests/tests_unit/test_file/test_file_df_writer_unit.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("if_exists", "append"), ("partition_by", "month"), @@ -26,7 +26,7 @@ def test_file_df_writer_options_mode_prohibited(): @pytest.mark.parametrize( - "mode, recommended", + ("mode", "recommended"), [ ("dynamic", "replace_overlapping_partitions"), ("static", "replace_entire_directory"), diff --git a/tests/tests_unit/test_file/test_file_downloader_unit.py b/tests/tests_unit/test_file/test_file_downloader_unit.py index 7bb236616..c2e7b7381 100644 --- a/tests/tests_unit/test_file/test_file_downloader_unit.py +++ b/tests/tests_unit/test_file/test_file_downloader_unit.py @@ -238,7 +238,7 @@ def test_file_downloader_limit_legacy(file_limit): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, FileExistBehavior.ERROR), ({"if_exists": "error"}, FileExistBehavior.ERROR), @@ -252,7 +252,7 @@ def test_file_downloader_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "replace_file"}, diff --git a/tests/tests_unit/test_file/test_file_filter_unit.py b/tests/tests_unit/test_file/test_file_filter_unit.py index 17e65e151..62270dacc 100644 --- a/tests/tests_unit/test_file/test_file_filter_unit.py +++ b/tests/tests_unit/test_file/test_file_filter_unit.py @@ -18,7 +18,7 @@ def test_file_filter_both_glob_and_regexp(): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="nested/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), @@ -55,7 +55,7 @@ def test_file_filter_glob(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (False, RemoteFile(path="exclude1/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), @@ -91,7 +91,7 @@ def test_file_filter_exclude_dirs_relative(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="exclude1/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), @@ -111,7 +111,7 @@ def test_file_filter_exclude_dirs_absolute(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (False, RemoteFile(path="file.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), @@ -150,7 +150,7 @@ def test_file_filter_regexp_str(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (False, RemoteFile(path="file.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), diff --git a/tests/tests_unit/test_file/test_file_mover_unit.py b/tests/tests_unit/test_file/test_file_mover_unit.py index 29c18828c..d87e16c89 100644 --- a/tests/tests_unit/test_file/test_file_mover_unit.py +++ b/tests/tests_unit/test_file/test_file_mover_unit.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, FileExistBehavior.ERROR), ({"if_exists": "error"}, FileExistBehavior.ERROR), @@ -21,7 +21,7 @@ def test_file_mover_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "replace_file"}, diff --git a/tests/tests_unit/test_file/test_file_set_unit.py b/tests/tests_unit/test_file/test_file_set_unit.py index 7242be0e9..69df336a1 100644 --- a/tests/tests_unit/test_file/test_file_set_unit.py +++ b/tests/tests_unit/test_file/test_file_set_unit.py @@ -69,7 +69,7 @@ def test_file_set(): empty_file_set = FileSet() assert not empty_file_set - assert len(empty_file_set) == 0 # noqa: WPS507 + assert len(empty_file_set) == 0 def test_file_set_details(): diff --git a/tests/tests_unit/test_file/test_file_uploader_unit.py b/tests/tests_unit/test_file/test_file_uploader_unit.py index ba39c19a4..e0ba71a5b 100644 --- a/tests/tests_unit/test_file/test_file_uploader_unit.py +++ b/tests/tests_unit/test_file/test_file_uploader_unit.py @@ -26,7 +26,7 @@ def test_file_uploader_deprecated_import(): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, FileExistBehavior.ERROR), ({"if_exists": "error"}, FileExistBehavior.ERROR), @@ -40,7 +40,7 @@ def test_file_uploader_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "replace_file"}, diff --git a/tests/tests_unit/test_file/test_filter/__init__.py b/tests/tests_unit/test_file/test_filter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_file/test_filter/test_exclude_dir.py b/tests/tests_unit/test_file/test_filter/test_exclude_dir.py index dde23a787..86e871316 100644 --- a/tests/tests_unit/test_file/test_filter/test_exclude_dir.py +++ b/tests/tests_unit/test_file/test_filter/test_exclude_dir.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (False, RemoteFile(path="exclude1/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), @@ -25,7 +25,7 @@ def test_exclude_dir_match_relative(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="exclude1/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), diff --git a/tests/tests_unit/test_file/test_filter/test_file_modified_time.py b/tests/tests_unit/test_file/test_filter/test_file_modified_time.py index 4953ecdc5..14026ae56 100644 --- a/tests/tests_unit/test_file/test_filter/test_file_modified_time.py +++ b/tests/tests_unit/test_file/test_filter/test_file_modified_time.py @@ -21,7 +21,7 @@ def test_file_modified_time_invalid(): # values always timezone-aware @pytest.mark.parametrize( - ["input", "expected"], + ("input", "expected"), [ ( datetime(2025, 1, 1), @@ -76,7 +76,7 @@ def test_file_modified_time_repr(): # only POSIX timestamps are compared, so all values are in UTC @pytest.mark.parametrize( - "matched, mtime", + ("matched", "mtime"), [ (False, datetime(2025, 1, 1, 11, 22, 33, 456788, tzinfo=timezone.utc)), # since-1ms (True, datetime(2025, 1, 1, 11, 22, 33, 456789, tzinfo=timezone.utc)), diff --git a/tests/tests_unit/test_file/test_filter/test_file_size_range.py b/tests/tests_unit/test_file/test_filter/test_file_size_range.py index e258132c4..c834dd220 100644 --- a/tests/tests_unit/test_file/test_filter/test_file_size_range.py +++ b/tests/tests_unit/test_file/test_filter/test_file_size_range.py @@ -28,7 +28,7 @@ def test_file_size_range_repr(): @pytest.mark.parametrize( - ["input", "expected"], + ("input", "expected"), [ ("10", 10), ("10B", 10), @@ -50,7 +50,7 @@ def test_file_size_range_parse(input: str, expected: int): @pytest.mark.parametrize( - "matched, size", + ("matched", "size"), [ (False, 1024), (True, 10 * 1024), diff --git a/tests/tests_unit/test_file/test_filter/test_glob.py b/tests/tests_unit/test_file/test_filter/test_glob.py index 429f7417b..ff7a5f3a8 100644 --- a/tests/tests_unit/test_file/test_filter/test_glob.py +++ b/tests/tests_unit/test_file/test_filter/test_glob.py @@ -10,7 +10,7 @@ def test_glob_invalid(): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="nested/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), diff --git a/tests/tests_unit/test_file/test_filter/test_match_all_filters.py b/tests/tests_unit/test_file/test_filter/test_match_all_filters.py index a499d85d6..8b0b2210b 100644 --- a/tests/tests_unit/test_file/test_filter/test_match_all_filters.py +++ b/tests/tests_unit/test_file/test_filter/test_match_all_filters.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "failed_filters, path", + ("failed_filters", "path"), [ (None, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (None, RemoteFile(path="exclude1/file3.csv", stats=RemotePathStat(st_size=20 * 1024, st_mtime=50))), diff --git a/tests/tests_unit/test_file/test_filter/test_regex.py b/tests/tests_unit/test_file/test_filter/test_regex.py index 6717e76dc..869b2006b 100644 --- a/tests/tests_unit/test_file/test_filter/test_regex.py +++ b/tests/tests_unit/test_file/test_filter/test_regex.py @@ -12,7 +12,7 @@ def test_regep_invalid(): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (False, RemoteFile(path="file.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), @@ -35,7 +35,7 @@ def test_regexp_match_str(matched, path): @pytest.mark.parametrize( - "matched, path", + ("matched", "path"), [ (False, RemoteFile(path="file.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), (True, RemoteFile(path="file1.csv", stats=RemotePathStat(st_size=10 * 1024, st_mtime=50))), diff --git a/tests/tests_unit/test_file/test_format_unit/__init__.py b/tests/tests_unit/test_file/test_format_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py b/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py index 94bade9ad..8fd3db56e 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_avro_unit.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "spark_version, scala_version, package", + ("spark_version", "scala_version", "package"), [ # Detect Scala version by Spark version ("3.2.0", None, "org.apache.spark:spark-avro_2.12:3.2.0"), @@ -27,7 +27,7 @@ def test_avro_get_packages(spark_version, scala_version, package): @pytest.mark.parametrize( - "value, real_value", + ("value", "real_value"), [ ({"name": "abc", "type": "string"}, {"name": "abc", "type": "string"}), ('{"name": "abc", "type": "string"}', {"name": "abc", "type": "string"}), @@ -39,7 +39,7 @@ def test_avro_options_schema(value, real_value): @pytest.mark.parametrize( - "name, real_name, value", + ("name", "real_name", "value"), [ ("avroSchema", "schema_dict", {"name": "abc", "type": "string"}), ("avroSchemaUrl", "schema_url", "http://example.com"), @@ -51,7 +51,7 @@ def test_avro_options_alias(name, real_name, value): @pytest.mark.parametrize( - "known_option, value, expected_value", + ("known_option", "value", "expected_value"), [ ("positionalFieldMatching", True, True), ("mode", "PERMISSIVE", "PERMISSIVE"), @@ -84,9 +84,12 @@ def test_avro_options_repr(): mode="PERMISSIVE", unknownOption="abc", ) - assert ( - repr(avro) - == "Avro(avroSchema={'name': 'abc', 'type': 'string'}, compression='snappy', mode='PERMISSIVE', unknownOption='abc')" + assert repr(avro) == ( + "Avro(" + "avroSchema={'name': 'abc', 'type': 'string'}, " + "compression='snappy', " + "mode='PERMISSIVE', " + "unknownOption='abc')" ) diff --git a/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py b/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py index be0598645..d2bf96d46 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_csv_unit.py @@ -20,7 +20,7 @@ def test_csv_options_delimiter_alias(): @pytest.mark.parametrize( - "known_option, value, expected_value", + ("known_option", "value", "expected_value"), [ ("delimiter", ";", ";"), ("quote", "'", "'"), diff --git a/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py b/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py index 86bb63df2..000ebaa7a 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_excel_unit.py @@ -13,12 +13,12 @@ def test_excel_get_packages_package_version_not_supported(): - with pytest.raises(ValueError, match="Package version should be at least 0.30, got 0.20.4"): + with pytest.raises(ValueError, match=r"Package version should be at least 0\.30, got 0\.20\.4"): Excel.get_packages(package_version="0.20.4", spark_version="3.2.4") @pytest.mark.parametrize( - "package_version, spark_version, scala_version, packages", + ("package_version", "spark_version", "scala_version", "packages"), [ # Detect Scala version by Spark version ("0.31.2", "3.2.4", None, ["dev.mauch:spark-excel_2.12:3.2.4_0.31.2"]), @@ -35,17 +35,12 @@ def test_excel_get_packages_package_version_not_supported(): ("0.31.2", "3.5.6", "2.12.1", ["dev.mauch:spark-excel_2.12:3.5.6_0.31.2"]), ], ) -def test_excel_get_packages(caplog, spark_version, scala_version, package_version, packages): - with caplog.at_level(level=logging.WARNING): - result = Excel.get_packages( - spark_version=spark_version, - scala_version=scala_version, - package_version=package_version, - ) - - if package_version: - assert f"Passed custom package version '{package_version}', it is not guaranteed to be supported" - +def test_excel_get_packages(spark_version, scala_version, package_version, packages): + result = Excel.get_packages( + spark_version=spark_version, + scala_version=scala_version, + package_version=package_version, + ) assert result == packages @@ -60,7 +55,7 @@ def test_excel_options_default_override(): @pytest.mark.parametrize( - "known_option, value, expected_value", + ("known_option", "value", "expected_value"), [ ("dataAddress", "value", "value"), ("treatEmptyValuesAsNulls", True, True), diff --git a/tests/tests_unit/test_file/test_format_unit/test_json_unit.py b/tests/tests_unit/test_file/test_format_unit/test_json_unit.py index 1bcc4127a..b506b72d4 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_json_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_json_unit.py @@ -23,7 +23,7 @@ def test_json_options_timezone_alias(): @pytest.mark.parametrize( - "known_option, value, expected", + ("known_option", "value", "expected"), [ ("encoding", "value", "value"), ("lineSep", "\r\n", "\r\n"), diff --git a/tests/tests_unit/test_file/test_format_unit/test_jsonline_unit.py b/tests/tests_unit/test_file/test_format_unit/test_jsonline_unit.py index 80d2b32f9..98fb8d7c2 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_jsonline_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_jsonline_unit.py @@ -23,7 +23,7 @@ def test_jsonline_options_timezone_alias(): @pytest.mark.parametrize( - "known_option, value, expected", + ("known_option", "value", "expected"), [ ("encoding", "value", "value"), ("lineSep", "\r\n", "\r\n"), diff --git a/tests/tests_unit/test_file/test_format_unit/test_orc_unit.py b/tests/tests_unit/test_file/test_format_unit/test_orc_unit.py index 3c09a0cae..23c7a9f80 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_orc_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_orc_unit.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "known_option, value, expected_value", + ("known_option", "value", "expected_value"), [ ("mergeSchema", True, True), ("compression", "snappy", "snappy"), diff --git a/tests/tests_unit/test_file/test_format_unit/test_parquet_unit.py b/tests/tests_unit/test_file/test_format_unit/test_parquet_unit.py index d5fb153e5..ddb332453 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_parquet_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_parquet_unit.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "known_option, value, expected_value", + ("known_option", "value", "expected_value"), [ ("mergeSchema", True, True), ("compression", "snappy", "snappy"), diff --git a/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py b/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py index 3cb1ebc11..07acae9a1 100644 --- a/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py +++ b/tests/tests_unit/test_file/test_format_unit/test_xml_unit.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "spark_version, scala_version, package_version, expected_packages", + ("spark_version", "scala_version", "package_version", "expected_packages"), [ ("3.2.4", None, None, ["com.databricks:spark-xml_2.12:0.18.0"]), ("3.4.1", "2.12", "0.18.0", ["com.databricks:spark-xml_2.12:0.18.0"]), @@ -26,17 +26,22 @@ ("3.2.4", "2.12.1", "0.15.0", ["com.databricks:spark-xml_2.12:0.15.0"]), ], ) -def test_xml_get_packages(spark_version, scala_version, package_version, expected_packages): - result = XML.get_packages( - spark_version=spark_version, - scala_version=scala_version, - package_version=package_version, - ) +def test_xml_get_packages(caplog, spark_version, scala_version, package_version, expected_packages): + with caplog.at_level(level=logging.WARNING): + result = XML.get_packages( + spark_version=spark_version, + scala_version=scala_version, + package_version=package_version, + ) + + if package_version: + msg = f"Passed custom package version '{package_version}', it is not guaranteed to be supported" + assert msg in caplog.text assert result == expected_packages @pytest.mark.parametrize( - "spark_version, scala_version, package_version", + ("spark_version", "scala_version", "package_version"), [ ("3.2.4", "2.12", "0.13.0"), ("3.4.1", "2.12", "0.10.0"), @@ -60,7 +65,7 @@ def test_xml_options_row_tag_case(): @pytest.mark.parametrize( - "known_option, raw_value, expected_value", + ("known_option", "raw_value", "expected_value"), [ ("samplingRatio", 0.1, 0.1), ("excludeAttribute", True, True), diff --git a/tests/tests_unit/test_file/test_limit/__init__.py b/tests/tests_unit/test_file/test_limit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_file/test_limit/test_total_files_size.py b/tests/tests_unit/test_file/test_limit/test_total_files_size.py index 75f92d60d..a486e7445 100644 --- a/tests/tests_unit/test_file/test_limit/test_total_files_size.py +++ b/tests/tests_unit/test_file/test_limit/test_total_files_size.py @@ -20,7 +20,7 @@ def test_total_files_size_repr(): @pytest.mark.parametrize( - ["input", "expected"], + ("input", "expected"), [ ("10", 10), ("10B", 10), diff --git a/tests/tests_unit/test_hooks/__init__.py b/tests/tests_unit/test_hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_hooks/test_hooks_callback.py b/tests/tests_unit/test_hooks/test_hooks_callback.py index 7fbc00ed7..7e0f45415 100644 --- a/tests/tests_unit/test_hooks/test_hooks_callback.py +++ b/tests/tests_unit/test_hooks/test_hooks_callback.py @@ -263,13 +263,13 @@ def plus(self, arg: int) -> int: @Calculator.plus.bind @hook def modify_callback2(self, arg: int): - result = yield # noqa: F841 + _result = yield yield 123 @Calculator.plus.bind @hook def modify_callback1(self, arg: int): - result = yield # noqa: F841 + _result = yield yield 234 # the last hook result is used @@ -289,7 +289,7 @@ def plus(self, arg: int) -> int: @Calculator.plus.bind @hook def modify_callback(self, arg: int): - yield from (i for i in ()) # noqa: WPS335 + yield from (i for i in ()) # no yield = no override assert Calculator(1).plus(2) == 3 @@ -323,7 +323,8 @@ class Calculator: @slot def plus(self, arg: int) -> int: log.info("Called original method with %s and %s", self.data, arg) - raise TypeError(f"Raised with {self.data} and {arg}") + msg = f"Raised with {self.data} and {arg}" + raise TypeError(msg) @Calculator.plus.bind @hook @@ -334,7 +335,8 @@ def context_callback(self, arg: int): log.info("After method call") except Exception as e: log.exception("Context caught exception") - raise RuntimeError("Replaced") from e + msg = "Replaced" + raise RuntimeError(msg) from e # exception successfully caught with pytest.raises(RuntimeError, match="Replaced"), caplog.at_level(logging.INFO): @@ -460,7 +462,8 @@ def plus(self, arg: int) -> int: @hook def before_callback(self, arg: int): if arg == 3: - raise ValueError("Argument value 3 is not allowed") + msg = "Argument value 3 is not allowed" + raise ValueError(msg) # exception successfully raised with pytest.raises(ValueError, match="Argument value 3 is not allowed"), caplog.at_level(logging.INFO): @@ -491,7 +494,8 @@ def plus(self, arg: int) -> int: def after_callback(self, arg: int): result = yield if result == 4: - raise ValueError("Result value 4 is not allowed") + msg = "Result value 4 is not allowed" + raise ValueError(msg) yield result # exception successfully raised @@ -525,7 +529,8 @@ def plus(self, arg: int) -> int: def after_callback(self, arg: int): yield if arg == 3: - raise ValueError("Argument value 3 is not allowed") + msg = "Argument value 3 is not allowed" + raise ValueError(msg) # exception successfully raised with pytest.raises(ValueError, match="Argument value 3 is not allowed"), caplog.at_level(logging.INFO): @@ -557,12 +562,13 @@ def plus(self, arg: int) -> int: def missing_arg(self): pass - method_name = "test_hooks_callback.test_hooks_execute_callback_wrong_signature..Calculator.plus" - hook_name = "test_hooks_callback.test_hooks_execute_callback_wrong_signature..missing_arg" + local_name = r"tests\.tests_unit\.test_hooks\.test_hooks_callback\.test_hooks_execute_callback_wrong_signature" + method_name = rf"{local_name}\.\.Calculator\.plus" + hook_name = rf"{local_name}\.\.missing_arg" error_msg = textwrap.dedent( rf""" - Error while passing method arguments to a hook. + Error while passing method arguments to a hook\. Method name: '{method_name}' Method source: '{__file__}:\d+' diff --git a/tests/tests_unit/test_hooks/test_hooks_class.py b/tests/tests_unit/test_hooks/test_hooks_class.py index 81d25006d..0af449e79 100644 --- a/tests/tests_unit/test_hooks/test_hooks_class.py +++ b/tests/tests_unit/test_hooks/test_hooks_class.py @@ -121,7 +121,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # stop & resume does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 @@ -186,7 +187,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # skip does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 @@ -256,7 +258,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # skip does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 diff --git a/tests/tests_unit/test_hooks/test_hooks_context_manager.py b/tests/tests_unit/test_hooks/test_hooks_context_manager.py index 06e2a421b..e377b24da 100644 --- a/tests/tests_unit/test_hooks/test_hooks_context_manager.py +++ b/tests/tests_unit/test_hooks/test_hooks_context_manager.py @@ -257,7 +257,8 @@ class Calculator: @slot def plus(self, arg: int) -> int: log.info("Called original method with %s and %s", self.data, arg) - raise TypeError(f"Raised with {self.data} and {arg}") + msg = f"Raised with {self.data} and {arg}" + raise TypeError(msg) @Calculator.plus.bind @hook @@ -274,7 +275,8 @@ def __exit__(self, exc_type, exc_value, traceback): if exc_type: log.exception("Context caught exception", exc_info=(exc_type, exc_value, traceback)) del traceback - raise RuntimeError("Replaced") from exc_value + msg = "Replaced" + raise RuntimeError(msg) from exc_value log.info("After method call") @@ -439,7 +441,8 @@ def plus(self, arg: int) -> int: class BeforeCallback: def __init__(self, instance: Calculator, arg: int): if arg == 3: - raise ValueError("Argument value 3 is not allowed") + msg = "Argument value 3 is not allowed" + raise ValueError(msg) self.instance = instance self.arg = arg @@ -482,7 +485,8 @@ def __init__(self, instance: Calculator, arg: int): def __enter__(self): if self.arg == 3: - raise ValueError("Argument value 3 is not allowed") + msg = "Argument value 3 is not allowed" + raise ValueError(msg) return self def __exit__(self, *args): @@ -529,7 +533,8 @@ def __exit__(self, *args): def process_result(self, result: int) -> int: if result == 4: - raise ValueError("Result value 4 is not allowed") + msg = "Result value 4 is not allowed" + raise ValueError(msg) return result # exception successfully raised @@ -570,7 +575,8 @@ def __enter__(self): def __exit__(self, *args): if self.arg == 3: - raise ValueError("Argument value 3 is not allowed") + msg = "Argument value 3 is not allowed" + raise ValueError(msg) return False # exception successfully raised @@ -610,13 +616,15 @@ def __enter__(self): def __exit__(self, *args): return False - local_name = "test_hooks_context_manager.test_hooks_execute_context_manager_wrong_signature" - method_name = f"{local_name}..Calculator.plus" - hook_name = f"{local_name}..MissingArg" + local_name = ( + r"tests\.tests_unit\.test_hooks\.test_hooks_context_manager\.test_hooks_execute_context_manager_wrong_signature" + ) + method_name = rf"{local_name}\.\.Calculator\.plus" + hook_name = rf"{local_name}\.\.MissingArg" error_msg = textwrap.dedent( rf""" - Error while passing method arguments to a hook. + Error while passing method arguments to a hook\. Method name: '{method_name}' Method source: '{__file__}:\d+' diff --git a/tests/tests_unit/test_hooks/test_hooks_global.py b/tests/tests_unit/test_hooks/test_hooks_global.py index 8cc02889f..0ce937940 100644 --- a/tests/tests_unit/test_hooks/test_hooks_global.py +++ b/tests/tests_unit/test_hooks/test_hooks_global.py @@ -62,7 +62,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # stop & resume does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 @@ -160,7 +161,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # skip does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 @@ -263,7 +265,8 @@ def more_callback(self, arg: int): @hook(enabled=False) def never_called(self, arg: int): # skip does not affect hook state, it should be enabled explicitly - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) assert Calculator1(1).plus(1) == 234 assert Calculator1(1).multiply(1) == 345 diff --git a/tests/tests_unit/test_hooks/test_hooks_slot.py b/tests/tests_unit/test_hooks/test_hooks_slot.py index 289bd651a..2c330bf70 100644 --- a/tests/tests_unit/test_hooks/test_hooks_slot.py +++ b/tests/tests_unit/test_hooks/test_hooks_slot.py @@ -130,7 +130,8 @@ def callback2(self, arg: int): def callback3(self, arg: int): # stop & resume does not change hook state # they also have higher priority - raise AssertionError("Never called") + msg = "Never called" + raise AssertionError(msg) with caplog.at_level(logging.INFO): Calculator.plus.suspend_hooks() diff --git a/tests/tests_unit/test_impl_unit.py b/tests/tests_unit/test_impl_unit.py index e6554ec41..7db419114 100644 --- a/tests/tests_unit/test_impl_unit.py +++ b/tests/tests_unit/test_impl_unit.py @@ -124,7 +124,7 @@ def test_file_stat(): assert file_stat.st_size == 10 assert file_stat.st_mtime == 50 - assert file_stat == file_stat # noqa: WPS312 NOSONAR + assert file_stat == file_stat # noqa: PLR0124 assert RemotePathStat(st_size=10, st_mtime=50) == RemotePathStat(st_size=10, st_mtime=50) assert RemotePathStat(st_size=10, st_mtime=50) != RemotePathStat(st_size=20, st_mtime=50) @@ -138,7 +138,7 @@ def test_file_stat(): @pytest.mark.parametrize( - "item1, item2", + ("item1", "item2"), [ (RemotePath("a/b/c"), RemoteDirectory(path="a/b/c")), (LocalPath("a/b/c"), FailedLocalFile(path="a/b/c", exception=FileNotFoundError("abc"))), @@ -161,21 +161,21 @@ def test_path_compat(item1, item2): assert bytes(item1) == bytes(item2) assert os.fspath(item1) == os.fspath(item2) - assert item1 in {item1} # noqa: WPS525 - assert item2 in {item2} # noqa: WPS525 + assert item1 in {item1} + assert item2 in {item2} assert {item1} == {item2} == {item1, item2} assert len({item1, item2}) == 1 - assert item1 in {item2} # noqa: WPS525 - assert item2 in {item1} # noqa: WPS525 + assert item1 in {item2} + assert item2 in {item1} assert item1 == item2 assert item2 == item1 - assert item1 in [item1] # noqa: WPS525, WPS510 - assert item2 in [item2] # noqa: WPS525, WPS510 + assert item1 in [item1] + assert item2 in [item2] assert [item1] == [item2] - assert item1 in [item2] # noqa: WPS525, WPS510 - assert item2 in [item1] # noqa: WPS525, WPS510 + assert item1 in [item2] + assert item2 in [item1] assert item1 / "d" == item2 / "d" assert "d" / item1 == "d" / item2 @@ -185,7 +185,7 @@ def test_path_compat(item1, item2): @pytest.mark.parametrize( - "item1, item2", + ("item1", "item2"), [ (RemotePath("a/b/c"), RemoteDirectory(path="a/b/c")), (LocalPath("a/b/c"), FailedLocalFile(path="a/b/c", exception=FileNotFoundError("abc"))), @@ -277,7 +277,7 @@ def test_failed_remote_file_eq(): @pytest.mark.parametrize( - "kwargs, kind", + ("kwargs", "kind"), [ ({}, None), ({"st_mode": stat.S_IFSOCK}, "socket"), @@ -315,7 +315,7 @@ def test_path_repr_stats_with_kind(kwargs, kind): @pytest.mark.parametrize( - "st_size, details", + ("st_size", "details"), [ (0, ", size='0 Bytes'"), (10, ", size='10 Bytes'"), @@ -335,7 +335,7 @@ def test_path_repr_stats_with_size(st_size, details): @pytest.mark.parametrize( - "path_class, kind", + ("path_class", "kind"), [ (RemoteFile, "file"), (RemoteDirectory, "directory"), @@ -358,7 +358,7 @@ def test_path_repr_stats_with_mtime(path_class, kind): @pytest.mark.parametrize( - "mode, mode_str", + ("mode", "mode_str"), [ (0o777, "rwxrwxrwx"), (0o666, "rw-rw-rw-"), @@ -376,7 +376,7 @@ def test_path_repr_stats_with_mtime(path_class, kind): ], ) @pytest.mark.parametrize( - "path_class, kind", + ("path_class", "kind"), [ (RemoteFile, "file"), (RemoteDirectory, "directory"), @@ -396,21 +396,21 @@ def test_path_repr_stats_with_mode(path_class, kind, mode, mode_str): @pytest.mark.parametrize( - "user, user_str", + ("user", "user_str"), [ (123, ", uid=123"), ("me", ", uid='me'"), ], ) @pytest.mark.parametrize( - "group, group_str", + ("group", "group_str"), [ (123, ", gid=123"), ("me", ", gid='me'"), ], ) @pytest.mark.parametrize( - "path_class, kind", + ("path_class", "kind"), [ (RemoteFile, "file"), (RemoteDirectory, "directory"), @@ -430,7 +430,7 @@ def test_path_repr_stats_with_owner(path_class, kind, user, user_str, group, gro @pytest.mark.parametrize( - "exception, exception_str", + ("exception", "exception_str"), [ ( FileNotFoundError("abc"), diff --git a/tests/tests_unit/test_internal_unit/__init__.py b/tests/tests_unit/test_internal_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_internal_unit/test_generate_temp_path.py b/tests/tests_unit/test_internal_unit/test_generate_temp_path.py index 0b8f98853..eb785f0ef 100644 --- a/tests/tests_unit/test_internal_unit/test_generate_temp_path.py +++ b/tests/tests_unit/test_internal_unit/test_generate_temp_path.py @@ -1,5 +1,5 @@ import os -from datetime import datetime +from datetime import datetime, timezone from pathlib import PurePath import pytest @@ -13,7 +13,7 @@ def test_generate_temp_path(): root = PurePath("/path") - dt_prefix = datetime.now().strftime("%Y%m%d%H%M") # up to minutes, not seconds + dt_prefix = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M") # up to minutes, not seconds with Process(name="me", host="currenthost"): temp_path = os.fspath(generate_temp_path(root)) diff --git a/tests/tests_unit/test_metrics/__init__.py b/tests/tests_unit/test_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/test_plugins/__init__.py b/tests/tests_unit/test_plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_db_connection_unit/__init__.py b/tests/tests_unit/tests_db_connection_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py index 0874c3288..1cbc4a5f4 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_clickhouse_unit.py @@ -19,7 +19,7 @@ def test_clickhouse_package(): @pytest.mark.parametrize( - "package_version, apache_http_client_version, expected_packages", + ("package_version", "apache_http_client_version", "expected_packages"), [ ( None, @@ -96,7 +96,7 @@ def test_clickhouse_get_packages(package_version, apache_http_client_version, ex @pytest.mark.parametrize( - "package_version, apache_http_client_version", + ("package_version", "apache_http_client_version"), [ ("0.7", "5.4.2"), ("1", "5.4.0"), @@ -106,7 +106,10 @@ def test_clickhouse_get_packages(package_version, apache_http_client_version, ex def test_clickhouse_get_packages_invalid_version(package_version, apache_http_client_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 3\)." + ), ): Clickhouse.get_packages(package_version=package_version, apache_http_client_version=apache_http_client_version) diff --git a/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py b/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py index 00b6cddde..8af9ef50e 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_db_options_unit.py @@ -45,7 +45,7 @@ ], ) @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("url", "jdbc:postgresql://localhost:5432/postgres"), ("driver", "org.postgresql.Driver"), @@ -59,7 +59,7 @@ def test_db_options_connection_parameters_cannot_be_passed(options_class, arg, v @pytest.mark.parametrize( - "options_class, options_class_name, known_options", + ("options_class", "options_class_name", "known_options"), [ (Hive.WriteOptions, "HiveWriteOptions", {"if_exists": "replace_overlapping_partitions"}), (Hive.Options, "HiveLegacyOptions", {"if_exists": "replace_overlapping_partitions"}), @@ -95,7 +95,7 @@ def test_db_options_warn_for_unknown(options_class, options_class_name, known_op @pytest.mark.parametrize( - "options_class,options", + ("options_class", "options"), [ (Postgres.ReadOptions, Postgres.WriteOptions()), (Postgres.WriteOptions, Postgres.ReadOptions()), @@ -117,7 +117,7 @@ def test_db_options_parse_mismatch_class(options_class, options): @pytest.mark.parametrize( - "connection,options", + ("connection", "options"), [ ( Postgres, diff --git a/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py b/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py index 7dc7ac7d2..04b418f6d 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_dialect_unit.py @@ -265,7 +265,7 @@ def test_db_dialect_get_sql_query_compact_true(spark_mock): @pytest.mark.parametrize( - "limit, where, expected_query", + ("limit", "where", "expected_query"), [ (None, None, "SELECT\n *\nFROM\n default.test"), (0, None, "SELECT\n *\nFROM\n default.test\nWHERE\n 1=0"), @@ -275,7 +275,10 @@ def test_db_dialect_get_sql_query_compact_true(spark_mock): ( 5, "column1 = 'value'", - "SELECT\n *\nFROM\n default.test\nWHERE\n (column1 = 'value')\n AND\n (ROWNUM <= 5)", + "SELECT\n *\n" + "FROM\n default.test\n" + "WHERE\n (column1 = 'value')\n " + "AND\n (ROWNUM <= 5)", ), ], ) @@ -286,7 +289,7 @@ def test_oracle_dialect_get_sql_query_limit_where(spark_mock, limit, where, expe @pytest.mark.parametrize( - "limit, where, expected_query", + ("limit", "where", "expected_query"), [ (None, None, "SELECT\n *\nFROM\n default.test"), (0, None, "SELECT\n *\nFROM\n default.test\nWHERE\n 1 = 0"), diff --git a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py index 89d4b55ae..8639acc62 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_greenplum_unit.py @@ -32,7 +32,7 @@ def test_greenplum_get_packages_spark_version_not_supported(spark_version): @pytest.mark.parametrize( - "spark_version, scala_version, package", + ("spark_version", "scala_version", "package"), [ # use Scala version directly (None, "2.12", "io.pivotal:greenplum-spark_2.12:2.2.0"), @@ -49,7 +49,7 @@ def test_greenplum_get_packages(spark_version, scala_version, package): @pytest.mark.parametrize( - "package_version, scala_version, package", + ("package_version", "scala_version", "package"), [ (None, "2.12", "io.pivotal:greenplum-spark_2.12:2.2.0"), ("2.3.0", "2.12", "io.pivotal:greenplum-spark_2.12:2.3.0"), @@ -248,7 +248,7 @@ def test_greenplum_write_options_default(): @pytest.mark.parametrize( - "klass, name", + ("klass", "name"), [ (Greenplum.ReadOptions, "GreenplumReadOptions"), (Greenplum.WriteOptions, "GreenplumWriteOptions"), @@ -280,7 +280,7 @@ def test_greenplum_read_write_options_populated_by_connection_class(options_clas @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("mode", "append"), ("truncate", "true"), @@ -295,7 +295,7 @@ def test_greenplum_write_options_cannot_be_used_in_read_options(arg, value): @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("partitions", 10), ("numPartitions", 10), @@ -310,7 +310,7 @@ def test_greenplum_read_options_cannot_be_used_in_write_options(arg, value): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, GreenplumTableExistBehavior.APPEND), ({"if_exists": "append"}, GreenplumTableExistBehavior.APPEND), @@ -324,7 +324,7 @@ def test_greenplum_write_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "append"}, diff --git a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py index f1a542bfc..39e0425a5 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_hive_unit.py @@ -121,7 +121,7 @@ def test_hive_write_options_sort_by_without_bucket_by(sort_by): @pytest.mark.parametrize( - "mode, recommended", + ("mode", "recommended"), [ ("dynamic", "replace_overlapping_partitions"), ("static", "replace_entire_table"), @@ -151,7 +151,7 @@ def test_hive_write_options_unsupported_insert_into(insert_into): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, HiveTableExistBehavior.APPEND), ({"if_exists": "append"}, HiveTableExistBehavior.APPEND), @@ -166,7 +166,7 @@ def test_hive_write_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "append"}, diff --git a/tests/tests_unit/tests_db_connection_unit/test_iceberg_unit.py b/tests/tests_unit/tests_db_connection_unit/test_iceberg_unit.py index 36f8b7e41..1ba69baaf 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_iceberg_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_iceberg_unit.py @@ -466,7 +466,7 @@ def test_iceberg_spark_stopped(iceberg_mock, spark_stopped): @pytest.mark.parametrize( - "package_version,spark_version,scala_version,package", + ("package_version", "spark_version", "scala_version", "package"), [ ("1.4.0", "3.3", None, "org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.4.0"), ("1.10.0", "3.5", "2.12", "org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.10.0"), diff --git a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py index ab9874eaa..a5d8fc2e7 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_jdbc_options_unit.py @@ -35,7 +35,7 @@ def test_jdbc_options_default(): @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("table", "mytable"), ("dbtable", "mytable"), @@ -44,7 +44,7 @@ def test_jdbc_options_default(): ], ) @pytest.mark.parametrize( - "options_class, read_write_restriction", + ("options_class", "read_write_restriction"), [ (Postgres.FetchOptions, False), (Postgres.ExecuteOptions, False), @@ -80,7 +80,7 @@ def test_jdbc_read_write_options_populated_by_connection_class(arg, value, optio @pytest.mark.parametrize( - "options_class, options_class_name", + ("options_class", "options_class_name"), [ (Postgres.ReadOptions, "PostgresReadOptions"), (Clickhouse.ReadOptions, "ClickhouseReadOptions"), @@ -90,7 +90,7 @@ def test_jdbc_read_write_options_populated_by_connection_class(arg, value, optio ], ) @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("column", "some"), ("mode", "append"), @@ -110,7 +110,7 @@ def test_jdbc_write_options_cannot_be_used_in_read_options(arg, value, options_c @pytest.mark.parametrize( - "options_class, options_class_name", + ("options_class", "options_class_name"), [ (Postgres.WriteOptions, "PostgresWriteOptions"), (Clickhouse.WriteOptions, "ClickhouseWriteOptions"), @@ -120,7 +120,7 @@ def test_jdbc_write_options_cannot_be_used_in_read_options(arg, value, options_c ], ) @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("column", "some"), ("partitionColumn", "part"), @@ -148,7 +148,7 @@ def test_jdbc_read_options_cannot_be_used_in_write_options(options_class, option @pytest.mark.parametrize( - "arg, value", + ("arg", "value"), [ ("mode", "append"), ("batchsize", 10), @@ -250,7 +250,7 @@ def test_jdbc_write_options_case(): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, JDBCTableExistBehavior.APPEND), ({"if_exists": "append"}, JDBCTableExistBehavior.APPEND), @@ -264,7 +264,7 @@ def test_jdbc_write_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "append"}, @@ -305,7 +305,7 @@ def test_jdbc_write_options_mode_deprecated(options, value, message): @pytest.mark.parametrize( - "options_class, options", + ("options_class", "options"), [ (Postgres.WriteOptions, {"if_exists": "wrong_mode"}), (Clickhouse.WriteOptions, {"if_exists": "wrong_mode"}), @@ -320,7 +320,7 @@ def test_jdbc_write_options_mode_wrong(options_class, options): @pytest.mark.parametrize( - "options, expected_message", + ("options", "expected_message"), [ ({"numPartitions": 2}, "lowerBound and upperBound must be set if numPartitions > 1"), ({"numPartitions": 2, "lowerBound": 0}, "lowerBound and upperBound must be set if numPartitions > 1"), diff --git a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py index f4a50a341..0c7ce5699 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_kafka_unit.py @@ -34,7 +34,7 @@ def create_temp_file(tmp_path_factory): @pytest.mark.parametrize( - "spark_version, scala_version, package", + ("spark_version", "scala_version", "package"), [ ("3.2.0", None, "org.apache.spark:spark-sql-kafka-0-10_2.12:3.2.0"), ("3.2.0", "2.12", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.2.0"), @@ -67,7 +67,7 @@ def test_kafka_spark_stopped(spark_stopped): @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("assign", "assign_value"), ("subscribe", "subscribe_value"), @@ -84,7 +84,7 @@ def test_kafka_spark_stopped(spark_stopped): ], ) @pytest.mark.parametrize( - "options_class, class_name", + ("options_class", "class_name"), [ (Kafka.ReadOptions, "KafkaReadOptions"), (Kafka.WriteOptions, "KafkaWriteOptions"), @@ -97,7 +97,7 @@ def test_kafka_options_prohibited(option, value, options_class, class_name): @pytest.mark.parametrize( - "options_class, class_name", + ("options_class", "class_name"), [ (Kafka.ReadOptions, "KafkaReadOptions"), (Kafka.WriteOptions, "KafkaWriteOptions"), @@ -112,7 +112,7 @@ def test_kafka_options_unknown(caplog, options_class, class_name): @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("failOnDataLoss", "false"), ("kafkaConsumer.pollTimeoutMs", "30000"), @@ -247,7 +247,7 @@ def test_kafka_empty_cluster(spark_mock): @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("bootstrap.servers", "kafka.bootstrap.servers_value"), ("security.protocol", "ssl"), @@ -270,7 +270,7 @@ def test_kafka_invalid_extras(option, value): @pytest.mark.parametrize( - "option, value", + ("option", "value"), [ ("kafka.group.id", "group_id"), ("group.id", "group_id"), @@ -282,7 +282,7 @@ def test_kafka_valid_extras(option, value): def test_kafka_kerberos_auth_not_enough_permissions_keytab_error(create_keytab): - os.chmod(create_keytab, 0o000) # noqa: S103, WPS339 + create_keytab.chmod(0o000) with pytest.raises( OSError, @@ -647,7 +647,7 @@ def test_kafka_normalize_address_hook(request, spark_mock): def normalize_address(address: str, cluster: str): if cluster == "kafka-cluster": return f"{address}:9093" - elif cluster == "local": + if cluster == "local": return f"{address}:9092" return None @@ -672,12 +672,12 @@ def get_cluster_addresses(cluster: str): "192.168.1.2", ] - with pytest.raises(ValueError, match="Cluster 'kafka-cluster' does not contain addresses {'192.168.1.3'}"): + with pytest.raises(ValueError, match=r"Cluster 'kafka-cluster' does not contain addresses \{'192.168.1.3'\}"): Kafka(cluster="kafka-cluster", spark=spark_mock, addresses=["192.168.1.1", "192.168.1.3"]) @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, KafkaTopicExistBehaviorKafka.APPEND), ({"if_exists": "append"}, KafkaTopicExistBehaviorKafka.APPEND), @@ -689,7 +689,7 @@ def test_kafka_write_options_if_exists(options, value): @pytest.mark.parametrize( - "options, message", + ("options", "message"), [ ( {"mode": "append"}, @@ -752,7 +752,7 @@ def test_kafka_ssl_protocol_with_raw_strings(spark_mock, prefix): @pytest.mark.parametrize( - "keystore_type,truststore_type", + ("keystore_type", "truststore_type"), [ ("PEM", "PEM"), ("JKS", "JKS"), @@ -870,7 +870,7 @@ def test_kafka_ssl_protocol_with_basic_auth(spark_mock): @pytest.mark.parametrize( - "columns,expected_schema", + ("columns", "expected_schema"), [ ( ["key", "value", "offset"], diff --git a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py index 8bed934a0..09caebd41 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mongodb_unit.py @@ -23,7 +23,7 @@ def test_mongodb_get_packages_no_input(): @pytest.mark.parametrize( - "spark_version, scala_version, package_version, package", + ("spark_version", "scala_version", "package_version", "package"), [ (None, "2.12", "10.5.0", "org.mongodb.spark:mongo-spark-connector_2.12:10.5.0"), (None, "2.13", "10.5.0", "org.mongodb.spark:mongo-spark-connector_2.13:10.5.0"), @@ -55,7 +55,10 @@ def test_mongodb_get_packages(spark_version, scala_version, package_version, pac def test_mongodb_get_packages_invalid_package_version(package_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 2\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 2\)." + ), ): MongoDB.get_packages(scala_version="2.12", package_version=package_version) @@ -117,7 +120,7 @@ def test_mongodb(spark_mock): ], ) def test_mongodb_prohibited_options_error(prohibited_options): - with pytest.raises(ValueError): # noqa: PT011 + with pytest.raises(ValueError): MongoDB.PipelineOptions(**prohibited_options) @@ -264,7 +267,7 @@ def test_mongodb_convert_dict_to_str(spark_mock): @pytest.mark.parametrize( - "options, value", + ("options", "value"), [ ({}, MongoDBCollectionExistBehavior.APPEND), ({"if_exists": "append"}, MongoDBCollectionExistBehavior.APPEND), @@ -278,7 +281,7 @@ def test_mongodb_write_options_if_exists(options, value): @pytest.mark.parametrize( - "options, value, message", + ("options", "value", "message"), [ ( {"mode": "append"}, diff --git a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py index 93b0fd4be..9d45f57a2 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mssql_unit.py @@ -19,7 +19,7 @@ def test_mssql_package(): @pytest.mark.parametrize( - "java_version, package_version, expected_packages", + ("java_version", "package_version", "expected_packages"), [ (None, None, ["com.microsoft.sqlserver:mssql-jdbc:13.2.1.jre8"]), ("8", None, ["com.microsoft.sqlserver:mssql-jdbc:13.2.1.jre8"]), @@ -48,7 +48,10 @@ def test_mssql_get_packages(java_version, package_version, expected_packages): def test_mssql_get_packages_invalid_version(package_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 3\)." + ), ): MSSQL.get_packages(package_version=package_version) diff --git a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py index a92c287df..dbfd4fd67 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_mysql_unit.py @@ -19,7 +19,7 @@ def test_mysql_package(): @pytest.mark.parametrize( - "package_version, expected_packages", + ("package_version", "expected_packages"), [ (None, ["com.mysql:mysql-connector-j:9.5.0"]), ("9.5.0", ["com.mysql:mysql-connector-j:9.5.0"]), @@ -41,7 +41,10 @@ def test_mysql_get_packages(package_version, expected_packages): def test_mysql_get_packages_invalid_version(package_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 3\)." + ), ): MySQL.get_packages(package_version=package_version) diff --git a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py index 5c6ee1b60..81516699a 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_oracle_unit.py @@ -23,7 +23,7 @@ def test_oracle_get_packages_no_input(): @pytest.mark.parametrize( - "java_version, package_version, expected_packages", + ("java_version", "package_version", "expected_packages"), [ (None, None, ["com.oracle.database.jdbc:ojdbc8:23.26.0.0.0"]), ("8", None, ["com.oracle.database.jdbc:ojdbc8:23.26.0.0.0"]), @@ -42,7 +42,7 @@ def test_oracle_get_packages(java_version, package_version, expected_packages): @pytest.mark.parametrize( - "java_version, package_version", + ("java_version", "package_version"), [ ("8", "23.3.0"), ("11", "23.3"), @@ -52,7 +52,10 @@ def test_oracle_get_packages(java_version, package_version, expected_packages): def test_oracle_get_packages_invalid_version(java_version, package_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 4\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 4\)." + ), ): Oracle.get_packages(java_version=java_version, package_version=package_version) diff --git a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py index bdc3414e6..4dbb27cc7 100644 --- a/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py +++ b/tests/tests_unit/tests_db_connection_unit/test_postgres_unit.py @@ -19,7 +19,7 @@ def test_postgres_package(): @pytest.mark.parametrize( - "package_version, expected_packages", + ("package_version", "expected_packages"), [ (None, ["org.postgresql:postgresql:42.7.8"]), ("42.7.8", ["org.postgresql:postgresql:42.7.8"]), @@ -41,7 +41,10 @@ def test_postgres_get_packages(package_version, expected_packages): def test_postgres_get_packages_invalid_version(package_version): with pytest.raises( ValueError, - match=rf"Version '{package_version}' does not have enough numeric components for requested format \(expected at least 3\).", + match=( + f"Version '{package_version}' does not have enough numeric components " + r"for requested format \(expected at least 3\)." + ), ): Postgres.get_packages(package_version=package_version) diff --git a/tests/tests_unit/tests_file_connection_unit/__init__.py b/tests/tests_unit/tests_file_connection_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py b/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py index 1385f0d00..b0efcb41e 100644 --- a/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py +++ b/tests/tests_unit/tests_file_connection_unit/test_hdfs_unit.py @@ -149,7 +149,7 @@ def finalizer(): request.addfinalizer(finalizer) with pytest.raises(ValueError, match="Please provide either `keytab` or `password` for kinit, not both"): - HDFS(host="hdfs2", webhdfs_port=50070, user="usr", password="pwd", keytab=keytab) # noqa: F841 + HDFS(host="hdfs2", webhdfs_port=50070, user="usr", password="pwd", keytab=keytab) def test_hdfs_get_known_clusters_hook(request): @@ -213,9 +213,8 @@ def test_hdfs_normalize_namenode_host_hook(request): @hook def normalize_namenode_host(host: str, cluster: str | None) -> str: host = host.lower() - if cluster == "rnd-dwh": - if not host.endswith(".domain.com"): - host += ".domain.com" + if cluster == "rnd-dwh" and not host.endswith(".domain.com"): + host += ".domain.com" return host request.addfinalizer(normalize_namenode_host.disable) diff --git a/tests/tests_unit/tests_file_df_connection_unit/__init__.py b/tests/tests_unit/tests_file_df_connection_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py index e275e2775..beca8158b 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_hdfs_unit.py @@ -110,9 +110,8 @@ def test_spark_hdfs_normalize_namenode_host_hook(request, spark_mock): @hook def normalize_namenode_host(host: str, cluster: str) -> str: host = host.lower() - if cluster == "rnd-dwh": - if not host.endswith(".domain.com"): - host += ".domain.com" + if cluster == "rnd-dwh" and not host.endswith(".domain.com"): + host += ".domain.com" return host request.addfinalizer(normalize_namenode_host.disable) diff --git a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py index f45fc07ad..ba58cb5ca 100644 --- a/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py +++ b/tests/tests_unit/tests_file_df_connection_unit/test_spark_s3_unit.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize( - "spark_version, scala_version, package", + ("spark_version", "scala_version", "package"), [ ("3.5.7", None, "org.apache.spark:spark-hadoop-cloud_2.12:3.5.7"), ("3.5.7", "2.12", "org.apache.spark:spark-hadoop-cloud_2.12:3.5.7"), @@ -161,7 +161,7 @@ def test_spark_s3_without_path_style_access(spark_mock_hadoop_3): @pytest.mark.parametrize( - "name, value", + ("name", "value"), [ ("attempts.maximum", 1), ("connection.establish.timeout", 300000), @@ -180,7 +180,7 @@ def test_spark_s3_extra_allowed_options(name, value, prefix): @pytest.mark.parametrize( - "name, value", + ("name", "value"), [ ("impl", "org.apache.hadoop.fs.s3a.S3AFileSystem"), ("endpoint", "http://localhost:9010"), diff --git a/tests/tests_unit/tests_hwm_store_unit/__init__.py b/tests/tests_unit/tests_hwm_store_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py b/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py index c177b608e..701acbd17 100644 --- a/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py +++ b/tests/tests_unit/tests_hwm_store_unit/test_yaml_hwm_store_unit.py @@ -89,9 +89,8 @@ def test_hwm_store_yaml_context_manager(caplog): assert hwm_store.path assert hwm_store.encoding == "utf-8" - with caplog.at_level(logging.INFO): - with hwm_store as store: - assert HWMStoreStackManager.get_current() == store + with caplog.at_level(logging.INFO), hwm_store as store: + assert HWMStoreStackManager.get_current() == store assert "Using YAMLHWMStore as HWM Store" in caplog.text assert "path = " in caplog.text @@ -113,9 +112,8 @@ def finalizer(): assert hwm_store.path == path assert hwm_store.encoding == "utf-8" - with caplog.at_level(logging.INFO): - with hwm_store as store: - assert HWMStoreStackManager.get_current() == store + with caplog.at_level(logging.INFO), hwm_store as store: + assert HWMStoreStackManager.get_current() == store assert "Using YAMLHWMStore as HWM Store" in caplog.text assert str(path) in caplog.text @@ -138,9 +136,8 @@ def finalizer(): assert hwm_store.path == path assert hwm_store.encoding == "cp-1251" - with caplog.at_level(logging.INFO): - with hwm_store as store: - assert HWMStoreStackManager.get_current() == store + with caplog.at_level(logging.INFO), hwm_store as store: + assert HWMStoreStackManager.get_current() == store assert "Using YAMLHWMStore as HWM Store" in caplog.text assert str(path) in caplog.text @@ -151,7 +148,7 @@ def finalizer(): @pytest.mark.parametrize( - "qualified_name, file_name", + ("qualified_name", "file_name"), [ ( "id|partition=abc/another=cde#mydb.mytable@dbtype://host.name:1234/schema#dag.task.myprocess@myhost", @@ -198,7 +195,7 @@ def test_hwm_store_no_deprecation_warning_yaml_hwm_store(): @pytest.mark.parametrize( - "import_name, original_import", + ("import_name", "original_import"), [ ("MemoryHWMStore", OriginalMemoryHWMStore), ("BaseHWMStore", OriginalBaseHWMStore), diff --git a/tests/tests_unit/tests_strategy_unit/__init__.py b/tests/tests_unit/tests_strategy_unit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py b/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py index 82e1924a5..6514b76b0 100644 --- a/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py +++ b/tests/tests_unit/tests_strategy_unit/test_strategy_unit.py @@ -32,7 +32,7 @@ def test_strategy_batch_step_is_empty(step, strategy): @patch.object(Postgres, "check") @pytest.mark.parametrize( - "strategy, kwargs", + ("strategy", "kwargs"), [ (IncrementalStrategy, {}), (IncrementalBatchStrategy, {"step": 1}), diff --git a/tests/util/assert_df.py b/tests/util/assert_df.py index 331388e83..8ed53c73e 100644 --- a/tests/util/assert_df.py +++ b/tests/util/assert_df.py @@ -33,8 +33,8 @@ def assert_equal_df( left_df = left_df.sort_values(by=order_by.lower()) right_df = right_df.sort_values(by=order_by.lower()) - left_df.reset_index(inplace=True, drop=True) - right_df.reset_index(inplace=True, drop=True) + left_df.reset_index(inplace=True, drop=True) # noqa: PD002 + right_df.reset_index(inplace=True, drop=True) # noqa: PD002 # ignore columns order left_df = left_df.sort_index(axis=1) @@ -65,7 +65,7 @@ def assert_subset_df( else: columns = [column.lower() for column in columns] - for column in columns: # noqa: WPS528 + for column in columns: small_column = small_pdf[column] large_column = large_pdf[column] different_indices = ~small_column.isin(large_column) diff --git a/tests/util/rand.py b/tests/util/rand.py index 0b0a7b56e..6b3080627 100644 --- a/tests/util/rand.py +++ b/tests/util/rand.py @@ -4,4 +4,4 @@ def rand_str(alphabet: str = ascii_lowercase, length: int = 10) -> str: alphabet_length = len(alphabet) - return "".join(alphabet[randint(0, alphabet_length - 1)] for _ in range(length)) # noqa: S311 + return "".join(alphabet[randint(0, alphabet_length - 1)] for _ in range(length)) diff --git a/tests/util/upload_files.py b/tests/util/upload_files.py index 8be81990b..2e130ac39 100644 --- a/tests/util/upload_files.py +++ b/tests/util/upload_files.py @@ -30,8 +30,9 @@ def upload_files( remote_files.append(remote_filename) if not remote_files: + msg = f"Could not load file examples from {local_path}. Path should exist and should contain samples" raise RuntimeError( - f"Could not load file examples from {local_path}. Path should exist and should contain samples", + msg, ) return remote_files diff --git a/uv.lock b/uv.lock index 7ffa5410b..d7a1b6f8b 100644 --- a/uv.lock +++ b/uv.lock @@ -430,15 +430,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/82/82745642d3c46e7cea25e1885b014b033f4693346ce46b7f47483cf5d448/argon2_cffi_bindings-25.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:da0c79c23a63723aa5d782250fbf51b768abca630285262fb5144ba5ae01e520", size = 29187, upload-time = "2025-07-30T10:02:03.674Z" }, ] -[[package]] -name = "attrs" -version = "25.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6b/5c/685e6633917e101e5dcb62b9dd76946cbb57c26e133bae9e0cd36033c0a9/attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11", size = 934251, upload-time = "2025-10-06T13:54:44.725Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, -] - [[package]] name = "autodoc-pydantic" version = "1.9.1" @@ -2860,20 +2851,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/7f/a1a97644e39e7316d850784c642093c99df1290a460df4ede27659056834/filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a", size = 16666, upload-time = "2025-12-15T23:54:26.874Z" }, ] -[[package]] -name = "flake8" -version = "7.3.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mccabe", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, - { name = "pycodestyle", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, - { name = "pyflakes", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9b/af/fbfe3c4b5a657d79e5c47a2827a362f9e1b763336a52f926126aa6dc7123/flake8-7.3.0.tar.gz", hash = "sha256:fe044858146b9fc69b551a4b490d69cf960fcb78ad1edcb84e7fbb1b4a8e3872", size = 48326, upload-time = "2025-06-20T19:31:35.838Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/56/13ab06b4f93ca7cac71078fbe37fcea175d3216f31f85c3168a6bbd0bb9a/flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e", size = 57922, upload-time = "2025-06-20T19:31:34.425Z" }, -] - [[package]] name = "frozendict" version = "2.4.7" @@ -4545,15 +4522,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/d3/fe08482b5cd995033556d45041a4f4e76e7f0521112a9c9991d40d39825f/markupsafe-3.0.3-cp39-cp39-win_arm64.whl", hash = "sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8", size = 13928, upload-time = "2025-09-27T18:37:39.037Z" }, ] -[[package]] -name = "mccabe" -version = "0.7.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/ff/0ffefdcac38932a54d2b5eed4e0ba8a408f215002cd178ad1df0f2806ff8/mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325", size = 9658, upload-time = "2022-01-24T01:14:51.113Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e", size = 7350, upload-time = "2022-01-24T01:14:49.62Z" }, -] - [[package]] name = "mdit-py-plugins" version = "0.5.0" @@ -5319,7 +5287,6 @@ dev = [ { name = "prek", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, { name = "types-deprecated", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, { name = "types-pyyaml", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, - { name = "wemake-python-styleguide", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, ] docs = [ { name = "autodoc-pydantic", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, @@ -5524,7 +5491,6 @@ dev = [ { name = "prek", marker = "python_full_version >= '3.12'", specifier = "~=0.2.25" }, { name = "types-deprecated", marker = "python_full_version >= '3.12'", specifier = "~=1.3.1" }, { name = "types-pyyaml", marker = "python_full_version >= '3.12'", specifier = "~=6.0.12" }, - { name = "wemake-python-styleguide", marker = "python_full_version >= '3.12'", specifier = "~=1.5.0" }, ] docs = [ { name = "autodoc-pydantic", marker = "python_full_version >= '3.12'", specifier = "~=1.9.1" }, @@ -7038,15 +7004,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, ] -[[package]] -name = "pycodestyle" -version = "2.14.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/e0/abfd2a0d2efe47670df87f3e3a0e2edda42f055053c85361f19c0e2c1ca8/pycodestyle-2.14.0.tar.gz", hash = "sha256:c4b5b517d278089ff9d0abdec919cd97262a3367449ea1c8b49b91529167b783", size = 39472, upload-time = "2025-06-20T18:49:48.75Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/27/a58ddaf8c588a3ef080db9d0b7e0b97215cee3a45df74f3a94dbbf5c893a/pycodestyle-2.14.0-py2.py3-none-any.whl", hash = "sha256:dd6bf7cb4ee77f8e016f9c8e74a35ddd9f67e1d5fd4184d86c3b98e07099f42d", size = 31594, upload-time = "2025-06-20T18:49:47.491Z" }, -] - [[package]] name = "pycparser" version = "2.21" @@ -7232,15 +7189,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/98/556e82f00b98486def0b8af85da95e69d2be7e367cf2431408e108bc3095/pydantic-1.10.26-py3-none-any.whl", hash = "sha256:c43ad70dc3ce7787543d563792426a16fd7895e14be4b194b5665e36459dd917", size = 166975, upload-time = "2025-12-18T15:47:44.927Z" }, ] -[[package]] -name = "pyflakes" -version = "3.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/45/dc/fd034dc20b4b264b3d015808458391acbf9df40b1e54750ef175d39180b1/pyflakes-3.4.0.tar.gz", hash = "sha256:b24f96fafb7d2ab0ec5075b7350b3d2d2218eab42003821c06344973d3ea2f58", size = 64669, upload-time = "2025-06-20T18:45:27.834Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/2f/81d580a0fb83baeb066698975cb14a618bdbed7720678566f1b046a95fe8/pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f", size = 63551, upload-time = "2025-06-20T18:45:26.937Z" }, -] - [[package]] name = "pygments" version = "2.19.2" @@ -10566,20 +10514,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, ] -[[package]] -name = "wemake-python-styleguide" -version = "1.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "attrs", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, - { name = "flake8", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, - { name = "pygments", marker = "python_full_version >= '3.12' or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-3') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-2' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-4') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-3' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-3-5') or (extra == 'group-5-onetl-test-spark-3-4' and extra == 'group-5-onetl-test-spark-4-0') or (extra == 'group-5-onetl-test-spark-3-5' and extra == 'group-5-onetl-test-spark-4-0')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0a/06/a0968903b6d41de7c41b8d8cc6a31471894365fe0c7b13684a3e6faa956d/wemake_python_styleguide-1.5.0.tar.gz", hash = "sha256:a764b30bd298ecd3ca9d5cd64b7776dec9a529d728291e8b8076a56649d6cce1", size = 156819, upload-time = "2025-12-22T19:39:10.812Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/15/2a/9e303a946df30335695b60c0f75dc49a5f339cf4a101772522c88d4133f7/wemake_python_styleguide-1.5.0-py3-none-any.whl", hash = "sha256:0743a8d1a748e3b84cad5804de4e6211641cd77e5ea0e4aaeec1993765e7f6c0", size = 219957, upload-time = "2025-12-22T19:39:09.076Z" }, -] - [[package]] name = "zipp" version = "3.15.0"