diff --git a/contrib/hamilton/contrib/dagworks/text_summarization/__init__.py b/contrib/hamilton/contrib/dagworks/text_summarization/__init__.py index 502fda293..f79b7be25 100644 --- a/contrib/hamilton/contrib/dagworks/text_summarization/__init__.py +++ b/contrib/hamilton/contrib/dagworks/text_summarization/__init__.py @@ -18,7 +18,8 @@ import concurrent import logging import tempfile -from typing import Generator, Union +from collections.abc import Generator +from typing import Union logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper" @config.when(file_type="pdf") -def raw_text__pdf(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text__pdf(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. @@ -64,7 +65,7 @@ def raw_text__pdf(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) @config.when(file_type="txt") -def raw_text__txt(text_file: Union[str, tempfile.SpooledTemporaryFile]) -> str: +def raw_text__txt(text_file: str | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a text file and returns a string of the text file's contents :param text_file: the path, or the temporary file, to the text file. :return: the contents of the file as a string. diff --git a/contrib/hamilton/contrib/user/elijahbenizzy/caption_images/__init__.py b/contrib/hamilton/contrib/user/elijahbenizzy/caption_images/__init__.py index 9fb191013..1a3efe1d9 100644 --- a/contrib/hamilton/contrib/user/elijahbenizzy/caption_images/__init__.py +++ b/contrib/hamilton/contrib/user/elijahbenizzy/caption_images/__init__.py @@ -40,7 +40,7 @@ def openai_client() -> openai.OpenAI: return openai.OpenAI() -def _encode_image(image_path_or_file: Union[str, IO], ext: str): +def _encode_image(image_path_or_file: str | IO, ext: str): """Helper fn to return a base-64 encoded image""" file_like_object = ( image_path_or_file @@ -79,8 +79,8 @@ def processed_image_url(image_url: str) -> str: def caption_prompt( core_prompt: str, - additional_prompt: Optional[str] = None, - descriptiveness: Optional[str] = None, + additional_prompt: str | None = None, + descriptiveness: str | None = None, ) -> str: """Returns the prompt used to describe an image""" out = core_prompt @@ -128,7 +128,7 @@ def caption_embeddings( openai_client: openai.OpenAI, embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL, generated_caption: str = None, -) -> List[float]: +) -> list[float]: """Returns the embeddings for a generated caption""" data = ( openai_client.embeddings.create( @@ -158,7 +158,7 @@ def caption_metadata( @config.when(include_embeddings=True) def embeddings_metadata( - caption_embeddings: List[float], + caption_embeddings: list[float], embeddings_model: str = DEFAULT_EMBEDDINGS_MODEL, ) -> dict: """Returns metadata for the embeddings portion of the workflow""" @@ -170,9 +170,9 @@ def embeddings_metadata( def metadata( embeddings_metadata: dict, - caption_metadata: Optional[dict] = None, - additional_metadata: Optional[dict] = None, -) -> Dict[str, Any]: + caption_metadata: dict | None = None, + additional_metadata: dict | None = None, +) -> dict[str, Any]: """Returns the response to a given chat""" out = embeddings_metadata if caption_metadata is not None: diff --git a/contrib/hamilton/contrib/user/elijahbenizzy/convert_images_s3/__init__.py b/contrib/hamilton/contrib/user/elijahbenizzy/convert_images_s3/__init__.py index c53043f9d..972665c7d 100644 --- a/contrib/hamilton/contrib/user/elijahbenizzy/convert_images_s3/__init__.py +++ b/contrib/hamilton/contrib/user/elijahbenizzy/convert_images_s3/__init__.py @@ -92,7 +92,7 @@ def converted_and_saved( image: Image, file_to_convert: ToConvert, new_format: str = "jpeg", - image_params: Optional[Dict[str, Any]] = None, + image_params: dict[str, Any] | None = None, ) -> Converted: """Returns a list of all files to convert.""" s3 = _s3() @@ -121,7 +121,7 @@ def converted_and_saved( ) -def all_converted_and_saved(converted_and_saved: Collect[Converted]) -> List[Converted]: +def all_converted_and_saved(converted_and_saved: Collect[Converted]) -> list[Converted]: """Returns a list of all downloaded locations""" return list(converted_and_saved) diff --git a/contrib/hamilton/contrib/user/elijahbenizzy/parallel_load_dataframes_s3/__init__.py b/contrib/hamilton/contrib/user/elijahbenizzy/parallel_load_dataframes_s3/__init__.py index 80b628f3d..343bf703f 100644 --- a/contrib/hamilton/contrib/user/elijahbenizzy/parallel_load_dataframes_s3/__init__.py +++ b/contrib/hamilton/contrib/user/elijahbenizzy/parallel_load_dataframes_s3/__init__.py @@ -107,7 +107,7 @@ def downloaded_data( return download_location -def all_downloaded_data(downloaded_data: Collect[str]) -> List[str]: +def all_downloaded_data(downloaded_data: Collect[str]) -> list[str]: """Returns a list of all downloaded locations""" out = [] for path in downloaded_data: @@ -120,7 +120,7 @@ def _jsonl_parse(path: str) -> pd.DataFrame: return pd.read_json(path, lines=True) -def processed_dataframe(all_downloaded_data: List[str]) -> pd.DataFrame: +def processed_dataframe(all_downloaded_data: list[str]) -> pd.DataFrame: """Processes everything into a dataframe""" out = [] for floc in all_downloaded_data: diff --git a/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py b/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py index 3c2755c40..6ec50b539 100644 --- a/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py +++ b/contrib/hamilton/contrib/user/skrawcz/customize_embeddings/__init__.py @@ -63,7 +63,7 @@ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -def _get_embedding(text: str, model="text-similarity-davinci-001", **kwargs) -> List[float]: +def _get_embedding(text: str, model="text-similarity-davinci-001", **kwargs) -> list[float]: """Get embedding from OpenAI API. :param text: text to embed. :param model: the embedding model to use. @@ -394,7 +394,7 @@ def embedded_data_set( def _accuracy_and_se( cosine_similarity: list[float], labeled_similarity: list[int] -) -> Tuple[float, float]: +) -> tuple[float, float]: """Calculate accuracy (and its standard error) of predicting label=1 if similarity>x x is optimized by sweeping from -1 to 1 in steps of 0.01 @@ -465,7 +465,7 @@ def accuracy_computation( return a, se -def _embedding_multiplied_by_matrix(embedding: List[float], matrix: torch.tensor) -> np.array: +def _embedding_multiplied_by_matrix(embedding: list[float], matrix: torch.tensor) -> np.array: """Helper function to multiply an embedding by a matrix.""" embedding_tensor = torch.tensor(embedding).float() modified_embedding = embedding_tensor @ matrix @@ -530,7 +530,7 @@ def tensors_from_dataframe( embedding_column_1: str, embedding_column_2: str, similarity_label_column: str, - ) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: + ) -> tuple[torch.tensor, torch.tensor, torch.tensor]: e1 = np.stack(np.array(df[embedding_column_1].values)) e2 = np.stack(np.array(df[embedding_column_2].values)) s = np.stack(np.array(df[similarity_label_column].astype("float").values)) @@ -638,7 +638,7 @@ def mse_loss(predictions, targets): optimization_result_matrices=group(*[source(k) for k in optimization_parameterization.keys()]) ) def optimization_results( - optimization_result_matrices: List[pd.DataFrame], + optimization_result_matrices: list[pd.DataFrame], ) -> pd.DataFrame: """Combine optimization results into one dataframe.""" return pd.concat(optimization_result_matrices) diff --git a/contrib/hamilton/contrib/user/skrawcz/fine_tuning/__init__.py b/contrib/hamilton/contrib/user/skrawcz/fine_tuning/__init__.py index 606520b5d..4f9531160 100644 --- a/contrib/hamilton/contrib/user/skrawcz/fine_tuning/__init__.py +++ b/contrib/hamilton/contrib/user/skrawcz/fine_tuning/__init__.py @@ -62,7 +62,7 @@ def raw_dataset( validation_size: float = 0.8, input_text_key: str = "question", output_text_key: str = "reply", -) -> Dict[str, Dataset]: +) -> dict[str, Dataset]: """Loads the raw dataset from disk and splits it into train and test sets. :param data_path: the path to the dataset. diff --git a/contrib/hamilton/contrib/user/zilto/lancedb_vdb/__init__.py b/contrib/hamilton/contrib/user/zilto/lancedb_vdb/__init__.py index 3d03953d4..40993f94d 100644 --- a/contrib/hamilton/contrib/user/zilto/lancedb_vdb/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/lancedb_vdb/__init__.py @@ -16,8 +16,9 @@ # under the License. import logging +from collections.abc import Iterable from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, List, Optional, Union logger = logging.getLogger(__name__) @@ -33,11 +34,11 @@ from hamilton.function_modifiers import tag VectorType = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] -DataType = Union[Dict, List[Dict], pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]] +DataType = Union[dict, list[dict], pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]] TableSchema = Union[pa.Schema, LanceModel] -def client(uri: Union[str, Path] = "./.lancedb") -> lancedb.DBConnection: +def client(uri: str | Path = "./.lancedb") -> lancedb.DBConnection: """Create a LanceDB connection. :param uri: path to local LanceDB @@ -49,7 +50,7 @@ def client(uri: Union[str, Path] = "./.lancedb") -> lancedb.DBConnection: def _create_table( client: lancedb.DBConnection, table_name: str, - schema: Optional[TableSchema] = None, + schema: TableSchema | None = None, overwrite_table: bool = False, ) -> lancedb.db.LanceTable: """Create a new table based on schema.""" @@ -62,7 +63,7 @@ def _create_table( def table_ref( client: lancedb.DBConnection, table_name: str, - schema: Optional[TableSchema] = None, + schema: TableSchema | None = None, overwrite_table: bool = False, ) -> lancedb.db.LanceTable: """Create or reference a LanceDB table @@ -91,7 +92,7 @@ def table_ref( @tag(side_effect="True") -def reset(client: lancedb.DBConnection) -> Dict[str, List[str]]: +def reset(client: lancedb.DBConnection) -> dict[str, list[str]]: """Drop all existing tables. :param vdb_client: LanceDB connection. @@ -106,7 +107,7 @@ def reset(client: lancedb.DBConnection) -> Dict[str, List[str]]: @tag(side_effect="True") -def insert(table_ref: lancedb.db.LanceTable, data: DataType) -> Dict: +def insert(table_ref: lancedb.db.LanceTable, data: DataType) -> dict: """Push new data to the specified table. :param table_ref: Reference to the LanceDB table. @@ -121,7 +122,7 @@ def insert(table_ref: lancedb.db.LanceTable, data: DataType) -> Dict: @tag(side_effect="True") -def delete(table_ref: lancedb.db.LanceTable, delete_expression: str) -> Dict: +def delete(table_ref: lancedb.db.LanceTable, delete_expression: str) -> dict: """Delete existing data using an SQL expression. :param table_ref: Reference to the LanceDB table. @@ -138,8 +139,8 @@ def delete(table_ref: lancedb.db.LanceTable, delete_expression: str) -> Dict: def vector_search( table_ref: lancedb.db.LanceTable, vector_query: VectorType, - columns: Optional[List[str]] = None, - where: Optional[str] = None, + columns: list[str] | None = None, + where: str | None = None, prefilter_where: bool = False, limit: int = 10, ) -> pd.DataFrame: @@ -169,8 +170,8 @@ def vector_search( def full_text_search( table_ref: lancedb.db.LanceTable, full_text_query: str, - full_text_index: Union[str, List[str]], - where: Optional[str] = None, + full_text_index: str | list[str], + where: str | None = None, limit: int = 10, rebuild_index: bool = True, ) -> pd.DataFrame: diff --git a/contrib/hamilton/contrib/user/zilto/llm_generate_code/__init__.py b/contrib/hamilton/contrib/user/zilto/llm_generate_code/__init__.py index 1d4257a83..5f594fd06 100644 --- a/contrib/hamilton/contrib/user/zilto/llm_generate_code/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/llm_generate_code/__init__.py @@ -30,7 +30,7 @@ import openai -def llm_client(api_key: Optional[str] = None) -> openai.OpenAI: +def llm_client(api_key: str | None = None) -> openai.OpenAI: """Create an OpenAI client.""" if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") diff --git a/contrib/hamilton/contrib/user/zilto/nixtla_mlforecast/__init__.py b/contrib/hamilton/contrib/user/zilto/nixtla_mlforecast/__init__.py index 76e7fcaed..50a67b4e2 100644 --- a/contrib/hamilton/contrib/user/zilto/nixtla_mlforecast/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/nixtla_mlforecast/__init__.py @@ -16,7 +16,8 @@ # under the License. import logging -from typing import Any, Callable, Iterable, Optional, Union +from collections.abc import Callable, Iterable +from typing import Any, TypeAlias logger = logging.getLogger(__name__) @@ -38,12 +39,12 @@ # sklearn compliant models (including XGBoost and LightGBM) subclass BaseEstimator -MODELS_TYPE = Union[ - BaseEstimator, list[BaseEstimator], dict[str, BaseEstimator] -] # equivalent to mlforecast.core.Models -LAG_TRANSFORMS_TYPE = dict[int, list[Union[Callable, tuple[Callable, Any]]]] -DATE_FEATURES_TYPE = Iterable[Union[str, Callable]] -CONFIDENCE_INTERVAL_TYPE = Optional[list[Union[int, float]]] +MODELS_TYPE = ( + BaseEstimator | list[BaseEstimator] | dict[str, BaseEstimator] +) # equivalent to mlforecast.core.Models +LAG_TRANSFORMS_TYPE: TypeAlias = dict[int, list[Callable | tuple[Callable, Any]]] +DATE_FEATURES_TYPE: TypeAlias = Iterable[str | Callable] +CONFIDENCE_INTERVAL_TYPE: TypeAlias = list[int | float] | None def base_models() -> MODELS_TYPE: @@ -81,11 +82,11 @@ def date_features() -> DATE_FEATURES_TYPE: def forecaster( base_models: MODELS_TYPE, - freq: Union[int, str] = "M", - lags: Optional[list[int]] = None, - lag_transforms: Optional[LAG_TRANSFORMS_TYPE] = None, - date_features: Optional[DATE_FEATURES_TYPE] = None, - target_transforms: Optional[list[BaseTargetTransform]] = None, + freq: int | str = "M", + lags: list[int] | None = None, + lag_transforms: LAG_TRANSFORMS_TYPE | None = None, + date_features: DATE_FEATURES_TYPE | None = None, + target_transforms: list[BaseTargetTransform] | None = None, num_threads: int = 1, ) -> MLForecast: """Create the forecasting harness with data and models @@ -106,15 +107,15 @@ def forecaster( def cross_validation_predictions( forecaster: MLForecast, dataset: pd.DataFrame, - static_features: Optional[list[str]] = None, + static_features: list[str] | None = None, dropna: bool = True, - keep_last_n_inputs: Optional[int] = None, - train_models_for_n_horizons: Optional[int] = None, + keep_last_n_inputs: int | None = None, + train_models_for_n_horizons: int | None = None, confidence_percentile: CONFIDENCE_INTERVAL_TYPE = None, cv_n_windows: int = 2, cv_forecast_horizon: int = 12, - cv_step_size: Optional[int] = None, - cv_input_size: Optional[int] = None, + cv_step_size: int | None = None, + cv_input_size: int | None = None, cv_refit: bool = True, cv_save_train_predictions: bool = True, ) -> pd.DataFrame: @@ -182,10 +183,10 @@ def best_model_per_series(cross_validation_evaluation: pd.DataFrame) -> pd.Serie def fitted_forecaster( forecaster: MLForecast, dataset: pd.DataFrame, - static_features: Optional[list[str]] = None, + static_features: list[str] | None = None, dropna: bool = True, - keep_last_n: Optional[int] = None, - train_models_for_n_horizons: Optional[int] = None, + keep_last_n: int | None = None, + train_models_for_n_horizons: int | None = None, save_train_predictions: bool = True, ) -> MLForecast: """Fit models over full dataset""" @@ -202,9 +203,9 @@ def fitted_forecaster( def inference_predictions( fitted_forecaster: MLForecast, inference_forecast_horizon: int = 12, - inference_uids: Optional[list[str]] = None, - inference_dataset: Optional[pd.DataFrame] = None, - inference_exogenous: Optional[pd.DataFrame] = None, + inference_uids: list[str] | None = None, + inference_dataset: pd.DataFrame | None = None, + inference_exogenous: pd.DataFrame | None = None, confidence_percentile: CONFIDENCE_INTERVAL_TYPE = None, ) -> pd.DataFrame: """Infer values using the trained models @@ -221,8 +222,8 @@ def inference_predictions( def plotting_config( plot_max_n_series: int = 4, - plot_uids: Optional[list[str]] = None, - plot_models: Optional[list[str]] = None, + plot_uids: list[str] | None = None, + plot_models: list[str] | None = None, plot_anomalies: bool = False, plot_confidence_percentile: CONFIDENCE_INTERVAL_TYPE = None, plot_engine: str = "matplotlib", diff --git a/contrib/hamilton/contrib/user/zilto/nixtla_statsforecast/__init__.py b/contrib/hamilton/contrib/user/zilto/nixtla_statsforecast/__init__.py index dac1dc52d..7b6785ed6 100644 --- a/contrib/hamilton/contrib/user/zilto/nixtla_statsforecast/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/nixtla_statsforecast/__init__.py @@ -16,7 +16,8 @@ # under the License. import logging -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional logger = logging.getLogger(__name__) @@ -155,8 +156,8 @@ def inference_predictions( def plotting_config( - plot_uids: Optional[list[str]] = None, - plot_models: Optional[list[str]] = None, + plot_uids: list[str] | None = None, + plot_models: list[str] | None = None, plot_anomalies: bool = False, plot_confidence_percentile: list[float] = [90.0], # noqa: B006 plot_engine: str = "matplotlib", diff --git a/contrib/hamilton/contrib/user/zilto/text_summarization/__init__.py b/contrib/hamilton/contrib/user/zilto/text_summarization/__init__.py index 2157d1bb5..a035a72b4 100644 --- a/contrib/hamilton/contrib/user/zilto/text_summarization/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/text_summarization/__init__.py @@ -18,7 +18,8 @@ import concurrent import logging import tempfile -from typing import Generator, Union +from collections.abc import Generator +from typing import Union logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper" @config.when(file_type="pdf") -def raw_text__pdf(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text__pdf(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. @@ -64,7 +65,7 @@ def raw_text__pdf(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) @config.when(file_type="txt") -def raw_text__txt(text_file: Union[str, tempfile.SpooledTemporaryFile]) -> str: +def raw_text__txt(text_file: str | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a text file and returns a string of the text file's contents :param text_file: the path, or the temporary file, to the text file. :return: the contents of the file as a string. diff --git a/contrib/hamilton/contrib/user/zilto/webscraper/__init__.py b/contrib/hamilton/contrib/user/zilto/webscraper/__init__.py index c5f7c404e..665cb9262 100644 --- a/contrib/hamilton/contrib/user/zilto/webscraper/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/webscraper/__init__.py @@ -45,7 +45,7 @@ class ParsingResult: parsed: Any -def url(urls: List[str]) -> Parallelizable[str]: +def url(urls: list[str]) -> Parallelizable[str]: """Iterate over the list of urls and create one branch per url :param urls: list of url to scrape and parse @@ -71,8 +71,8 @@ def html_page(url: str) -> str: def parsed_html( url: str, html_page: str, - tags_to_extract: List[str] = ["p", "li", "div"], # noqa: B006 - tags_to_remove: List[str] = ["script", "style"], # noqa: B006 + tags_to_extract: list[str] = ["p", "li", "div"], # noqa: B006 + tags_to_remove: list[str] = ["script", "style"], # noqa: B006 ) -> ParsingResult: """Parse an HTML string using BeautifulSoup @@ -104,7 +104,7 @@ def parsed_html( return ParsingResult(url=url, parsed=content) -def parsed_html_collection(parsed_html: Collect[ParsingResult]) -> List[ParsingResult]: +def parsed_html_collection(parsed_html: Collect[ParsingResult]) -> list[ParsingResult]: """Collect parallel branches of `parsed_html` :param parsed_html: receive the ParsingResult associated with each url diff --git a/contrib/hamilton/contrib/user/zilto/xgboost_optuna/__init__.py b/contrib/hamilton/contrib/user/zilto/xgboost_optuna/__init__.py index 2892d74d3..464690f7b 100644 --- a/contrib/hamilton/contrib/user/zilto/xgboost_optuna/__init__.py +++ b/contrib/hamilton/contrib/user/zilto/xgboost_optuna/__init__.py @@ -16,8 +16,9 @@ # under the License. import logging +from collections.abc import Callable, Sequence from types import FunctionType -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ from hamilton.function_modifiers import config, extract_fields -def model_config(seed: int = 0, model_config_override: Optional[dict] = None) -> dict: +def model_config(seed: int = 0, model_config_override: dict | None = None) -> dict: """XGBoost model configuration ref: https://xgboost.readthedocs.io/en/stable/parameter.html @@ -68,7 +69,7 @@ def model_config(seed: int = 0, model_config_override: Optional[dict] = None) -> return config -def optuna_distributions(optuna_distributions_override: Optional[dict] = None) -> dict: +def optuna_distributions(optuna_distributions_override: dict | None = None) -> dict: """Distributions of hyperparameters to search during optimization :param optuna_distributions_override: Hyperparameter distributions to explore @@ -150,10 +151,10 @@ def cross_validation_folds( def study( higher_is_better: bool, - pruner: Optional[optuna.pruners.BasePruner] = None, - sampler: Optional[optuna.samplers.BaseSampler] = None, - study_storage: Optional[str] = None, - study_name: Optional[str] = None, + pruner: optuna.pruners.BasePruner | None = None, + sampler: optuna.samplers.BaseSampler | None = None, + study_storage: str | None = None, + study_name: str | None = None, load_if_exists: bool = False, ) -> optuna.study.Study: """Create an optuna study; use the XGBoost + Optuna integration for pruning diff --git a/contrib/hamilton/contribute.py b/contrib/hamilton/contribute.py index 5d5a7a43d..bc1ece743 100644 --- a/contrib/hamilton/contribute.py +++ b/contrib/hamilton/contribute.py @@ -18,7 +18,6 @@ import logging import os import shutil -from typing import List import click import git @@ -62,7 +61,7 @@ def _get_base_template_dir(base_contrib_path: str): def _create_username_dir_if_not_exists( base_contrib_path: str, sanitized_username: str, username: str -) -> List[str]: +) -> list[str]: to_add = [] username_dir = os.path.join(base_contrib_path, sanitized_username) if not os.path.exists(username_dir): @@ -108,7 +107,7 @@ def _create_username_dir_if_not_exists( def _create_dataflow_dir_if_not_exists( base_contrib_path: str, sanitized_username: str, dataflow_name: str -) -> List[str]: +) -> list[str]: to_add = [] dataflow_dir = os.path.join(base_contrib_path, sanitized_username, dataflow_name) if not os.path.exists(dataflow_dir): @@ -140,7 +139,7 @@ def _create_dataflow_dir_if_not_exists( return to_add -def _git_add(files_to_add: List[str], git_repo_path: str): +def _git_add(files_to_add: list[str], git_repo_path: str): repo = git.Repo(git_repo_path) repo.index.add(files_to_add) logger.info(f"Adding files {files_to_add} to git! Happy developing!") diff --git a/dev_tools/language_server/hamilton_lsp/server.py b/dev_tools/language_server/hamilton_lsp/server.py index 64536b1aa..042c47085 100644 --- a/dev_tools/language_server/hamilton_lsp/server.py +++ b/dev_tools/language_server/hamilton_lsp/server.py @@ -17,7 +17,6 @@ import inspect import re -from typing import Type from lsprotocol.types import ( TEXT_DOCUMENT_COMPLETION, @@ -54,7 +53,7 @@ from hamilton_lsp import __version__ -def _type_to_string(type_: Type): +def _type_to_string(type_: type): """Return the full path of type, but may not be accessible from document For example, `pandas.core.series.Series` while document defines `pandas as pd` """ diff --git a/examples/LLM_Workflows/GraphRAG/application.py b/examples/LLM_Workflows/GraphRAG/application.py index a6ee454d6..4c978eeae 100644 --- a/examples/LLM_Workflows/GraphRAG/application.py +++ b/examples/LLM_Workflows/GraphRAG/application.py @@ -17,7 +17,6 @@ import json import uuid -from typing import Tuple import falkordb import openai @@ -128,7 +127,7 @@ def run_cypher_query(graph, query): reads=[], writes=["question", "chat_history"], ) -def human_converse(state: State, user_question: str) -> Tuple[dict, State]: +def human_converse(state: State, user_question: str) -> tuple[dict, State]: """Human converse step -- make sure we get input, and store it as state.""" new_state = state.update(question=user_question) new_state = new_state.append(chat_history={"role": "user", "content": user_question}) @@ -161,7 +160,7 @@ def AI_create_cypher_query(state: State, client: openai.Client) -> tuple[dict, S reads=["tool_calls", "chat_history"], writes=["tool_calls", "chat_history"], ) -def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]: +def tool_call(state: State, graph: falkordb.Graph) -> tuple[dict, State]: """Tool call step -- execute the tool call.""" tool_calls = state.get("tool_calls", []) new_state = state diff --git a/examples/LLM_Workflows/GraphRAG/notebook.ipynb b/examples/LLM_Workflows/GraphRAG/notebook.ipynb index f56e3fb24..8d4773bef 100644 --- a/examples/LLM_Workflows/GraphRAG/notebook.ipynb +++ b/examples/LLM_Workflows/GraphRAG/notebook.ipynb @@ -45,7 +45,6 @@ "# import what we need\n", "import json\n", "import uuid\n", - "from typing import Tuple\n", "\n", "import falkordb\n", "import openai\n", @@ -223,7 +222,7 @@ " reads=[],\n", " writes=[\"question\", \"chat_history\"],\n", ")\n", - "def human_converse(state: State, user_question: str) -> Tuple[dict, State]:\n", + "def human_converse(state: State, user_question: str) -> tuple[dict, State]:\n", " \"\"\"Human converse step -- make sure we get input, and store it as state.\"\"\"\n", " new_state = state.update(question=user_question)\n", " new_state = new_state.append(chat_history={\"role\": \"user\", \"content\": user_question})\n", @@ -281,7 +280,7 @@ " reads=[\"tool_calls\", \"chat_history\"],\n", " writes=[\"tool_calls\", \"chat_history\"],\n", ")\n", - "def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]:\n", + "def tool_call(state: State, graph: falkordb.Graph) -> tuple[dict, State]:\n", " \"\"\"Tool call step -- execute the query and append to chat history.\"\"\"\n", " tool_calls = state.get(\"tool_calls\", [])\n", " new_state = state\n", diff --git a/examples/LLM_Workflows/NER_Example/lancedb_module.py b/examples/LLM_Workflows/NER_Example/lancedb_module.py index 70384a0de..c2c002f9b 100644 --- a/examples/LLM_Workflows/NER_Example/lancedb_module.py +++ b/examples/LLM_Workflows/NER_Example/lancedb_module.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Union import lancedb import numpy as np @@ -31,7 +30,7 @@ def db_client() -> lancedb.DBConnection: def _write_to_lancedb( - data: Union[list[dict], pa.Table], db: lancedb.DBConnection, table_name: str + data: list[dict] | pa.Table, db: lancedb.DBConnection, table_name: str ) -> int: """Helper function to write to lancedb. diff --git a/examples/LLM_Workflows/NER_Example/ner_extraction.py b/examples/LLM_Workflows/NER_Example/ner_extraction.py index b8287254e..4be611c92 100644 --- a/examples/LLM_Workflows/NER_Example/ner_extraction.py +++ b/examples/LLM_Workflows/NER_Example/ner_extraction.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Union import torch from datasets import Dataset, load_dataset # noqa: F401 @@ -129,7 +128,7 @@ def retriever( def _extract_named_entities_text( - title_text_batch: Union[LazyBatch, list[str]], _ner_pipeline + title_text_batch: LazyBatch | list[str], _ner_pipeline ) -> list[list[str]]: """Helper function to extract named entities given a batch of text.""" # extract named entities using the NER pipeline diff --git a/examples/LLM_Workflows/RAG_document_extract_chunk_embed/pipeline.py b/examples/LLM_Workflows/RAG_document_extract_chunk_embed/pipeline.py index 879f17372..b4de89ae2 100644 --- a/examples/LLM_Workflows/RAG_document_extract_chunk_embed/pipeline.py +++ b/examples/LLM_Workflows/RAG_document_extract_chunk_embed/pipeline.py @@ -19,7 +19,7 @@ Modules that mirrors the pipeline the code in the notebook creates. """ -from typing import NamedTuple, Optional +from typing import NamedTuple class Chunk(NamedTuple): @@ -28,8 +28,8 @@ class Chunk(NamedTuple): index: int document_id: str text: str - embedding: Optional[list[float]] - metadata: Optional[dict[str, str]] + embedding: list[float] | None + metadata: dict[str, str] | None def add_embedding(self, embedding: list[float]) -> "Chunk": """Required to update chunk with embeddings""" @@ -46,7 +46,7 @@ class Document(NamedTuple): id: str url: str raw_text: str - chunks: Optional[list[Chunk]] + chunks: list[Chunk] | None def add_chunks(self, chunks: list[Chunk]) -> "Document": """Required to update the document when Chunks are created""" diff --git a/examples/LLM_Workflows/image_telephone/adapters.py b/examples/LLM_Workflows/image_telephone/adapters.py index 5c74cfb8b..1c3804e7b 100644 --- a/examples/LLM_Workflows/image_telephone/adapters.py +++ b/examples/LLM_Workflows/image_telephone/adapters.py @@ -19,7 +19,8 @@ import io import json import logging -from typing import Any, Collection, Dict, Type +from collections.abc import Collection +from typing import Any from urllib import parse import boto3 @@ -39,12 +40,12 @@ class JSONS3DataSaver(DataSaver): bucket: str key: str - def save_data(self, data: dict) -> Dict[str, Any]: + def save_data(self, data: dict) -> dict[str, Any]: data = json.dumps(data).encode() client.put_object(Body=data, Bucket=self.bucket, Key=self.key) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [dict] @classmethod @@ -74,7 +75,7 @@ class ImageS3DataSaver(DataSaver): format: str # image_convert_params: Optional[Dict[str, Any]] = None - def save_data(self, data: str) -> Dict[str, Any]: + def save_data(self, data: str) -> dict[str, Any]: image = _load_image(data, self.format) in_mem_file = io.BytesIO() image.save(in_mem_file, format=self.format) @@ -83,7 +84,7 @@ def save_data(self, data: str) -> Dict[str, Any]: return {"key": self.key, "bucket": self.bucket} @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [str] # URL or local path @classmethod @@ -97,13 +98,13 @@ class LocalImageSaver(DataSaver): format: str # image_convert_params: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict) - def save_data(self, data: str) -> Dict[str, Any]: + def save_data(self, data: str) -> dict[str, Any]: image = _load_image(data, self.format) image.save(self.path, format=self.format) return {"path": self.path} @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [str] # URL or local path @classmethod diff --git a/examples/LLM_Workflows/image_telephone/notebook.ipynb b/examples/LLM_Workflows/image_telephone/notebook.ipynb index 7eba9376a..fe6454fb8 100644 --- a/examples/LLM_Workflows/image_telephone/notebook.ipynb +++ b/examples/LLM_Workflows/image_telephone/notebook.ipynb @@ -77,8 +77,6 @@ }, "outputs": [], "source": [ - "from typing import Tuple\n", - "\n", "from burr.core import ApplicationBuilder, Result, State, default, expr\n", "from burr.core.action import action\n", "from PIL import Image\n", @@ -163,7 +161,7 @@ " reads=[\"current_image_location\"],\n", " writes=[\"current_image_caption\", \"image_location_history\"],\n", ")\n", - "def image_caption(state: State, caption_image_driver: driver.Driver) -> Tuple[dict, State]:\n", + "def image_caption(state: State, caption_image_driver: driver.Driver) -> tuple[dict, State]:\n", " current_image = state[\"current_image_location\"]\n", " result = caption_image_driver.execute(\n", " [\"generated_caption\"], inputs={\"image_url\": current_image}\n", @@ -178,7 +176,7 @@ " reads=[\"current_image_caption\"],\n", " writes=[\"current_image_location\", \"image_caption_history\"],\n", ")\n", - "def image_generation(state: State, generate_image_driver: driver.Driver) -> Tuple[dict, State]:\n", + "def image_generation(state: State, generate_image_driver: driver.Driver) -> tuple[dict, State]:\n", " current_caption = state[\"current_image_caption\"]\n", " result = generate_image_driver.execute(\n", " [\"generated_image\"], inputs={\"image_generation_prompt\": current_caption}\n", diff --git a/examples/LLM_Workflows/image_telephone/streamlit.py b/examples/LLM_Workflows/image_telephone/streamlit.py index 2b65ffed3..de65d2c85 100644 --- a/examples/LLM_Workflows/image_telephone/streamlit.py +++ b/examples/LLM_Workflows/image_telephone/streamlit.py @@ -22,7 +22,7 @@ import urllib.parse from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any import altair as alt import boto3 @@ -86,7 +86,7 @@ def download_s3_file(s3_client, bucket_name, file_key, _transform=lambda x: x): @st.cache_data(ttl=100) -def query_data(bucket: str, paths: Tuple[str, ...], _s3_client, _transform) -> Any: +def query_data(bucket: str, paths: tuple[str, ...], _s3_client, _transform) -> Any: """Function to query metadata from the s3 bucket""" file_contents = {} @@ -107,7 +107,7 @@ def get_url(key: str) -> str: return f"https://d1lf8m1wnxcl0a.cloudfront.net/{key}" -def list_prompts_and_images(prompt: str, paths: Tuple[str, ...]) -> tuple[list[str], list[str]]: +def list_prompts_and_images(prompt: str, paths: tuple[str, ...]) -> tuple[list[str], list[str]]: """Lists out prompts and images paths, given an image name (prompt)""" prompt_entries = [p for p in paths if p.endswith(".json")] image_entries = [p for p in paths if p.endswith(".jpeg") & ("original" not in p.split("/")[-1])] @@ -233,7 +233,7 @@ def embedding_path_plot(vis_dims, image_list, highlight_idx, image_name): st.altair_chart(altair_chart, use_container_width=True, theme="streamlit") -def prompt_dropdown(items: List[str], incomplete_items) -> str: +def prompt_dropdown(items: list[str], incomplete_items) -> str: """Selects the seed image (prompt) -- also updates the URL state""" query_params = st.experimental_get_query_params() key = "seed_image" @@ -366,7 +366,7 @@ def s3(): @st.cache_data(ttl=100) # every 20 seconds we'll update? -def list_files_in_bucket(_s3_client, bucket) -> Dict[str, tuple[str, ...]]: +def list_files_in_bucket(_s3_client, bucket) -> dict[str, tuple[str, ...]]: """Gets name of all files in the s3 bucket grouped by the prefix (first directory).""" paginator = _s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket) @@ -438,7 +438,7 @@ def explore_display(): @st.cache_data() -def load_resources() -> Tuple[str, str, str, str]: +def load_resources() -> tuple[str, str, str, str]: """Returns the code for the dataflow and the code for the adapter""" # TODO -- get this to work when we have hamilton contrib is_local_mode = os.environ.get("LOCAL_MODE", "false").lower() == "true" diff --git a/examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py b/examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py index 8a3cad561..0d8331aee 100644 --- a/examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py +++ b/examples/LLM_Workflows/knowledge_retrieval/arxiv_articles.py @@ -16,7 +16,6 @@ # under the License. import os.path -from typing import Dict import arxiv import openai @@ -48,7 +47,7 @@ def arxiv_search_result( @extract_fields({"title": str, "summary": str, "article_url": str, "pdf_url": str}) -def result(arxiv_search_result: arxiv.Result) -> Dict[str, str]: +def result(arxiv_search_result: arxiv.Result) -> dict[str, str]: return { "title": arxiv_search_result.title, "summary": arxiv_search_result.summary, @@ -91,7 +90,7 @@ def arxiv_processed_result( pdf_url: str, arxiv_pdf: str, arxiv_result_embedding: list[float], -) -> Dict[str, str]: +) -> dict[str, str]: """creates dict with parameters as keys/values""" return { "title": title, @@ -103,7 +102,7 @@ def arxiv_processed_result( } -def arxiv_result_df(arxiv_processed_result: Collect[Dict[str, str]]) -> pd.DataFrame: +def arxiv_result_df(arxiv_processed_result: Collect[dict[str, str]]) -> pd.DataFrame: """Joins the arxiv results back to a dataframe. :param arxiv_processed_result: result of all the joined arxiv result information diff --git a/examples/LLM_Workflows/knowledge_retrieval/summarize_text.py b/examples/LLM_Workflows/knowledge_retrieval/summarize_text.py index 95cc87a81..bef710fbc 100644 --- a/examples/LLM_Workflows/knowledge_retrieval/summarize_text.py +++ b/examples/LLM_Workflows/knowledge_retrieval/summarize_text.py @@ -17,7 +17,7 @@ import ast import concurrent -from typing import Callable, Generator, List +from collections.abc import Callable, Generator import openai import pandas as pd @@ -45,14 +45,14 @@ def summarize_paper_from_summaries_prompt() -> str: @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) -def user_query_embedding(user_query: str, embedding_model_name: str) -> List[float]: +def user_query_embedding(user_query: str, embedding_model_name: str) -> list[float]: """Get the embedding for a user query from OpenAI API.""" response = openai.Embedding.create(input=user_query, model=embedding_model_name) return response["data"][0]["embedding"] def relatedness( - user_query_embedding: List[float], + user_query_embedding: list[float], embeddings: pd.Series, relatedness_fn: Callable = lambda x, y: 1 - spatial.distance.cosine(x, y), ) -> pd.Series: diff --git a/examples/LLM_Workflows/langchain_comparison/hamilton_async.py b/examples/LLM_Workflows/langchain_comparison/hamilton_async.py index 551c26455..db29986f6 100644 --- a/examples/LLM_Workflows/langchain_comparison/hamilton_async.py +++ b/examples/LLM_Workflows/langchain_comparison/hamilton_async.py @@ -16,7 +16,6 @@ # under the License. # hamilton_async.py -from typing import List import openai @@ -29,11 +28,11 @@ def joke_prompt(topic: str) -> str: return f"Tell me a short joke about {topic}" -def joke_messages(joke_prompt: str) -> List[dict]: +def joke_messages(joke_prompt: str) -> list[dict]: return [{"role": "user", "content": joke_prompt}] -async def joke_response(llm_client: openai.AsyncOpenAI, joke_messages: List[dict]) -> str: +async def joke_response(llm_client: openai.AsyncOpenAI, joke_messages: list[dict]) -> str: response = await llm_client.chat.completions.create( model="gpt-3.5-turbo", messages=joke_messages, diff --git a/examples/LLM_Workflows/langchain_comparison/hamilton_batch.py b/examples/LLM_Workflows/langchain_comparison/hamilton_batch.py index 8e2161e2a..594f5a20e 100644 --- a/examples/LLM_Workflows/langchain_comparison/hamilton_batch.py +++ b/examples/LLM_Workflows/langchain_comparison/hamilton_batch.py @@ -16,7 +16,6 @@ # under the License. # hamilton_batch.py -from typing import List import openai @@ -37,11 +36,11 @@ def joke_prompt(topic: str) -> str: return f"Tell me a short joke about {topic}" -def joke_messages(joke_prompt: str) -> List[dict]: +def joke_messages(joke_prompt: str) -> list[dict]: return [{"role": "user", "content": joke_prompt}] -def joke_response(llm_client: openai.OpenAI, joke_messages: List[dict]) -> str: +def joke_response(llm_client: openai.OpenAI, joke_messages: list[dict]) -> str: response = llm_client.chat.completions.create( model="gpt-3.5-turbo", messages=joke_messages, @@ -49,7 +48,7 @@ def joke_response(llm_client: openai.OpenAI, joke_messages: List[dict]) -> str: return response.choices[0].message.content -def joke_responses(joke_response: Collect[str]) -> List[str]: +def joke_responses(joke_response: Collect[str]) -> list[str]: return list(joke_response) diff --git a/examples/LLM_Workflows/langchain_comparison/hamilton_completion.py b/examples/LLM_Workflows/langchain_comparison/hamilton_completion.py index 695f8d983..de033f135 100644 --- a/examples/LLM_Workflows/langchain_comparison/hamilton_completion.py +++ b/examples/LLM_Workflows/langchain_comparison/hamilton_completion.py @@ -16,7 +16,6 @@ # under the License. # hamilton_completion.py -from typing import List import openai @@ -31,7 +30,7 @@ def joke_prompt(topic: str) -> str: return f"Tell me a short joke about {topic}" -def joke_messages(joke_prompt: str) -> List[dict]: +def joke_messages(joke_prompt: str) -> list[dict]: return [{"role": "user", "content": joke_prompt}] @@ -45,7 +44,7 @@ def joke_response__completion(llm_client: openai.OpenAI, joke_prompt: str) -> st @config.when(type="chat") -def joke_response__chat(llm_client: openai.OpenAI, joke_messages: List[dict]) -> str: +def joke_response__chat(llm_client: openai.OpenAI, joke_messages: list[dict]) -> str: response = llm_client.chat.completions.create( model="gpt-3.5-turbo", messages=joke_messages, diff --git a/examples/LLM_Workflows/langchain_comparison/hamilton_invoke.py b/examples/LLM_Workflows/langchain_comparison/hamilton_invoke.py index 0b87e5f4f..aee80d142 100644 --- a/examples/LLM_Workflows/langchain_comparison/hamilton_invoke.py +++ b/examples/LLM_Workflows/langchain_comparison/hamilton_invoke.py @@ -16,7 +16,6 @@ # under the License. # hamilton_invoke.py -from typing import List import openai @@ -29,11 +28,11 @@ def joke_prompt(topic: str) -> str: return f"Tell me a short joke about {topic}" -def joke_messages(joke_prompt: str) -> List[dict]: +def joke_messages(joke_prompt: str) -> list[dict]: return [{"role": "user", "content": joke_prompt}] -def joke_response(llm_client: openai.OpenAI, joke_messages: List[dict]) -> str: +def joke_response(llm_client: openai.OpenAI, joke_messages: list[dict]) -> str: response = llm_client.chat.completions.create( model="gpt-3.5-turbo", messages=joke_messages, diff --git a/examples/LLM_Workflows/langchain_comparison/hamilton_streamed.py b/examples/LLM_Workflows/langchain_comparison/hamilton_streamed.py index 6d998aa6c..538a0d700 100644 --- a/examples/LLM_Workflows/langchain_comparison/hamilton_streamed.py +++ b/examples/LLM_Workflows/langchain_comparison/hamilton_streamed.py @@ -16,7 +16,7 @@ # under the License. # hamilton_streamed.py -from typing import Iterator, List +from collections.abc import Iterator import openai @@ -29,11 +29,11 @@ def joke_prompt(topic: str) -> str: return f"Tell me a short joke about {topic}" -def joke_messages(joke_prompt: str) -> List[dict]: +def joke_messages(joke_prompt: str) -> list[dict]: return [{"role": "user", "content": joke_prompt}] -def joke_response(llm_client: openai.OpenAI, joke_messages: List[dict]) -> Iterator[str]: +def joke_response(llm_client: openai.OpenAI, joke_messages: list[dict]) -> Iterator[str]: stream = llm_client.chat.completions.create( model="gpt-3.5-turbo", messages=joke_messages, stream=True ) diff --git a/examples/LLM_Workflows/langchain_comparison/vanilla_async.py b/examples/LLM_Workflows/langchain_comparison/vanilla_async.py index 6cfc74686..e786732d7 100644 --- a/examples/LLM_Workflows/langchain_comparison/vanilla_async.py +++ b/examples/LLM_Workflows/langchain_comparison/vanilla_async.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import openai @@ -26,7 +25,7 @@ async_client = openai.AsyncOpenAI() -async def acall_chat_model(messages: List[dict]) -> str: +async def acall_chat_model(messages: list[dict]) -> str: response = await async_client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, diff --git a/examples/LLM_Workflows/langchain_comparison/vanilla_batch.py b/examples/LLM_Workflows/langchain_comparison/vanilla_batch.py index c87741f5f..f9fa5d28a 100644 --- a/examples/LLM_Workflows/langchain_comparison/vanilla_batch.py +++ b/examples/LLM_Workflows/langchain_comparison/vanilla_batch.py @@ -16,7 +16,6 @@ # under the License. from concurrent.futures import ThreadPoolExecutor -from typing import List import openai @@ -24,7 +23,7 @@ client = openai.OpenAI() -def call_chat_model(messages: List[dict]) -> str: +def call_chat_model(messages: list[dict]) -> str: response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, diff --git a/examples/LLM_Workflows/langchain_comparison/vanilla_invoke.py b/examples/LLM_Workflows/langchain_comparison/vanilla_invoke.py index eab368970..5fc8bd622 100644 --- a/examples/LLM_Workflows/langchain_comparison/vanilla_invoke.py +++ b/examples/LLM_Workflows/langchain_comparison/vanilla_invoke.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import openai @@ -23,7 +22,7 @@ client = openai.OpenAI() -def call_chat_model(messages: List[dict]) -> str: +def call_chat_model(messages: list[dict]) -> str: response = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, diff --git a/examples/LLM_Workflows/langchain_comparison/vanilla_streamed.py b/examples/LLM_Workflows/langchain_comparison/vanilla_streamed.py index 2973acc8f..1353fe054 100644 --- a/examples/LLM_Workflows/langchain_comparison/vanilla_streamed.py +++ b/examples/LLM_Workflows/langchain_comparison/vanilla_streamed.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator, List +from collections.abc import Iterator import openai @@ -23,7 +23,7 @@ client = openai.OpenAI() -def stream_chat_model(messages: List[dict]) -> Iterator[str]: +def stream_chat_model(messages: list[dict]) -> Iterator[str]: stream = client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, diff --git a/examples/LLM_Workflows/llm_logging/summarization.py b/examples/LLM_Workflows/llm_logging/summarization.py index 6e9cbab38..4c1069b2a 100644 --- a/examples/LLM_Workflows/llm_logging/summarization.py +++ b/examples/LLM_Workflows/llm_logging/summarization.py @@ -17,7 +17,7 @@ import os import tempfile -from typing import Generator, Union +from collections.abc import Generator import tiktoken from openai import OpenAI @@ -30,7 +30,7 @@ def openai_client() -> OpenAI: return OpenAI(api_key=os.environ["OPENAI_API_KEY"]) -def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. diff --git a/examples/LLM_Workflows/modular_llm_stack/marqo_module.py b/examples/LLM_Workflows/modular_llm_stack/marqo_module.py index 956d1a514..28db62525 100644 --- a/examples/LLM_Workflows/modular_llm_stack/marqo_module.py +++ b/examples/LLM_Workflows/modular_llm_stack/marqo_module.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Union +from typing import Any import marqo @@ -94,7 +94,7 @@ def query_vector_db( include_metadata: bool = True, include_vectors: bool = False, namespace: str = None, - ) -> list[dict[str, Union[Union[list[Any], dict], Any]]]: + ) -> list[dict[str, list[Any] | dict | Any]]: params = { "limit": top_k, "attributes_to_retrieve": ["*"] if include_metadata else ["_id"], diff --git a/examples/LLM_Workflows/pdf_summarizer/backend/parallel_summarization.py b/examples/LLM_Workflows/pdf_summarizer/backend/parallel_summarization.py index 24afdf68c..9a585339a 100644 --- a/examples/LLM_Workflows/pdf_summarizer/backend/parallel_summarization.py +++ b/examples/LLM_Workflows/pdf_summarizer/backend/parallel_summarization.py @@ -21,7 +21,7 @@ """ import tempfile -from typing import Generator, Union +from collections.abc import Generator import tiktoken from openai import OpenAI @@ -64,7 +64,7 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper" @config.when(file_type="pdf") -def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. diff --git a/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py b/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py index 11df3cc78..d555f0bc5 100644 --- a/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py +++ b/examples/LLM_Workflows/pdf_summarizer/backend/summarization.py @@ -17,7 +17,7 @@ import concurrent import tempfile -from typing import Generator, Union +from collections.abc import Generator from openai import OpenAI @@ -45,7 +45,7 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper" @config.when(file_type="pdf") -def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. diff --git a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py index 7d8641f78..8e170ad39 100644 --- a/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py +++ b/examples/LLM_Workflows/pdf_summarizer/run_on_spark/summarization.py @@ -17,7 +17,7 @@ import concurrent import tempfile -from typing import Generator, Union +from collections.abc import Generator import openai @@ -52,7 +52,7 @@ def summarize_text_from_summaries_prompt(content_type: str = "an academic paper" @config.when(file_type="pdf") -def raw_text(pdf_source: Union[str, bytes, tempfile.SpooledTemporaryFile]) -> str: +def raw_text(pdf_source: str | bytes | tempfile.SpooledTemporaryFile) -> str: """Takes a filepath to a PDF and returns a string of the PDF's contents :param pdf_source: the path, or the temporary file, to the PDF. :return: the text of the PDF. diff --git a/examples/LLM_Workflows/retrieval_augmented_generation/backend/ingestion.py b/examples/LLM_Workflows/retrieval_augmented_generation/backend/ingestion.py index 209debc2e..4b7f75464 100644 --- a/examples/LLM_Workflows/retrieval_augmented_generation/backend/ingestion.py +++ b/examples/LLM_Workflows/retrieval_augmented_generation/backend/ingestion.py @@ -17,8 +17,8 @@ import base64 import io +from collections.abc import Generator from pathlib import Path -from typing import Generator import arxiv import fastapi @@ -113,7 +113,7 @@ def raw_text(pdf_content: io.BytesIO) -> str: Throw exception if unable to read PDF """ reader = PyPDF2.PdfReader(pdf_content) - pdf_text = " ".join((page.extract_text() for page in reader.pages)) + pdf_text = " ".join(page.extract_text() for page in reader.pages) return pdf_text diff --git a/examples/airflow/plugins/absenteeism/prepare_data.py b/examples/airflow/plugins/absenteeism/prepare_data.py index d4dd2c0a4..8bf4525cb 100644 --- a/examples/airflow/plugins/absenteeism/prepare_data.py +++ b/examples/airflow/plugins/absenteeism/prepare_data.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import numpy as np import pandas as pd @@ -76,7 +75,7 @@ ] -def _rename_columns(columns: List[str]) -> List[str]: +def _rename_columns(columns: list[str]) -> list[str]: """convert raw data column names to snakecase and make them compatible with Hamilton's naming convention (need to be a valid Python function name) diff --git a/examples/airflow/plugins/absenteeism/train_model.py b/examples/airflow/plugins/absenteeism/train_model.py index 1b5ebe420..dc974f1e5 100644 --- a/examples/airflow/plugins/absenteeism/train_model.py +++ b/examples/airflow/plugins/absenteeism/train_model.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, List import numpy as np import pandas as pd @@ -52,8 +51,8 @@ def preprocessed_data__dev(df: pd.DataFrame) -> pd.DataFrame: ) ) def split_indices( - preprocessed_data: pd.DataFrame, validation_user_ids: List[int] -) -> Dict[str, pd.DataFrame]: + preprocessed_data: pd.DataFrame, validation_user_ids: list[int] +) -> dict[str, pd.DataFrame]: """Creating train-validation splits based on the list of `validation_user_ids`""" validation_selection_mask = preprocessed_data.id.isin([int(i) for i in validation_user_ids]) @@ -63,7 +62,7 @@ def split_indices( ) -def data_stats(preprocessed_data: pd.DataFrame, feature_set: List[str]) -> pd.DataFrame: +def data_stats(preprocessed_data: pd.DataFrame, feature_set: list[str]) -> pd.DataFrame: return preprocessed_data[feature_set].describe() @@ -71,7 +70,7 @@ def data_stats(preprocessed_data: pd.DataFrame, feature_set: List[str]) -> pd.Da X_train=dict(df=source("train_df"), feature_set=source("feature_set")), X_validation=dict(df=source("validation_df"), feature_set=source("feature_set")), ) -def features(df: pd.DataFrame, feature_set: List[str]) -> np.ndarray: +def features(df: pd.DataFrame, feature_set: list[str]) -> np.ndarray: """Select features from `preprocessed_data` based on `feature_set`""" return df[feature_set].to_numpy() diff --git a/examples/airflow/plugins/function_modules/data_loaders.py b/examples/airflow/plugins/function_modules/data_loaders.py index 615347e9a..6f42a404b 100644 --- a/examples/airflow/plugins/function_modules/data_loaders.py +++ b/examples/airflow/plugins/function_modules/data_loaders.py @@ -25,8 +25,6 @@ the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ -from typing import List - import pandas as pd from hamilton.function_modifiers import config, extract_columns, load_from, source, value @@ -56,7 +54,7 @@ ] -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. diff --git a/examples/dagster/dagster_code/tutorial/resources/__init__.py b/examples/dagster/dagster_code/tutorial/resources/__init__.py index 82263ae7c..d119c74d2 100644 --- a/examples/dagster/dagster_code/tutorial/resources/__init__.py +++ b/examples/dagster/dagster_code/tutorial/resources/__init__.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from dataclasses import asdict, dataclass from datetime import datetime, timedelta from random import Random -from typing import Sequence, Union +from typing import Union from dagster import ConfigurableResource from faker import Faker @@ -120,7 +121,7 @@ def get_signups_for_date(self, date: datetime) -> Sequence[Signup]: return sorted(signups, key=lambda x: x["registered_at"]) def get_signups_for_dates( - self, start_date: datetime, end_date: Union[datetime, None] = None + self, start_date: datetime, end_date: datetime | None = None ) -> Sequence[Signup]: signups = [] diff --git a/examples/dagster/hamilton_code/mock_api.py b/examples/dagster/hamilton_code/mock_api.py index 77c5687e8..5baf5f43d 100644 --- a/examples/dagster/hamilton_code/mock_api.py +++ b/examples/dagster/hamilton_code/mock_api.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from dataclasses import asdict, dataclass from datetime import datetime, timedelta from random import Random -from typing import Sequence, Union from dagster import ConfigurableResource from faker import Faker @@ -121,7 +121,7 @@ def get_signups_for_date(self, date: datetime) -> Sequence[Signup]: return sorted(signups, key=lambda x: x["registered_at"]) def get_signups_for_dates( - self, start_date: datetime, end_date: Union[datetime, None] = None + self, start_date: datetime, end_date: datetime | None = None ) -> Sequence[Signup]: signups = [] diff --git a/examples/data_quality/pandera/data_loaders.py b/examples/data_quality/pandera/data_loaders.py index 52c821ac6..8e5d1f04c 100644 --- a/examples/data_quality/pandera/data_loaders.py +++ b/examples/data_quality/pandera/data_loaders.py @@ -27,8 +27,6 @@ the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ -from typing import List - import pandas as pd from hamilton.function_modifiers import config, extract_columns, load_from, source, value @@ -58,7 +56,7 @@ ] -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. diff --git a/examples/data_quality/simple/data_loaders.py b/examples/data_quality/simple/data_loaders.py index bb1eb1120..042765e51 100644 --- a/examples/data_quality/simple/data_loaders.py +++ b/examples/data_quality/simple/data_loaders.py @@ -25,8 +25,6 @@ the driver choose which one to use for the DAG. For the purposes of this example, we decided one file is simpler. """ -from typing import List - import pandas as pd from hamilton.function_modifiers import config, extract_columns, load_from, source, value @@ -56,7 +54,7 @@ ] -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. diff --git a/examples/dbt/python_transforms/data_loader.py b/examples/dbt/python_transforms/data_loader.py index dfe8f3d72..cc7498997 100644 --- a/examples/dbt/python_transforms/data_loader.py +++ b/examples/dbt/python_transforms/data_loader.py @@ -19,8 +19,6 @@ This module contains our data loading functions. """ -from typing import List - import pandas as pd import pandera as pa from sklearn import datasets @@ -29,7 +27,7 @@ from hamilton.function_modifiers import check_output, config, extract_columns -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. :return: sanitize column names that work with Hamilton diff --git a/examples/dbt/python_transforms/feature_transforms.py b/examples/dbt/python_transforms/feature_transforms.py index 04b71c708..66f6e9c44 100644 --- a/examples/dbt/python_transforms/feature_transforms.py +++ b/examples/dbt/python_transforms/feature_transforms.py @@ -20,7 +20,6 @@ """ import pickle -from typing import Set import pandas as pd @@ -33,7 +32,7 @@ from hamilton.function_modifiers import check_output, config -def rare_titles() -> Set[str]: +def rare_titles() -> set[str]: """Rare titles we've curated""" return { "Capt", @@ -63,7 +62,7 @@ def normalized_name(name: pd.Series) -> pd.Series: return name.apply(lambda x: x.split(",")[1].split(".")[0].strip()) -def title(normalized_name: pd.Series, rare_titles: Set[str]) -> pd.Series: +def title(normalized_name: pd.Series, rare_titles: set[str]) -> pd.Series: return normalized_name.apply(lambda n: "rare" if n in rare_titles else n) diff --git a/examples/dbt/python_transforms/model_pipeline.py b/examples/dbt/python_transforms/model_pipeline.py index 0c1695345..a660cd902 100644 --- a/examples/dbt/python_transforms/model_pipeline.py +++ b/examples/dbt/python_transforms/model_pipeline.py @@ -20,7 +20,6 @@ """ import pickle -from typing import Dict import numpy as np import pandas as pd @@ -47,7 +46,7 @@ def model_classifier(random_state: int) -> base.ClassifierMixin: @extract_fields({"train_set": pd.DataFrame, "test_set": pd.DataFrame}) def train_test_split( data_set: pd.DataFrame, target: pd.Series, test_size: float -) -> Dict[str, pd.DataFrame]: +) -> dict[str, pd.DataFrame]: """Splits the dataset into train & test. :param data_set: the dataset with all features already computed diff --git a/examples/decoupling_io/adapters.py b/examples/decoupling_io/adapters.py index e26060367..230ddeba2 100644 --- a/examples/decoupling_io/adapters.py +++ b/examples/decoupling_io/adapters.py @@ -16,8 +16,9 @@ # under the License. import dataclasses +from collections.abc import Collection from os import PathLike -from typing import Any, Collection, Dict, Optional, Type, Union +from typing import Any, Union # This is not necessary once this PR gets merged: https://github.com/apache/hamilton/pull/467. try: @@ -62,24 +63,24 @@ @dataclasses.dataclass class SklearnPlotSaver(DataSaver): - path: Union[str, PathLike] + path: str | PathLike # kwargs dpi: float = 200 format: str = "png" - metadata: Optional[dict] = None + metadata: dict | None = None bbox_inches: str = None pad_inches: float = 0.1 - backend: Optional[str] = None + backend: str | None = None papertype: str = None transparent: bool = None - bbox_extra_artists: Optional[list] = None - pil_kwargs: Optional[dict] = None + bbox_extra_artists: list | None = None + pil_kwargs: dict | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return SKLEARN_PLOT_TYPES - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: kwargs = {} if self.dpi is not None: kwargs["dpi"] = self.dpi @@ -103,7 +104,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: kwargs["pil_kwargs"] = self.pil_kwargs return kwargs - def save_data(self, data: SKLEARN_PLOT_TYPES_ANNOTATION) -> Dict[str, Any]: + def save_data(self, data: SKLEARN_PLOT_TYPES_ANNOTATION) -> dict[str, Any]: data.plot() data.figure_.savefig(self.path, **self._get_saving_kwargs()) return utils.get_file_metadata(self.path) diff --git a/examples/decoupling_io/components/feature_data.py b/examples/decoupling_io/components/feature_data.py index fa63368fe..3e7b70caf 100644 --- a/examples/decoupling_io/components/feature_data.py +++ b/examples/decoupling_io/components/feature_data.py @@ -19,8 +19,6 @@ This is a module that contains our feature transforms. """ -from typing import Dict, List, Set - import pandas as pd from sklearn import impute, model_selection, preprocessing # import KNNImputer @@ -28,8 +26,8 @@ def _sanitize_columns( - df_columns: List[str], -) -> List[str]: + df_columns: list[str], +) -> list[str]: """Helper function to sanitize column names. :param df_columns: the current column names @@ -52,7 +50,7 @@ def passengers_df(titanic_data: pd.DataFrame) -> pd.DataFrame: return raw_passengers_df -def rare_titles() -> Set[str]: +def rare_titles() -> set[str]: """Rare titles we've curated""" return { "Capt", @@ -81,7 +79,7 @@ def normalized_name(name: pd.Series) -> pd.Series: return name.apply(lambda x: x.split(",")[1].split(".")[0].strip()) -def title(normalized_name: pd.Series, rare_titles: Set[str]) -> pd.Series: +def title(normalized_name: pd.Series, rare_titles: set[str]) -> pd.Series: return normalized_name.apply(lambda n: "rare" if n in rare_titles else n) @@ -241,7 +239,7 @@ def target(survived: pd.Series) -> pd.Series: @extract_fields({"train_set": pd.DataFrame, "test_set": pd.DataFrame}) def train_test_split( data_set: pd.DataFrame, target: pd.Series, test_size: float -) -> Dict[str, pd.DataFrame]: +) -> dict[str, pd.DataFrame]: """Splits the dataset into train & test. :param data_set: the dataset with all features already computed diff --git a/examples/decoupling_io/run.py b/examples/decoupling_io/run.py index 4a9d1e4d8..504a9049e 100644 --- a/examples/decoupling_io/run.py +++ b/examples/decoupling_io/run.py @@ -17,7 +17,6 @@ # Since this is not in plugins we need to import it *before* doing anything else import importlib -from typing import List, Union import click @@ -42,7 +41,7 @@ def cli(): } -def get_materializers(which_stage: str) -> List[Union[MaterializerFactory, ExtractorFactory]]: +def get_materializers(which_stage: str) -> list[MaterializerFactory | ExtractorFactory]: """Gives the set of materializers for the driver to run, separated out by stages. This demonstrates the evolution of additional materializers by stage, so that you can run stages individually using the same driver. diff --git a/examples/dlt/slack/__init__.py b/examples/dlt/slack/__init__.py index 1474b6158..4c4becd46 100644 --- a/examples/dlt/slack/__init__.py +++ b/examples/dlt/slack/__init__.py @@ -17,8 +17,9 @@ """Fetches Slack Conversations, History and logs.""" +from collections.abc import Iterable from functools import partial -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import dlt from dlt.common.typing import TAnyDateTime, TDataItem @@ -38,9 +39,9 @@ def slack_source( page_size: int = MAX_PAGE_SIZE, access_token: str = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = DEFAULT_START_DATE, - end_date: Optional[TAnyDateTime] = None, - selected_channels: Optional[List[str]] = dlt.config.value, + start_date: TAnyDateTime | None = DEFAULT_START_DATE, + end_date: TAnyDateTime | None = None, + selected_channels: list[str] | None = dlt.config.value, table_per_channel: bool = True, replies: bool = False, ) -> Iterable[DltResource]: @@ -62,8 +63,8 @@ def slack_source( Iterable[DltResource]: A list of DltResource objects representing the data resources. """ - end_dt: Optional[DateTime] = ensure_dt_type(end_date) - start_dt: Optional[DateTime] = ensure_dt_type(start_date) + end_dt: DateTime | None = ensure_dt_type(end_date) + start_dt: DateTime | None = ensure_dt_type(start_date) write_disposition: Literal["append", "merge"] = "append" if end_date is None else "merge" api = SlackAPI( @@ -72,8 +73,8 @@ def slack_source( ) def get_channels( - slack_api: SlackAPI, selected_channels: Optional[List[str]] - ) -> Tuple[List[TDataItem], List[TDataItem]]: + slack_api: SlackAPI, selected_channels: list[str] | None + ) -> tuple[list[TDataItem], list[TDataItem]]: """ Returns channel fetched from slack and list of selected channels. @@ -84,7 +85,7 @@ def get_channels( Returns: Tuple[List[TDataItem], List[TDataItem]]: fetched channels and selected fetched channels. """ - channels: List[TDataItem] = [] + channels: list[TDataItem] = [] for page_data in slack_api.get_pages( resource="conversations.list", response_path="$.channels[*]", @@ -127,7 +128,7 @@ def users_resource() -> Iterable[TDataItem]: yield page_data def get_messages( - channel_data: Dict[str, Any], start_date_ts: float, end_date_ts: float + channel_data: dict[str, Any], start_date_ts: float, end_date_ts: float ) -> Iterable[TDataItem]: """ Generator, which gets channel messages for specific dates. @@ -154,7 +155,7 @@ def get_messages( ): yield page_data - def get_thread_replies(messages: List[Dict[str, Any]]) -> Iterable[TDataItem]: + def get_thread_replies(messages: list[dict[str, Any]]) -> Iterable[TDataItem]: """ Generator, which gets replies for each message. Args: @@ -209,7 +210,7 @@ def messages_resource( yield from get_messages(channel_data, start_date_ts, end_date_ts) def per_table_messages_resource( - channel_data: Dict[str, Any], + channel_data: dict[str, Any], created_at: dlt.sources.incremental[DateTime] = None, ) -> Iterable[TDataItem]: """Yield all messages for a given channel as a DLT resource. Keep blocks column without normalization. diff --git a/examples/dlt/slack/helpers.py b/examples/dlt/slack/helpers.py index 519c6a338..a37a1d888 100644 --- a/examples/dlt/slack/helpers.py +++ b/examples/dlt/slack/helpers.py @@ -17,7 +17,8 @@ """Slack source helpers.""" -from typing import Any, Generator, Iterable, List, Optional +from collections.abc import Generator, Iterable +from typing import Any from urllib.parse import urljoin import pendulum @@ -98,7 +99,7 @@ def headers(self) -> Dict[str, str]: return {"Authorization": f"Bearer {self.access_token}"} def parameters( - self, params: Optional[Dict[str, Any]] = None, next_cursor: str = None + self, params: Dict[str, Any] | None = None, next_cursor: str = None ) -> Dict[str, str]: """ Generate the query parameters to use for the request. @@ -136,7 +137,7 @@ def _get_next_cursor(self, response: Dict[str, Any]) -> Any: return next(extract_jsonpath(cursor_jsonpath, response), None) def _convert_datetime_fields( - self, item: Dict[str, Any], datetime_fields: List[str] + self, item: Dict[str, Any], datetime_fields: list[str] ) -> Dict[str, Any]: """Convert timestamp fields in the item to pendulum datetime objects. @@ -167,7 +168,7 @@ def get_pages( resource: str, response_path: str = None, params: Dict[str, Any] = None, - datetime_fields: List[str] = None, + datetime_fields: list[str] = None, context: Dict[str, Any] = None, ) -> Iterable[TDataItem]: """Get all pages from slack using requests. diff --git a/examples/due_date_probabilities/probabilities.py b/examples/due_date_probabilities/probabilities.py index 832cf8d28..28fc46f56 100644 --- a/examples/due_date_probabilities/probabilities.py +++ b/examples/due_date_probabilities/probabilities.py @@ -16,7 +16,6 @@ # under the License. import datetime -from typing import Optional import pandas as pd from scipy import stats @@ -32,7 +31,7 @@ def full_pdf( start_date: datetime.datetime, due_date: datetime.datetime, probability_distribution: stats.rv_continuous, - current_date: Optional[datetime.datetime] = None, + current_date: datetime.datetime | None = None, induction_post_due_date_days: int = 14, ) -> pd.Series: """Probabilities of delivery on X date on the *full* date range. We'll filter later. diff --git a/examples/due_date_probabilities/probability_estimation.py b/examples/due_date_probabilities/probability_estimation.py index 8009cbb8b..655629daf 100644 --- a/examples/due_date_probabilities/probability_estimation.py +++ b/examples/due_date_probabilities/probability_estimation.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import pandas as pd from scipy import stats @@ -141,7 +140,7 @@ def raw_probabilities(raw_data: str) -> pd.DataFrame: return probabilities_df # .set_index("days") -def resampled(raw_probabilities: pd.DataFrame) -> List[int]: +def resampled(raw_probabilities: pd.DataFrame) -> list[int]: sample_data = [] for _idx, row in raw_probabilities.iterrows(): count = row.probability * 1000 diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/offline_loader.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/offline_loader.py index a2564aa74..8299c7e53 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/offline_loader.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_1/offline_loader.py @@ -21,8 +21,6 @@ We use this to build our offline ETL featurization process. """ -from typing import List - import pandas as pd from hamilton.function_modifiers import extract_columns, load_from, source, value @@ -53,7 +51,7 @@ ] -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/offline_loader.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/offline_loader.py index a62bff3d2..46fa945c9 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/offline_loader.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/offline_loader.py @@ -21,8 +21,6 @@ We use this to build our offline ETL featurization process. """ -from typing import List - import pandas as pd from hamilton.function_modifiers import extract_columns, load_from, source, value @@ -53,7 +51,7 @@ ] -def _sanitize_columns(df_columns: List[str]) -> List[str]: +def _sanitize_columns(df_columns: list[str]) -> list[str]: """Renames columns to be valid hamilton names -- and lower cases them. :param df_columns: the current column names. diff --git a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/online_loader.py b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/online_loader.py index 7e3f06157..732e50717 100644 --- a/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/online_loader.py +++ b/examples/feature_engineering/feature_engineering_multiple_contexts/scenario_2/online_loader.py @@ -21,8 +21,6 @@ We use this to build our offline ETL featurization process. """ -from typing import List - import aiohttp import pandas as pd @@ -31,7 +29,7 @@ from hamilton.function_modifiers.metadata import tag -class FeatureStoreHttpClient(object): +class FeatureStoreHttpClient: """HTTP Client -- replace this with your own implementation if you need to.""" session: aiohttp.ClientSession = None @@ -50,7 +48,7 @@ def __call__(self) -> aiohttp.ClientSession: assert self.session is not None return self.session - async def get_features(self, client_id: str, features_needed: List[str]) -> pd.DataFrame: + async def get_features(self, client_id: str, features_needed: list[str]) -> pd.DataFrame: """Makes a request to the feature store to get the data. :param client_id: id of the client to get data for. diff --git a/examples/feature_engineering/write_once_run_everywhere_blog_post/components/utils.py b/examples/feature_engineering/write_once_run_everywhere_blog_post/components/utils.py index 205ff6dfa..f36f2323b 100644 --- a/examples/feature_engineering/write_once_run_everywhere_blog_post/components/utils.py +++ b/examples/feature_engineering/write_once_run_everywhere_blog_post/components/utils.py @@ -16,7 +16,6 @@ # under the License. import random -from typing import List import pandas as pd @@ -26,7 +25,7 @@ """ -def fabricate_client_login_data(client_ids: List[int]) -> pd.DataFrame: +def fabricate_client_login_data(client_ids: list[int]) -> pd.DataFrame: """Fabricates a dataframe of client login data. This contains the columns client ID (int) and last_logged_in (datetime) @@ -46,7 +45,7 @@ def fabricate_client_login_data(client_ids: List[int]) -> pd.DataFrame: ) -def fabricate_survey_results_data(client_ids: List[int]) -> pd.DataFrame: +def fabricate_survey_results_data(client_ids: list[int]) -> pd.DataFrame: """Fabricates a dataframe of survey results. This has the following (random) columns: - budget -- amount they're willing to spend on an order (number between 1 and 1000) diff --git a/examples/hamilton_ui/components/model_fitting.py b/examples/hamilton_ui/components/model_fitting.py index f12b57bc7..e27fe6b7e 100644 --- a/examples/hamilton_ui/components/model_fitting.py +++ b/examples/hamilton_ui/components/model_fitting.py @@ -17,8 +17,6 @@ """This module contains basic code for model fitting.""" -from typing import Dict - import numpy as np import pandas as pd from sklearn import base, linear_model, metrics, svm @@ -54,7 +52,7 @@ def train_test_split_func( data_set: pd.DataFrame, test_size_fraction: float, shuffle_train_test_split: bool, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. diff --git a/examples/ibis/feature_engineering/column_dataflow.py b/examples/ibis/feature_engineering/column_dataflow.py index fac47c16f..63a99e06a 100644 --- a/examples/ibis/feature_engineering/column_dataflow.py +++ b/examples/ibis/feature_engineering/column_dataflow.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional import ibis import ibis.expr.types as ir @@ -72,7 +71,7 @@ def feature_table( def feature_set( feature_table: ir.Table, feature_selection: list[str], - condition: Optional[ibis.common.deferred.Deferred] = None, + condition: ibis.common.deferred.Deferred | None = None, ) -> ir.Table: """Select feature columns and filter rows""" return feature_table[feature_selection].filter(condition) diff --git a/examples/ibis/feature_engineering/table_dataflow.py b/examples/ibis/feature_engineering/table_dataflow.py index b70aecf08..434b64765 100644 --- a/examples/ibis/feature_engineering/table_dataflow.py +++ b/examples/ibis/feature_engineering/table_dataflow.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional import ibis import ibis.expr.types as ir @@ -42,7 +41,7 @@ def feature_table(raw_table: ir.Table) -> ir.Table: def feature_set( feature_table: ir.Table, feature_selection: list[str], - condition: Optional[ibis.common.deferred.Deferred] = None, + condition: ibis.common.deferred.Deferred | None = None, ) -> ir.Table: """Select feature columns and filter rows""" selection = feature_table[feature_selection] diff --git a/examples/ibisml/table_dataflow.py b/examples/ibisml/table_dataflow.py index 2e7ec7a9b..40619d6cb 100644 --- a/examples/ibisml/table_dataflow.py +++ b/examples/ibisml/table_dataflow.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional import ibis import ibis.expr.types as ir @@ -42,7 +41,7 @@ def feature_table(raw_table: ir.Table) -> ir.Table: def feature_set( feature_table: ir.Table, feature_selection: list[str], - condition: Optional[ibis.common.deferred.Deferred] = None, + condition: ibis.common.deferred.Deferred | None = None, ) -> ir.Table: """Select feature columns and filter rows""" return feature_table[feature_selection].filter(condition) diff --git a/examples/kedro/hamilton-code/src/hamilton_code/data_science.py b/examples/kedro/hamilton-code/src/hamilton_code/data_science.py index 0a6ef9ff1..be4b0766b 100644 --- a/examples/kedro/hamilton-code/src/hamilton_code/data_science.py +++ b/examples/kedro/hamilton-code/src/hamilton_code/data_science.py @@ -16,7 +16,6 @@ # under the License. import logging -from typing import List, Union import numpy as np import pandas as pd @@ -39,7 +38,7 @@ def split_data( create_model_input_table: pd.DataFrame, test_size: float, random_state: int, - features: List[str], + features: list[str], ) -> dict: """Splits data into features and targets training and test sets. @@ -84,7 +83,7 @@ def evaluate_model( train_model: LinearRegression, X_test: pd.DataFrame, y_test: pd.Series, -) -> Union[float, np.ndarray]: +) -> float | np.ndarray: """Calculates and logs the coefficient of determination. Args: diff --git a/examples/lineage/data_loading.py b/examples/lineage/data_loading.py index b053afee4..0570ce73b 100644 --- a/examples/lineage/data_loading.py +++ b/examples/lineage/data_loading.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import pandas as pd @@ -23,8 +22,8 @@ def _sanitize_columns( - df_columns: List[str], -) -> List[str]: + df_columns: list[str], +) -> list[str]: """Helper function to sanitize column names. :param df_columns: the current column names diff --git a/examples/lineage/lineage_commands.py b/examples/lineage/lineage_commands.py index 849c32682..1a36f286a 100644 --- a/examples/lineage/lineage_commands.py +++ b/examples/lineage/lineage_commands.py @@ -20,8 +20,6 @@ to help you interface with Hamilton and its lineage capabilities. """ -from typing import Dict, List - import data_loading import features import model_pipeline @@ -100,7 +98,7 @@ def visualize_upstream(dr: driver.Driver, end_node: str, image_path: str) -> Non ) -def what_source_data_teams_are_upstream(dr: driver.Driver, output_name: str) -> List[dict]: +def what_source_data_teams_are_upstream(dr: driver.Driver, output_name: str) -> list[dict]: """Function to return a list of teams that own the source data that is upstream of the output_name. :param dr: the driver object to use to query. @@ -121,7 +119,7 @@ def what_source_data_teams_are_upstream(dr: driver.Driver, output_name: str) -> return teams -def what_pii_is_used_where(dr: driver.Driver) -> Dict[str, List[dict]]: +def what_pii_is_used_where(dr: driver.Driver) -> dict[str, list[dict]]: """Function to return a dictionary of PII to artifacts that consume that PII directly or indirectly. :param dr: the driver object @@ -144,7 +142,7 @@ def what_pii_is_used_where(dr: driver.Driver) -> Dict[str, List[dict]]: return pii_to_artifacts -def what_artifacts_are_downstream(dr: driver.Driver, source_name: str) -> List[dict]: +def what_artifacts_are_downstream(dr: driver.Driver, source_name: str) -> list[dict]: """Function to return a list of artifacts that are downstream of a given source. :param dr: driver object to query. @@ -165,7 +163,7 @@ def what_artifacts_are_downstream(dr: driver.Driver, source_name: str) -> List[d return artifacts -def what_classifiers_are_downstream(dr: driver.Driver, start_node: str) -> List[dict]: +def what_classifiers_are_downstream(dr: driver.Driver, start_node: str) -> list[dict]: """Shows that you can also filter nodes based on output type. :param dr: driver object to query. @@ -186,7 +184,7 @@ def what_classifiers_are_downstream(dr: driver.Driver, start_node: str) -> List[ return models -def what_nodes_are_on_path_between(dr: driver.Driver, start_node: str, end_node: str) -> List[dict]: +def what_nodes_are_on_path_between(dr: driver.Driver, start_node: str, end_node: str) -> list[dict]: """Function that returns the nodes that are on the path between two nodes. :param dr: driver object to query. diff --git a/examples/lineage/model_pipeline.py b/examples/lineage/model_pipeline.py index e25c666c4..0f34aa23d 100644 --- a/examples/lineage/model_pipeline.py +++ b/examples/lineage/model_pipeline.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Union import pandas as pd from sklearn import base @@ -33,7 +32,7 @@ def train_test_split_func( target: pd.Series, validation_size_fraction: float, random_state: int, -) -> Dict[str, Union[pd.DataFrame, pd.Series]]: +) -> dict[str, pd.DataFrame | pd.Series]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. @@ -55,7 +54,7 @@ def train_test_split_func( return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test} -def prefit_random_forest(random_state: int, max_depth: Union[int, None]) -> base.ClassifierMixin: +def prefit_random_forest(random_state: int, max_depth: int | None) -> base.ClassifierMixin: """Returns a Random Forest Classifier with the specified parameters. :param random_state: random state for reproducibility. diff --git a/examples/materialization/custom_materializers.py b/examples/materialization/custom_materializers.py index 95fe9da18..b10daa058 100644 --- a/examples/materialization/custom_materializers.py +++ b/examples/materialization/custom_materializers.py @@ -17,7 +17,8 @@ import dataclasses import pickle -from typing import Any, Collection, Dict, Type +from collections.abc import Collection +from typing import Any import numpy as np from sklearn import base @@ -38,12 +39,12 @@ def __post_init__(self): if not self.path.endswith(".csv"): raise ValueError(f"CSV files must end with .csv, got {self.path}") - def save_data(self, data: np.ndarray) -> Dict[str, Any]: + def save_data(self, data: np.ndarray) -> dict[str, Any]: np.savetxt(self.path, data, delimiter=self.sep) return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [np.ndarray] @classmethod @@ -55,12 +56,12 @@ def name(cls) -> str: class SKLearnPickler(DataSaver): path: str - def save_data(self, data: base.ClassifierMixin) -> Dict[str, Any]: + def save_data(self, data: base.ClassifierMixin) -> dict[str, Any]: pickle.dump(data, open(self.path, "wb")) return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [base.ClassifierMixin] @classmethod diff --git a/examples/materialization/model_training.py b/examples/materialization/model_training.py index 2aa92f1e0..7a0f274a1 100644 --- a/examples/materialization/model_training.py +++ b/examples/materialization/model_training.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import numpy as np from sklearn import base, linear_model, metrics, svm @@ -52,7 +51,7 @@ def train_test_split_func( target: np.ndarray, test_size_fraction: float, shuffle_train_test_split: bool, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. diff --git a/examples/materialization/notebook.ipynb b/examples/materialization/notebook.ipynb index a61d2184e..cc8b8b2fb 100644 --- a/examples/materialization/notebook.ipynb +++ b/examples/materialization/notebook.ipynb @@ -539,7 +539,7 @@ } ], "source": [ - "print(open((materialization_results[\"classification_report_to_txt\"][\"path\"])).read())" + "print(open(materialization_results[\"classification_report_to_txt\"][\"path\"]).read())" ] } ], diff --git a/examples/model_examples/scikit-learn/my_train_evaluate_logic.py b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py index ccbdc69e8..324ad4e45 100644 --- a/examples/model_examples/scikit-learn/my_train_evaluate_logic.py +++ b/examples/model_examples/scikit-learn/my_train_evaluate_logic.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import numpy as np from sklearn import base, linear_model, metrics, svm @@ -52,7 +51,7 @@ def train_test_split_func( target: np.ndarray, test_size_fraction: float, shuffle_train_test_split: bool, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. diff --git a/examples/model_examples/time-series/data_loaders.py b/examples/model_examples/time-series/data_loaders.py index 8e2d4da8c..f3c92faf2 100644 --- a/examples/model_examples/time-series/data_loaders.py +++ b/examples/model_examples/time-series/data_loaders.py @@ -18,7 +18,6 @@ import gc import logging import os -from typing import Dict import pandas as pd import utils @@ -124,7 +123,7 @@ def sales_train_validation( ) def submission_loader( submission_path: str = "/kaggle/input/m5-forecasting-accuracy/sample_submission.csv", -) -> Dict[str, pd.DataFrame]: +) -> dict[str, pd.DataFrame]: """Loads the submission data. The notebook I got this from did some weird splitting with test1 and test 2. diff --git a/examples/model_examples/time-series/transforms.py b/examples/model_examples/time-series/transforms.py index 299a964a8..bcc2a98a3 100644 --- a/examples/model_examples/time-series/transforms.py +++ b/examples/model_examples/time-series/transforms.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Tuple import pandas as pd from pandas.core.groupby import generic @@ -24,7 +23,7 @@ from hamilton.function_modifiers import parameterize, source -def _label_encoder(col: pd.Series) -> Tuple[preprocessing.LabelEncoder, pd.Series]: +def _label_encoder(col: pd.Series) -> tuple[preprocessing.LabelEncoder, pd.Series]: """Creates an encoder, fits itself on the input, and then transforms the input. :param col: the column to encode. diff --git a/examples/mutate/abstract_functionality_blueprint/mutate.py b/examples/mutate/abstract_functionality_blueprint/mutate.py index 786e5982c..19b74ed87 100644 --- a/examples/mutate/abstract_functionality_blueprint/mutate.py +++ b/examples/mutate/abstract_functionality_blueprint/mutate.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any import pandas as pd @@ -52,7 +52,7 @@ def filter_(some_data: pd.DataFrame) -> pd.DataFrame: # data 2 # this is for value @mutate(data_2, missing_row=value(["c", 145])) -def add_missing_value(some_data: pd.DataFrame, missing_row: List[Any]) -> pd.DataFrame: +def add_missing_value(some_data: pd.DataFrame, missing_row: list[Any]) -> pd.DataFrame: """Add row to dataframe. The functions decorated with mutate can be viewed as steps in pipe_output in the order they diff --git a/examples/mutate/abstract_functionality_blueprint/mutate_on_output.py b/examples/mutate/abstract_functionality_blueprint/mutate_on_output.py index 571a44847..162c3a524 100644 --- a/examples/mutate/abstract_functionality_blueprint/mutate_on_output.py +++ b/examples/mutate/abstract_functionality_blueprint/mutate_on_output.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -47,7 +47,7 @@ def data_3() -> pd.DataFrame: @extract_fields({"field_1": pd.Series, "field_2": pd.Series}) -def feat_A(data_1: pd.DataFrame, data_2: pd.DataFrame) -> Dict[str, pd.Series]: +def feat_A(data_1: pd.DataFrame, data_2: pd.DataFrame) -> dict[str, pd.Series]: df = ( data_1.set_index("col_2").join(data_2.reset_index(names=["col_3"]).set_index("col_1")) ).reset_index(names=["col_0"]) @@ -83,7 +83,7 @@ def filter_(some_data: pd.DataFrame) -> pd.DataFrame: # data 2 # this is for value @mutate(apply_to(data_2), missing_row=value(["c", 145])) -def add_missing_value(some_data: pd.DataFrame, missing_row: List[Any]) -> pd.DataFrame: +def add_missing_value(some_data: pd.DataFrame, missing_row: list[Any]) -> pd.DataFrame: """Add row to dataframe. The functions decorated with mutate can be viewed as steps in pipe_output in the order they diff --git a/examples/mutate/abstract_functionality_blueprint/pipe_output.py b/examples/mutate/abstract_functionality_blueprint/pipe_output.py index 47e8263e0..31120f8e5 100644 --- a/examples/mutate/abstract_functionality_blueprint/pipe_output.py +++ b/examples/mutate/abstract_functionality_blueprint/pipe_output.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any import pandas as pd @@ -36,7 +36,7 @@ def test_foo(a, b, c): # data 2 # this is for value @hamilton_exclude -def add_missing_value(some_data: pd.DataFrame, missing_row: List[Any]) -> pd.DataFrame: +def add_missing_value(some_data: pd.DataFrame, missing_row: list[Any]) -> pd.DataFrame: some_data.loc[-1] = missing_row return some_data diff --git a/examples/mutate/abstract_functionality_blueprint/pipe_output_on_output.py b/examples/mutate/abstract_functionality_blueprint/pipe_output_on_output.py index a3ccdfac8..a5295c9e5 100644 --- a/examples/mutate/abstract_functionality_blueprint/pipe_output_on_output.py +++ b/examples/mutate/abstract_functionality_blueprint/pipe_output_on_output.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict from hamilton.function_modifiers import ( extract_fields, @@ -52,5 +51,5 @@ def a() -> int: .on_output(["field_1", "field_3"]), # applied to field_1 and field_3 ) @extract_fields({"field_1": int, "field_2": int, "field_3": int}) -def foo(a: int) -> Dict[str, int]: +def foo(a: int) -> dict[str, int]: return {"field_1": 1, "field_2": 2, "field_3": 3} diff --git a/examples/mutate/abstract_functionality_blueprint/procedural.py b/examples/mutate/abstract_functionality_blueprint/procedural.py index 11b6b2ee0..6fcc64ed0 100644 --- a/examples/mutate/abstract_functionality_blueprint/procedural.py +++ b/examples/mutate/abstract_functionality_blueprint/procedural.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any import pandas as pd @@ -44,7 +44,7 @@ def _filter(some_data: pd.DataFrame) -> pd.DataFrame: # data 2 # this is for value -def _add_missing_value(some_data: pd.DataFrame, missing_row: List[Any]) -> pd.DataFrame: +def _add_missing_value(some_data: pd.DataFrame, missing_row: list[Any]) -> pd.DataFrame: some_data.loc[-1] = missing_row return some_data diff --git a/examples/openlineage/pipeline.py b/examples/openlineage/pipeline.py index 1178d3d5f..03d1482ad 100644 --- a/examples/openlineage/pipeline.py +++ b/examples/openlineage/pipeline.py @@ -16,7 +16,6 @@ # under the License. import pickle -from typing import Tuple import pandas as pd @@ -35,13 +34,13 @@ @dataloader() -def user_dataset(file_ds_path: str) -> Tuple[pd.DataFrame, dict]: +def user_dataset(file_ds_path: str) -> tuple[pd.DataFrame, dict]: df = pd.read_csv(file_ds_path) return df, utils.get_file_and_dataframe_metadata(file_ds_path, df) @dataloader() -def purchase_dataset(db_client: object) -> Tuple[pd.DataFrame, dict]: +def purchase_dataset(db_client: object) -> tuple[pd.DataFrame, dict]: query = "SELECT * FROM purchase_data" df = pd.read_sql(query, con=db_client) metadata = { diff --git a/examples/pandas/split-apply-combine/my_functions.py b/examples/pandas/split-apply-combine/my_functions.py index dda9b3bed..81082217d 100644 --- a/examples/pandas/split-apply-combine/my_functions.py +++ b/examples/pandas/split-apply-combine/my_functions.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import numpy as np import pandas @@ -29,7 +28,7 @@ # ---------------------------------------------------------------------------------------------------------------------- -def _tax_rate(df: DataFrame, tax_rates: Dict[str, float]) -> DataFrame: +def _tax_rate(df: DataFrame, tax_rates: dict[str, float]) -> DataFrame: """ Add a series 'Tax Rate' to the DataFrame based on the tax_rates rules. :param df: The DataFrame @@ -47,7 +46,7 @@ def _tax_rate(df: DataFrame, tax_rates: Dict[str, float]) -> DataFrame: return df -def _tax_credit(df: DataFrame, tax_credits: Dict[str, float]) -> DataFrame: +def _tax_credit(df: DataFrame, tax_credits: dict[str, float]) -> DataFrame: """ Add a series 'Tax Credit' to the DataFrame based on the tax_credits rules. :param df: The DataFrame @@ -72,7 +71,7 @@ def _tax_credit(df: DataFrame, tax_credits: Dict[str, float]) -> DataFrame: @extract_fields({"under_100k": DataFrame, "over_100k": DataFrame}) # Step 1: DataFrame is split in 2 DataFrames -def split_dataframe(input: DataFrame) -> Dict[str, DataFrame]: +def split_dataframe(input: DataFrame) -> dict[str, DataFrame]: """ That function takes the DataFrame in input and split it in 2 DataFrames: - under_100k: Rows where 'Income' is under 100k diff --git a/examples/pandas/split-apply-combine/my_wrapper.py b/examples/pandas/split-apply-combine/my_wrapper.py index 77c210aca..37f0a15cb 100644 --- a/examples/pandas/split-apply-combine/my_wrapper.py +++ b/examples/pandas/split-apply-combine/my_wrapper.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import my_functions from pandas import DataFrame @@ -43,7 +42,7 @@ class TaxCalculator: @staticmethod def calculate( - input: DataFrame, tax_rates: Dict[str, float], tax_credits: Dict[str, float] + input: DataFrame, tax_rates: dict[str, float], tax_credits: dict[str, float] ) -> DataFrame: return driver.execute( inputs={"input": input, "tax_rates": tax_rates, "tax_credits": tax_credits}, diff --git a/examples/pandas/split-apply-combine/notebook.ipynb b/examples/pandas/split-apply-combine/notebook.ipynb index cf44cd4ee..ec2bd7966 100644 --- a/examples/pandas/split-apply-combine/notebook.ipynb +++ b/examples/pandas/split-apply-combine/notebook.ipynb @@ -555,8 +555,6 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Dict\n", - "\n", "# import my_functions # this is imported by the cell above\n", "from pandas import DataFrame\n", "\n", @@ -594,7 +592,7 @@ "\n", " @staticmethod\n", " def calculate(\n", - " input: DataFrame, tax_rates: Dict[str, float], tax_credits: Dict[str, float]\n", + " input: DataFrame, tax_rates: dict[str, float], tax_credits: dict[str, float]\n", " ) -> DataFrame:\n", " return driver.execute(\n", " inputs={\"input\": input, \"tax_rates\": tax_rates, \"tax_credits\": tax_credits},\n", diff --git a/examples/parallelism/file_processing/list_data.py b/examples/parallelism/file_processing/list_data.py index 11da3307b..90033e322 100644 --- a/examples/parallelism/file_processing/list_data.py +++ b/examples/parallelism/file_processing/list_data.py @@ -17,12 +17,11 @@ import dataclasses import os -from typing import List from hamilton.htypes import Parallelizable -def files(data_dir: str) -> List[str]: +def files(data_dir: str) -> list[str]: """Lists oll files in the data directory""" out = [] @@ -39,7 +38,7 @@ class CityData: weekday_file: str -def city_data(files: List[str]) -> Parallelizable[CityData]: +def city_data(files: list[str]) -> Parallelizable[CityData]: """Gathers a list of per-city data for processing/analyzing""" cities = dict() diff --git a/examples/parallelism/graceful_running/functions.py b/examples/parallelism/graceful_running/functions.py index 9482ccce2..11c61b12b 100644 --- a/examples/parallelism/graceful_running/functions.py +++ b/examples/parallelism/graceful_running/functions.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Callable, Iterable from itertools import cycle -from typing import Any, Callable, Iterable, List, Tuple, Union +from typing import Any import numpy as np import pandas as pd @@ -49,7 +50,7 @@ def load_data() -> pd.DataFrame: def split_to_groups( - load_data: pd.DataFrame, funcs: List[Splitter] + load_data: pd.DataFrame, funcs: list[Splitter] ) -> Parallelizable[tuple[str, pd.DataFrame]]: """Split data into interesting groups.""" for func in funcs: @@ -67,7 +68,7 @@ def average(data: pd.DataFrame) -> float: return data.Views.mean() -def model_fit(data: pd.DataFrame, group_name: str) -> Tuple[float, float, float]: +def model_fit(data: pd.DataFrame, group_name: str) -> tuple[float, float, float]: """Imagine a model fit that doesn't always work.""" if "Method:TV" in group_name: raise Exception("Fake floating point error, e.g.") @@ -79,9 +80,9 @@ def model_fit(data: pd.DataFrame, group_name: str) -> Tuple[float, float, float] @accept_error_sentinels def gather_metrics( - group_name: Union[str, None], - average: Union[float, None], - model_fit: Union[Tuple[float, float, float], None], + group_name: str | None, + average: float | None, + model_fit: tuple[float, float, float] | None, ) -> dict[str, Any]: answer = { "Name": group_name, diff --git a/examples/parallelism/graceful_running/run.py b/examples/parallelism/graceful_running/run.py index 43e747070..eaba87182 100644 --- a/examples/parallelism/graceful_running/run.py +++ b/examples/parallelism/graceful_running/run.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterable, Tuple +from collections.abc import Iterable import click import functions @@ -30,17 +30,17 @@ # Assume we define some custom methods for splittings -def split_on_region(data: pd.DataFrame) -> Iterable[Tuple[str, pd.DataFrame]]: +def split_on_region(data: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: for idx, grp in data.groupby("Region"): yield f"Region:{idx}", grp -def split_on_attrs(data: pd.DataFrame) -> Iterable[Tuple[str, pd.DataFrame]]: +def split_on_attrs(data: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: for (region, method), grp in data.groupby(["Region", "Method"]): yield f"Region:{region} - Method:{method}", grp -def split_on_views(data: pd.DataFrame) -> Iterable[Tuple[str, pd.DataFrame]]: +def split_on_views(data: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]: yield "Low Views", data[data.Views <= 4000.0] yield "High Views", data[data.Views > 4000.0] diff --git a/examples/parallelism/star_counting/functions.py b/examples/parallelism/star_counting/functions.py index 4118a6f1b..a78f07004 100644 --- a/examples/parallelism/star_counting/functions.py +++ b/examples/parallelism/star_counting/functions.py @@ -16,7 +16,6 @@ # under the License. from datetime import datetime -from typing import Dict, List, Tuple import pandas as pd import requests @@ -25,7 +24,7 @@ from hamilton.htypes import Collect, Parallelizable -def starcount_url(repositories: List[str]) -> Parallelizable[str]: +def starcount_url(repositories: list[str]) -> Parallelizable[str]: """Generates API URLs for counting stars on a repo. We do this so we can paginate requests later. @@ -37,7 +36,7 @@ def starcount_url(repositories: List[str]) -> Parallelizable[str]: yield f"https://api.github.com/repos/{repo}" -def star_count(starcount_url: str, github_api_key: str) -> Tuple[str, int]: +def star_count(starcount_url: str, github_api_key: str) -> tuple[str, int]: """Generates the star count for a given repo. :param starcount_url: URL of the repo @@ -52,7 +51,7 @@ def star_count(starcount_url: str, github_api_key: str) -> Tuple[str, int]: return data["full_name"], data["stargazers_count"] -def stars_by_repo(star_count: Collect[Tuple[str, int]]) -> Dict[str, int]: +def stars_by_repo(star_count: Collect[tuple[str, int]]) -> dict[str, int]: """Aggregates the star count for each repo into a dictionary, so we can generate paginated requests. @@ -65,7 +64,7 @@ def stars_by_repo(star_count: Collect[Tuple[str, int]]) -> Dict[str, int]: return star_count_dict -def stargazer_url(stars_by_repo: Dict[str, int], per_page: int = 100) -> Parallelizable[str]: +def stargazer_url(stars_by_repo: dict[str, int], per_page: int = 100) -> Parallelizable[str]: """Generates query objects for each repository, with the correct pagination and offset. :param stars_by_repo: The star count for each repo diff --git a/examples/parallelism/star_counting/run.py b/examples/parallelism/star_counting/run.py index 8b6cc3b12..4cce6f118 100644 --- a/examples/parallelism/star_counting/run.py +++ b/examples/parallelism/star_counting/run.py @@ -16,7 +16,6 @@ # under the License. import logging -from typing import Tuple import click import functions @@ -74,7 +73,7 @@ def _get_executor(mode: str): required=False, help="Where to run remote tasks.", ) -def main(github_api_key: str, repositories: Tuple[str, ...], mode: str): +def main(github_api_key: str, repositories: tuple[str, ...], mode: str): remote_executor, shutdown = _get_executor(mode) dr = ( driver.Builder() diff --git a/examples/plotly/model_training.py b/examples/plotly/model_training.py index 29448f1e2..b3e8ac9f7 100644 --- a/examples/plotly/model_training.py +++ b/examples/plotly/model_training.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import numpy as np import plotly.express as px @@ -76,7 +75,7 @@ def train_test_split_func( target: np.ndarray, test_size_fraction: float, shuffle_train_test_split: bool, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. diff --git a/examples/prefect/prepare_data.py b/examples/prefect/prepare_data.py index d4dd2c0a4..8bf4525cb 100644 --- a/examples/prefect/prepare_data.py +++ b/examples/prefect/prepare_data.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import numpy as np import pandas as pd @@ -76,7 +75,7 @@ ] -def _rename_columns(columns: List[str]) -> List[str]: +def _rename_columns(columns: list[str]) -> list[str]: """convert raw data column names to snakecase and make them compatible with Hamilton's naming convention (need to be a valid Python function name) diff --git a/examples/prefect/train_model.py b/examples/prefect/train_model.py index 60a243164..08d121c84 100644 --- a/examples/prefect/train_model.py +++ b/examples/prefect/train_model.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, List import numpy as np import pandas as pd @@ -52,8 +51,8 @@ def preprocessed_data__dev(df: pd.DataFrame) -> pd.DataFrame: ) ) def split_indices( - preprocessed_data: pd.DataFrame, validation_user_ids: List[int] -) -> Dict[str, pd.DataFrame]: + preprocessed_data: pd.DataFrame, validation_user_ids: list[int] +) -> dict[str, pd.DataFrame]: """Creating train-validation splits based on the list of `validation_user_ids`""" validation_selection_mask = preprocessed_data.id.isin([int(i) for i in validation_user_ids]) @@ -63,7 +62,7 @@ def split_indices( ) -def data_stats(preprocessed_data: pd.DataFrame, feature_set: List[str]) -> pd.DataFrame: +def data_stats(preprocessed_data: pd.DataFrame, feature_set: list[str]) -> pd.DataFrame: return preprocessed_data[feature_set].describe() @@ -71,7 +70,7 @@ def data_stats(preprocessed_data: pd.DataFrame, feature_set: List[str]) -> pd.Da X_train=dict(df=source("train_df"), feature_set=source("feature_set")), X_validation=dict(df=source("validation_df"), feature_set=source("feature_set")), ) -def features(df: pd.DataFrame, feature_set: List[str]) -> np.ndarray: +def features(df: pd.DataFrame, feature_set: list[str]) -> np.ndarray: """Select features from `preprocessed_data` based on `feature_set`""" return df[feature_set].to_numpy() diff --git a/examples/reusing_functions/main.py b/examples/reusing_functions/main.py index d74a252d5..bbeff0303 100644 --- a/examples/reusing_functions/main.py +++ b/examples/reusing_functions/main.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any import pandas as pd import reusable_subdags @@ -119,7 +119,7 @@ def is_time_series(series: Any): return False return True - def build_result(self, **outputs: Dict[str, Any]) -> Any: + def build_result(self, **outputs: dict[str, Any]) -> Any: non_ts_output = [ key for key, value in outputs.items() diff --git a/examples/reusing_functions/reusing_functions.ipynb b/examples/reusing_functions/reusing_functions.ipynb index 1455ebe62..4047eb347 100644 --- a/examples/reusing_functions/reusing_functions.ipynb +++ b/examples/reusing_functions/reusing_functions.ipynb @@ -73,7 +73,7 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Any, Dict\n", + "from typing import Any\n", "\n", "import pandas as pd\n", "\n", @@ -1424,7 +1424,7 @@ " return False\n", " return True\n", "\n", - " def build_result(self, **outputs: Dict[str, Any]) -> Any:\n", + " def build_result(self, **outputs: dict[str, Any]) -> Any:\n", " non_ts_output = [\n", " key\n", " for key, value in outputs.items()\n", diff --git a/examples/scikit-learn/species_distribution_modeling/grids.py b/examples/scikit-learn/species_distribution_modeling/grids.py index 6c1ebe013..282535899 100644 --- a/examples/scikit-learn/species_distribution_modeling/grids.py +++ b/examples/scikit-learn/species_distribution_modeling/grids.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Tuple import numpy as np import numpy.typing as npt @@ -25,21 +24,21 @@ from hamilton.function_modifiers import pipe_input, step -def _construct_grids(batch: Bunch) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: +def _construct_grids(batch: Bunch) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Our wrapper around and external function to integrate it as a node in the DAG.""" return construct_grids(batch=batch) @pipe_input(step(_construct_grids)) def data_grid_( - data: Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + data: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: return data def meshgrid( - data_grid_: Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + data_grid_: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: return np.meshgrid(data_grid_[0], data_grid_[1][::-1]) diff --git a/examples/scikit-learn/species_distribution_modeling/postprocessing_results.py b/examples/scikit-learn/species_distribution_modeling/postprocessing_results.py index a4fce770c..b9f9e48e8 100644 --- a/examples/scikit-learn/species_distribution_modeling/postprocessing_results.py +++ b/examples/scikit-learn/species_distribution_modeling/postprocessing_results.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Tuple +from typing import Any import numpy as np import numpy.typing as npt @@ -55,13 +55,13 @@ def area_under_curve( def plot_species_distribution( - meshgrid: Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], + meshgrid: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], prediction_train: npt.NDArray[np.float64], land_reference: npt.NDArray[np.float64], levels: npt.NDArray[np.float64], bunch: Bunch, area_under_curve: float, -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "X": meshgrid[0], "Y": meshgrid[1], diff --git a/examples/scikit-learn/species_distribution_modeling/preprocessing.py b/examples/scikit-learn/species_distribution_modeling/preprocessing.py index 083e8b27a..22b41c381 100644 --- a/examples/scikit-learn/species_distribution_modeling/preprocessing.py +++ b/examples/scikit-learn/species_distribution_modeling/preprocessing.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Tuple import numpy as np import numpy.typing as npt @@ -28,7 +27,7 @@ def _create_species_bunch( species_name: str, data: Bunch, - data_grid: Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], + data_grid: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]], ) -> npt.NDArray[np.float64]: """Our wrapper around and external function to integrate it as a node in the DAG.""" return create_species_bunch( @@ -38,7 +37,7 @@ def _create_species_bunch( def _standardize_features( species_bunch: npt.NDArray[np.float64], -) -> Tuple[ +) -> tuple[ npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], @@ -62,12 +61,12 @@ def _standardize_features( step(_standardize_features), ) def species( - chosen_species: Tuple[ + chosen_species: tuple[ npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], ], -) -> Dict[str, npt.NDArray[np.float64]]: +) -> dict[str, npt.NDArray[np.float64]]: train_cover_std = (chosen_species[0].cov_train - chosen_species[1]) / chosen_species[2] return { "bunch": chosen_species[0], diff --git a/examples/scikit-learn/transformer/hamilton_notebook.ipynb b/examples/scikit-learn/transformer/hamilton_notebook.ipynb index 9c05b7ca1..ec057a9ce 100644 --- a/examples/scikit-learn/transformer/hamilton_notebook.ipynb +++ b/examples/scikit-learn/transformer/hamilton_notebook.ipynb @@ -86,7 +86,7 @@ "from __future__ import annotations # noqa: F404\n", "\n", "import logging\n", - "from typing import TYPE_CHECKING, Any, Dict, List\n", + "from typing import TYPE_CHECKING, Any\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -185,9 +185,9 @@ " def __init__(\n", " self,\n", " config: dict = None,\n", - " modules: List[ModuleType] = None,\n", + " modules: list[ModuleType] = None,\n", " adapter: base.HamiltonGraphAdapter = None,\n", - " final_vars: List[str] = None,\n", + " final_vars: list[str] = None,\n", " ):\n", " self.config = {} if config is None else config\n", " self.modules = [] if modules is None else modules\n", @@ -228,7 +228,7 @@ " \"\"\"\n", " return {\"requires_fit\": True, \"requires_y\": False}\n", "\n", - " def fit(self, X, y=None, overrides: Dict[str, Any] = None) -> HamiltonTransformer:\n", + " def fit(self, X, y=None, overrides: dict[str, Any] = None) -> HamiltonTransformer:\n", " \"\"\"Instantiate Hamilton driver.Driver object\n", "\n", " :param X: Input 2D array\n", diff --git a/examples/scikit-learn/transformer/run.py b/examples/scikit-learn/transformer/run.py index ed131037e..3c341f016 100644 --- a/examples/scikit-learn/transformer/run.py +++ b/examples/scikit-learn/transformer/run.py @@ -19,7 +19,7 @@ import importlib import logging -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -40,9 +40,9 @@ class HamiltonTransformer(BaseEstimator, TransformerMixin): def __init__( self, config: dict = None, - modules: List[ModuleType] = None, + modules: list[ModuleType] = None, adapter: base.HamiltonGraphAdapter = None, - final_vars: List[str] = None, + final_vars: list[str] = None, ): self.config = {} if config is None else config self.modules = [] if modules is None else modules @@ -83,7 +83,7 @@ def _get_tags(self) -> dict: """ return {"requires_fit": True, "requires_y": False} - def fit(self, X, y=None, overrides: Dict[str, Any] = None) -> HamiltonTransformer: + def fit(self, X, y=None, overrides: dict[str, Any] = None) -> HamiltonTransformer: """Instantiate Hamilton driver.Driver object :param X: Input 2D array diff --git a/examples/spark/pyspark/dataflow.py b/examples/spark/pyspark/dataflow.py index f5bfed60e..e1f0f1a76 100644 --- a/examples/spark/pyspark/dataflow.py +++ b/examples/spark/pyspark/dataflow.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import map_transforms import pandas as pd @@ -98,7 +97,7 @@ def base_df(spark_session: ps.SparkSession) -> ps.DataFrame: "spend_std_dev": float, } ) -def spend_statistics(base_df: ps.DataFrame) -> Dict[str, float]: +def spend_statistics(base_df: ps.DataFrame) -> dict[str, float]: """Computes the mean and standard deviation of the spend column. Note that this is a blocking (collect) operation, but it doesn't have to be if you use an aggregation. In that case diff --git a/examples/spark/pyspark_feature_catalog/example_usage.ipynb b/examples/spark/pyspark_feature_catalog/example_usage.ipynb index 674d05d85..8328ec6ff 100644 --- a/examples/spark/pyspark_feature_catalog/example_usage.ipynb +++ b/examples/spark/pyspark_feature_catalog/example_usage.ipynb @@ -101,7 +101,7 @@ "source": [ "class MyCustomBuilder(base.ResultMixin):\n", " @staticmethod\n", - " def build_result(**outputs: typing.Dict[str, typing.Any]) -> ps.DataFrame:\n", + " def build_result(**outputs: dict[str, typing.Any]) -> ps.DataFrame:\n", " # TODO: add error handling when incompatible outputs are created\n", " level_info = outputs[\"level_info\"]\n", " zone_counts = outputs[\"zone_counts\"]\n", diff --git a/examples/spark/tpc-h/query_1.py b/examples/spark/tpc-h/query_1.py index 80a16d778..1e7a620e0 100644 --- a/examples/spark/tpc-h/query_1.py +++ b/examples/spark/tpc-h/query_1.py @@ -31,7 +31,7 @@ def start_date() -> str: def lineitem_filtered(lineitem: ps.DataFrame, start_date: str) -> ps.DataFrame: - return lineitem.filter((lineitem.l_shipdate <= start_date)) + return lineitem.filter(lineitem.l_shipdate <= start_date) def disc_price( diff --git a/examples/streamlit/app.py b/examples/streamlit/app.py index e816722ef..6d8ad66ad 100644 --- a/examples/streamlit/app.py +++ b/examples/streamlit/app.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional import logic import streamlit as st @@ -33,8 +32,8 @@ def get_hamilton_driver() -> driver.Driver: @st.cache_data def _execute( final_vars: list[str], - inputs: Optional[dict] = None, - overrides: Optional[dict] = None, + inputs: dict | None = None, + overrides: dict | None = None, ) -> dict: """Generic utility to cache Hamilton results""" dr = get_hamilton_driver() diff --git a/examples/styling_visualization/data_loading.py b/examples/styling_visualization/data_loading.py index b053afee4..0570ce73b 100644 --- a/examples/styling_visualization/data_loading.py +++ b/examples/styling_visualization/data_loading.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import pandas as pd @@ -23,8 +22,8 @@ def _sanitize_columns( - df_columns: List[str], -) -> List[str]: + df_columns: list[str], +) -> list[str]: """Helper function to sanitize column names. :param df_columns: the current column names diff --git a/examples/styling_visualization/model_pipeline.py b/examples/styling_visualization/model_pipeline.py index e25c666c4..0f34aa23d 100644 --- a/examples/styling_visualization/model_pipeline.py +++ b/examples/styling_visualization/model_pipeline.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Union import pandas as pd from sklearn import base @@ -33,7 +32,7 @@ def train_test_split_func( target: pd.Series, validation_size_fraction: float, random_state: int, -) -> Dict[str, Union[pd.DataFrame, pd.Series]]: +) -> dict[str, pd.DataFrame | pd.Series]: """Function that creates the training & test splits. It this then extracted out into constituent components and used downstream. @@ -55,7 +54,7 @@ def train_test_split_func( return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test} -def prefit_random_forest(random_state: int, max_depth: Union[int, None]) -> base.ClassifierMixin: +def prefit_random_forest(random_state: int, max_depth: int | None) -> base.ClassifierMixin: """Returns a Random Forest Classifier with the specified parameters. :param random_state: random state for reproducibility. diff --git a/examples/validate_examples.py b/examples/validate_examples.py index 133f0bcec..16f531c77 100644 --- a/examples/validate_examples.py +++ b/examples/validate_examples.py @@ -20,10 +20,13 @@ import argparse import logging import pathlib -from typing import Sequence +from typing import TYPE_CHECKING import nbformat +if TYPE_CHECKING: + from collections.abc import Sequence + logger = logging.getLogger(__name__) IGNORE_PRAGMA = "## ignore_ci" diff --git a/examples/validation/static_validator/notebook.ipynb b/examples/validation/static_validator/notebook.ipynb index f884a2721..ecf0db6ba 100644 --- a/examples/validation/static_validator/notebook.ipynb +++ b/examples/validation/static_validator/notebook.ipynb @@ -174,7 +174,6 @@ "outputs": [], "source": [ "# Validator\n", - "from typing import Optional\n", "\n", "from hamilton.graph_types import HamiltonNode\n", "from hamilton.lifecycle import api\n", @@ -185,7 +184,7 @@ "\n", " def run_to_validate_node(\n", " self, *, node: HamiltonNode, **future_kwargs\n", - " ) -> tuple[bool, Optional[str]]:\n", + " ) -> tuple[bool, str | None]:\n", " if node.tags.get(\"node_type\", \"\") == \"output\":\n", " table_name = node.tags.get(\"table_name\")\n", " if not table_name: # None or empty\n", diff --git a/examples/validation/static_validator/run.py b/examples/validation/static_validator/run.py index fbee74a65..c7e69c782 100644 --- a/examples/validation/static_validator/run.py +++ b/examples/validation/static_validator/run.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional from hamilton.graph_types import HamiltonNode from hamilton.lifecycle import api @@ -24,7 +23,7 @@ class MyTagValidator(api.StaticValidator): def run_to_validate_node( self, *, node: HamiltonNode, **future_kwargs - ) -> tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: if node.tags.get("node_type", "") == "output": table_name = node.tags.get("table_name") if not table_name: # None or empty diff --git a/hamilton/ad_hoc_utils.py b/hamilton/ad_hoc_utils.py index a3e67d63f..82d476d8f 100644 --- a/hamilton/ad_hoc_utils.py +++ b/hamilton/ad_hoc_utils.py @@ -25,8 +25,8 @@ import tempfile import types import uuid +from collections.abc import Callable from types import ModuleType -from typing import Callable, Optional def _copy_func(f): @@ -81,7 +81,7 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo return module -def module_from_source(source: str, module_name: Optional[str] = None) -> ModuleType: +def module_from_source(source: str, module_name: str | None = None) -> ModuleType: """Create a temporary module from source code.""" module_name = module_name or _generate_unique_temp_module_name() module_object = ModuleType(module_name) diff --git a/hamilton/async_driver.py b/hamilton/async_driver.py index 4c80223d6..a6c45b61f 100644 --- a/hamilton/async_driver.py +++ b/hamilton/async_driver.py @@ -23,7 +23,7 @@ import typing import uuid from types import ModuleType -from typing import Any, Dict, Optional, Tuple +from typing import Any import hamilton.lifecycle.base as lifecycle_base from hamilton import base, driver, graph, lifecycle, node, telemetry @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -async def await_dict_of_tasks(task_dict: Dict[str, typing.Awaitable]) -> Dict[str, Any]: +async def await_dict_of_tasks(task_dict: dict[str, typing.Awaitable]) -> dict[str, Any]: """Util to await a dictionary of tasks as asyncio.gather is kind of garbage""" keys = sorted(task_dict.keys()) coroutines = [task_dict[key] for key in keys] @@ -59,7 +59,7 @@ class AsyncGraphAdapter(lifecycle_base.BaseDoNodeExecute, lifecycle.ResultBuilde def __init__( self, result_builder: base.ResultMixin = None, - async_lifecycle_adapters: Optional[lifecycle_base.LifecycleAdapterSet] = None, + async_lifecycle_adapters: lifecycle_base.LifecycleAdapterSet | None = None, ): """Creates an AsyncGraphAdapter class. Note this will *only* work with the AsyncDriver class. @@ -83,8 +83,8 @@ def do_node_execute( *, run_id: str, node_: node.Node, - kwargs: typing.Dict[str, typing.Any], - task_id: Optional[str] = None, + kwargs: dict[str, typing.Any], + task_id: str | None = None, ) -> typing.Any: """Executes a node. Note this doesn't actually execute it -- rather, it returns a task. This does *not* use async def, as we want it to be awaited on later -- this await is done @@ -176,8 +176,8 @@ def build_result(self, **outputs: Any) -> Any: def separate_sync_from_async( - adapters: typing.List[lifecycle.LifecycleAdapter], -) -> Tuple[typing.List[lifecycle.LifecycleAdapter], typing.List[lifecycle.LifecycleAdapter]]: + adapters: list[lifecycle.LifecycleAdapter], +) -> tuple[list[lifecycle.LifecycleAdapter], list[lifecycle.LifecycleAdapter]]: """Separates the sync and async adapters from a list of adapters. Note this only works with hooks -- we'll be dealing with methods later. @@ -213,8 +213,8 @@ def __init__( self, config, *modules, - result_builder: Optional[base.ResultMixin] = None, - adapters: typing.List[lifecycle.LifecycleAdapter] = None, + result_builder: base.ResultMixin | None = None, + adapters: list[lifecycle.LifecycleAdapter] = None, allow_module_overrides: bool = False, ): """Instantiates an asynchronous driver. @@ -290,12 +290,12 @@ async def ainit(self) -> "AsyncDriver": async def raw_execute( self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, + final_vars: list[str], + overrides: dict[str, Any] = None, display_graph: bool = False, # don't care - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, _fn_graph: graph.FunctionGraph = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Executes the graph, returning a dictionary of strings (node keys) to final results. :param final_vars: Variables to execute (+ upstream) @@ -355,10 +355,10 @@ async def raw_execute( async def execute( self, - final_vars: typing.List[str], - overrides: Dict[str, Any] = None, + final_vars: list[str], + overrides: dict[str, Any] = None, display_graph: bool = False, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, ) -> Any: """Executes computation. @@ -409,9 +409,9 @@ async def make_coroutine(): def capture_constructor_telemetry( self, - error: Optional[str], - modules: Tuple[ModuleType], - config: Dict[str, Any], + error: str | None, + modules: tuple[ModuleType], + config: dict[str, Any], adapter: base.HamiltonGraphAdapter, ): """Ensures we capture constructor telemetry the right way in an async context. @@ -484,7 +484,7 @@ def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> self._not_supported("enable_dynamic_execution") def with_materializers( - self, *materializers: typing.Union[ExtractorFactory, MaterializerFactory] + self, *materializers: ExtractorFactory | MaterializerFactory ) -> "Builder": self._not_supported("with_materializers") diff --git a/hamilton/base.py b/hamilton/base.py index 02241aaca..0cdedaec0 100644 --- a/hamilton/base.py +++ b/hamilton/base.py @@ -23,7 +23,7 @@ import abc import collections import logging -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any import numpy as np import pandas as pd @@ -76,15 +76,15 @@ class DictResult(ResultMixin): """ @staticmethod - def build_result(**outputs: Dict[str, Any]) -> Dict: + def build_result(**outputs: dict[str, Any]) -> dict: """This function builds a simple dict of output -> computed values.""" return outputs - def input_types(self) -> Optional[List[Type[Type]]]: + def input_types(self) -> list[type[type]] | None: return [Any] - def output_type(self) -> Type: - return Dict[str, Any] + def output_type(self) -> type: + return dict[str, Any] class PandasDataFrameResult(ResultMixin): @@ -108,8 +108,8 @@ class PandasDataFrameResult(ResultMixin): @staticmethod def pandas_index_types( - outputs: Dict[str, Any], - ) -> Tuple[Dict[str, List[str]], Dict[str, List[str]], Dict[str, List[str]]]: + outputs: dict[str, Any], + ) -> tuple[dict[str, list[str]], dict[str, list[str]], dict[str, list[str]]]: """This function creates three dictionaries according to whether there is an index type or not. The three dicts we create are: @@ -124,7 +124,7 @@ def pandas_index_types( time_indexes = collections.defaultdict(list) no_indexes = collections.defaultdict(list) - def index_key_name(pd_object: Union[pd.DataFrame, pd.Series]) -> str: + def index_key_name(pd_object: pd.DataFrame | pd.Series) -> str: """Creates a string helping identify the index and it's type. Useful for disambiguating time related indexes.""" return f"{pd_object.index.__class__.__name__}:::{pd_object.index.dtype}" @@ -160,9 +160,9 @@ def get_parent_time_index_type(): @staticmethod def check_pandas_index_types_match( - all_index_types: Dict[str, List[str]], - time_indexes: Dict[str, List[str]], - no_indexes: Dict[str, List[str]], + all_index_types: dict[str, list[str]], + time_indexes: dict[str, list[str]], + no_indexes: dict[str, list[str]], ) -> bool: """Checks that pandas index types match. @@ -212,7 +212,7 @@ def check_pandas_index_types_match( return types_match @staticmethod - def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame: + def build_result(**outputs: dict[str, Any]) -> pd.DataFrame: """Builds a Pandas DataFrame from the outputs. This function will check the index types of the outputs, and log warnings if they don't match. @@ -244,7 +244,7 @@ def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame: return pd.DataFrame(outputs) # this does an implicit outer join based on index. @staticmethod - def build_dataframe_with_dataframes(outputs: Dict[str, Any]) -> pd.DataFrame: + def build_dataframe_with_dataframes(outputs: dict[str, Any]) -> pd.DataFrame: """Builds a dataframe from the outputs in an "outer join" manner based on index. The behavior of pd.Dataframe(outputs) is that it will do an outer join based on indexes of the Series passed in. @@ -294,12 +294,12 @@ def get_output_name(output_name: str, column_name: str) -> str: return pd.DataFrame(flattened_outputs) - def input_types(self) -> List[Type[Type]]: + def input_types(self) -> list[type[type]]: """Currently this just shoves anything into a dataframe. We should probably tighten this up.""" return [Any] - def output_type(self) -> Type: + def output_type(self) -> type: return pd.DataFrame @@ -324,7 +324,7 @@ class StrictIndexTypePandasDataFrameResult(PandasDataFrameResult): """ @staticmethod - def build_result(**outputs: Dict[str, Any]) -> pd.DataFrame: + def build_result(**outputs: dict[str, Any]) -> pd.DataFrame: # TODO check inputs are pd.Series, arrays, or scalars -- else error output_index_type_tuple = PandasDataFrameResult.pandas_index_types(outputs) indexes_match = PandasDataFrameResult.check_pandas_index_types_match( @@ -357,7 +357,7 @@ class NumpyMatrixResult(ResultMixin): """ @staticmethod - def build_result(**outputs: Dict[str, Any]) -> np.matrix: + def build_result(**outputs: dict[str, Any]) -> np.matrix: """Builds a numpy matrix from the passed in, inputs. Note: this does not check that the inputs are all numpy arrays/array like things. @@ -397,11 +397,11 @@ def build_result(**outputs: Dict[str, Any]) -> np.matrix: # Create the matrix with columns as rows and then transpose return np.asmatrix(list_of_columns).T - def input_types(self) -> List[Type[Type]]: + def input_types(self) -> list[type[type]]: """Currently returns anything as numpy types are relatively new and""" return [Any] # Typing - def output_type(self) -> Type: + def output_type(self) -> type: return pd.DataFrame @@ -422,14 +422,14 @@ class SimplePythonDataFrameGraphAdapter(HamiltonGraphAdapter, PandasDataFrameRes """ @staticmethod - def check_input_type(node_type: Type, input_value: Any) -> bool: + def check_input_type(node_type: type, input_value: Any) -> bool: return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: + def check_node_type_equivalence(node_type: type, input_type: type) -> bool: return node_type == input_type - def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: + def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any: return node.callable(**kwargs) @@ -453,11 +453,11 @@ def __init__(self, result_builder: ResultMixin = None): result_builder = DictResult() self.result_builder = result_builder - def build_result(self, **outputs: Dict[str, Any]) -> Any: + def build_result(self, **outputs: dict[str, Any]) -> Any: """Delegates to the result builder function supplied.""" return self.result_builder.build_result(**outputs) - def output_type(self) -> Type: + def output_type(self) -> type: return self.result_builder.output_type() diff --git a/hamilton/caching/adapter.py b/hamilton/caching/adapter.py index 33c606b2b..17b717bf7 100644 --- a/hamilton/caching/adapter.py +++ b/hamilton/caching/adapter.py @@ -23,8 +23,9 @@ import logging import pathlib import uuid +from collections.abc import Callable, Collection from datetime import datetime, timezone -from typing import Any, Callable, Collection, Dict, List, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeVar import hamilton.node from hamilton import graph_types @@ -156,9 +157,9 @@ class CachingEvent: actor: Literal["adapter", "metadata_store", "result_store"] event_type: CachingEventType node_name: str - task_id: Optional[str] = None - msg: Optional[str] = None - value: Optional[Any] = None + task_id: str | None = None + msg: str | None = None + value: Any | None = None timestamp: float = dataclasses.field( default_factory=lambda: datetime.now(timezone.utc).timestamp() ) @@ -214,16 +215,16 @@ class HamiltonCacheAdapter( def __init__( self, - path: Union[str, pathlib.Path] = ".hamilton_cache", - metadata_store: Optional[MetadataStore] = None, - result_store: Optional[ResultStore] = None, - default: Optional[Union[Literal[True], Collection[str]]] = None, - recompute: Optional[Union[Literal[True], Collection[str]]] = None, - ignore: Optional[Union[Literal[True], Collection[str]]] = None, - disable: Optional[Union[Literal[True], Collection[str]]] = None, - default_behavior: Optional[CACHING_BEHAVIORS] = None, - default_loader_behavior: Optional[CACHING_BEHAVIORS] = None, - default_saver_behavior: Optional[CACHING_BEHAVIORS] = None, + path: str | pathlib.Path = ".hamilton_cache", + metadata_store: MetadataStore | None = None, + result_store: ResultStore | None = None, + default: Literal[True] | Collection[str] | None = None, + recompute: Literal[True] | Collection[str] | None = None, + ignore: Literal[True] | Collection[str] | None = None, + disable: Literal[True] | Collection[str] | None = None, + default_behavior: CACHING_BEHAVIORS | None = None, + default_loader_behavior: CACHING_BEHAVIORS | None = None, + default_saver_behavior: CACHING_BEHAVIORS | None = None, log_to_file: bool = False, **kwargs, ): @@ -263,21 +264,21 @@ def __init__( self.default_saver_behavior = default_saver_behavior # attributes populated at execution time - self.run_ids: List[str] = [] - self._fn_graphs: Dict[str, FunctionGraph] = {} # {run_id: graph} - self._data_savers: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} - self._data_loaders: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} - self.behaviors: Dict[ - str, Dict[str, CachingBehavior] + self.run_ids: list[str] = [] + self._fn_graphs: dict[str, FunctionGraph] = {} # {run_id: graph} + self._data_savers: dict[str, Collection[str]] = {} # {run_id: list[node_name]} + self._data_loaders: dict[str, Collection[str]] = {} # {run_id: list[node_name]} + self.behaviors: dict[ + str, dict[str, CachingBehavior] ] = {} # {run_id: {node_name: behavior}} - self.data_versions: Dict[ - str, Dict[str, Union[str, Dict[str, str]]] + self.data_versions: dict[ + str, dict[str, str | dict[str, str]] ] = {} # {run_id: {node_name: version}} or {run_id: {node_name: {task_id: version}}} - self.code_versions: Dict[str, Dict[str, str]] = {} # {run_id: {node_name: version}} - self.cache_keys: Dict[ - str, Dict[str, Union[str, Dict[str, str]]] + self.code_versions: dict[str, dict[str, str]] = {} # {run_id: {node_name: version}} + self.cache_keys: dict[ + str, dict[str, str | dict[str, str]] ] = {} # {run_id: {node_name: key}} or {run_id: {node_name: {task_id: key}}} - self._logs: Dict[str, List[CachingEvent]] = {} # {run_id: [logs]} + self._logs: dict[str, list[CachingEvent]] = {} # {run_id: [logs]} @property def last_run_id(self): @@ -319,9 +320,9 @@ def _log_event( node_name: str, actor: Literal["adapter", "metadata_store", "result_store"], event_type: CachingEventType, - msg: Optional[str] = None, - value: Optional[Any] = None, - task_id: Optional[str] = None, + msg: str | None = None, + value: Any | None = None, + task_id: str | None = None, ) -> None: """Add a single event to logs stored in state, keyed by run_id @@ -362,7 +363,7 @@ def _log_event( def _log_by_node_name( self, run_id: str, level: Literal["debug", "info"] = "info" - ) -> Dict[str, List[str]]: + ) -> dict[str, list[str]]: """For a given run, group logs to key them by ``node_name`` or ``(node_name, run_id)`` if applicable.""" run_logs = collections.defaultdict(list) for event in self._logs[run_id]: @@ -377,7 +378,7 @@ def _log_by_node_name( run_logs[key].append(event) return dict(run_logs) - def logs(self, run_id: Optional[str] = None, level: Literal["debug", "info"] = "info") -> dict: + def logs(self, run_id: str | None = None, level: Literal["debug", "info"] = "info") -> dict: """Execution logs of the cache adapter. :param run_id: If ``None``, return all logged runs. If provided a ``run_id``, group logs by node. @@ -428,10 +429,10 @@ def logs(self, run_id: Optional[str] = None, level: Literal["debug", "info"] = " def _view_run( fn_graph: FunctionGraph, logs, - final_vars: List[str], + final_vars: list[str], inputs: dict, overrides: dict, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, ): """Create a Hamilton visualization of the execution and the cache hits/misses. @@ -468,7 +469,7 @@ def _visualization_styling_function(*, node, node_class, logs): # TODO make this work directly from the metadata_store too # visualization from logs is convenient when debugging someone else's issue - def view_run(self, run_id: Optional[str] = None, output_file_path: Optional[str] = None): + def view_run(self, run_id: str | None = None, output_file_path: str | None = None): """View the dataflow execution, including cache hits/misses. :param run_id: If ``None``, view the last run. If provided a ``run_id``, view that run. @@ -534,7 +535,7 @@ def view_run(self, run_id: Optional[str] = None, output_file_path: Optional[str] ) def _get_node_role( - self, run_id: str, node_name: str, task_id: Optional[str] + self, run_id: str, node_name: str, task_id: str | None ) -> NodeRoleInTaskExecution: """Determine based on the node name and task_id if a node is part of parallel execution.""" if task_id is None: @@ -552,9 +553,7 @@ def _get_node_role( return role - def get_cache_key( - self, run_id: str, node_name: str, task_id: Optional[str] = None - ) -> Union[str, S]: + def get_cache_key(self, run_id: str, node_name: str, task_id: str | None = None) -> str | S: """Get the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. This method is public-facing and can be used directly to inspect the cache. @@ -597,7 +596,7 @@ def get_cache_key( return cache_key def _set_cache_key( - self, run_id: str, node_name: str, cache_key: str, task_id: Optional[str] = None + self, run_id: str, node_name: str, cache_key: str, task_id: str | None = None ) -> None: """Set the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. @@ -631,8 +630,8 @@ def _set_cache_key( ) def _get_memory_data_version( - self, run_id: str, node_name: str, task_id: Optional[str] = None - ) -> Union[str, S]: + self, run_id: str, node_name: str, task_id: str | None = None + ) -> str | S: """Get the ``data_version`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. The behavior depends on the ``CacheBehavior`` (e.g., RECOMPUTE, IGNORE, DISABLE, DEFAULT) and @@ -675,8 +674,8 @@ def _get_memory_data_version( return data_version def _get_stored_data_version( - self, run_id: str, node_name: str, cache_key: str, task_id: Optional[str] = None - ) -> Union[str, S]: + self, run_id: str, node_name: str, cache_key: str, task_id: str | None = None + ) -> str | S: """Get the ``data_version`` stored in the metadata store associated with the ``cache_key``. The ``run_id``, ``node_name``, and ``task_id`` are included only for logging purposes. @@ -698,9 +697,9 @@ def get_data_version( self, run_id: str, node_name: str, - cache_key: Optional[str] = None, - task_id: Optional[str] = None, - ) -> Union[str, S]: + cache_key: str | None = None, + task_id: str | None = None, + ) -> str | S: """Get the ``data_version`` for a specific ``run_id``, ``node_name``, and ``task_id``. This method is public-facing and can be used directly to inspect the cache. This will check data versions @@ -737,7 +736,7 @@ def get_data_version( return data_version def _set_memory_metadata( - self, run_id: str, node_name: str, data_version: str, task_id: Optional[str] = None + self, run_id: str, node_name: str, data_version: str, task_id: str | None = None ) -> None: """Set in-memory data_version whether a task_id is specified or not""" assert data_version is not None @@ -774,7 +773,7 @@ def _set_stored_metadata( node_name: str, cache_key: str, data_version: str, - task_id: Optional[str] = None, + task_id: str | None = None, ) -> None: """Set data_version in the metadata store associated with the cache_key""" self.metadata_store.set( @@ -794,7 +793,7 @@ def _set_stored_metadata( ) def _version_data( - self, node_name: str, run_id: str, result: Any, task_id: Optional[str] = None + self, node_name: str, run_id: str, result: Any, task_id: str | None = None ) -> str: """Create a unique data version for the result""" data_version = fingerprinting.hash_value(result) @@ -827,7 +826,7 @@ def version_data(self, result: Any, run_id: str = None) -> str: # stuff the internal function call to not log event return self._version_data(result=result, run_id=run_id, node_name=None) - def version_code(self, node_name: str, run_id: Optional[str] = None) -> str: + def version_code(self, node_name: str, run_id: str | None = None) -> str: """Create a unique code version for the source code defining the node""" run_id = self.last_run_id if run_id is None else run_id node = self._fn_graphs[run_id].nodes[node_name] @@ -838,8 +837,8 @@ def _execute_node( run_id: str, node_name: str, node_callable: Callable, - node_kwargs: Dict[str, Any], - task_id: Optional[str] = None, + node_kwargs: dict[str, Any], + task_id: str | None = None, ) -> Any: """Simple wrapper that logs the regular execution of a node.""" logger.debug(node_name) @@ -856,10 +855,10 @@ def _execute_node( @staticmethod def _resolve_node_behavior( node: hamilton.node.Node, - default: Optional[Collection[str]] = None, - disable: Optional[Collection[str]] = None, - recompute: Optional[Collection[str]] = None, - ignore: Optional[Collection[str]] = None, + default: Collection[str] | None = None, + disable: Collection[str] | None = None, + recompute: Collection[str] | None = None, + ignore: Collection[str] | None = None, default_behavior: CACHING_BEHAVIORS = "default", default_loader_behavior: CACHING_BEHAVIORS = "default", default_saver_behavior: CACHING_BEHAVIORS = "default", @@ -906,7 +905,7 @@ def _resolve_node_behavior( else: return CachingBehavior.from_string(default_behavior) - def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: + def resolve_behaviors(self, run_id: str) -> dict[str, CachingBehavior]: """Resolve the caching behavior for each node based on the ``@cache`` decorator and the ``Builder.with_cache()`` parameters for a specific ``run_id``. @@ -1011,10 +1010,10 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: def resolve_code_versions( self, run_id: str, - final_vars: Optional[List[str]] = None, - inputs: Optional[Dict[str, Any]] = None, - overrides: Optional[Dict[str, Any]] = None, - ) -> Dict[str, str]: + final_vars: list[str] | None = None, + inputs: dict[str, Any] | None = None, + overrides: dict[str, Any] | None = None, + ) -> dict[str, str]: """Resolve the code version for each node for a specific ``run_id``. This is a user-facing method. @@ -1083,8 +1082,8 @@ def _process_override(self, run_id: str, node_name: str, value: Any) -> None: @staticmethod def _resolve_default_parameter_values( - node_: hamilton.node.Node, node_kwargs: Dict[str, Any] - ) -> Dict[str, Any]: + node_: hamilton.node.Node, node_kwargs: dict[str, Any] + ) -> dict[str, Any]: """ If a node uses the function's default parameter values, they won't be part of the node_kwargs. To ensure a consistent `cache_key` we want to retrieve default parameter @@ -1104,9 +1103,9 @@ def pre_graph_execute( *, run_id: str, graph: FunctionGraph, - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ): """Set up the state of the adapter for a new execution. @@ -1152,8 +1151,8 @@ def pre_node_execute( *, run_id: str, node_: hamilton.node.Node, - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, **future_kwargs, ): """Before node execution or retrieval, create the cache_key and set it in memory. @@ -1253,8 +1252,8 @@ def do_node_execute( *, run_id: str, node_: hamilton.node.Node, - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, **future_kwargs, ): """Try to retrieve stored result from previous executions or execute the node. @@ -1406,10 +1405,10 @@ def post_node_execute( *, run_id: str, node_: hamilton.node.Node, - result: Optional[str], + result: str | None, success: bool = True, - error: Optional[Exception] = None, - task_id: Optional[str] = None, + error: Exception | None = None, + task_id: str | None = None, **future_kwargs, ): """Get the cache_key and data_version stored in memory (respectively from diff --git a/hamilton/caching/cache_key.py b/hamilton/caching/cache_key.py index 6c57c95d9..5b368181e 100644 --- a/hamilton/caching/cache_key.py +++ b/hamilton/caching/cache_key.py @@ -17,7 +17,7 @@ import base64 import zlib -from typing import Dict, Mapping +from collections.abc import Mapping def _compress_string(string: str) -> str: @@ -58,7 +58,7 @@ def decode_key(cache_key: str) -> dict: def create_cache_key( - node_name: str, code_version: str, dependencies_data_versions: Dict[str, str] + node_name: str, code_version: str, dependencies_data_versions: dict[str, str] ) -> str: if len(dependencies_data_versions.keys()) > 0: dependencies_stringified = _encode_str_dict(dependencies_data_versions) diff --git a/hamilton/caching/fingerprinting.py b/hamilton/caching/fingerprinting.py index 3291c429e..119344372 100644 --- a/hamilton/caching/fingerprinting.py +++ b/hamilton/caching/fingerprinting.py @@ -40,7 +40,6 @@ import logging import sys from collections.abc import Mapping, Sequence, Set -from typing import Dict from hamilton.experimental import h_databackends @@ -183,7 +182,7 @@ def hash_unordered_mapping(obj, *args, depth: int = 0, **kwargs) -> str: hash_mapping(foo) == hash_mapping(bar) """ - hashed_mapping: Dict[str, str] = {} + hashed_mapping: dict[str, str] = {} for key, value in obj.items(): hashed_mapping[hash_value(key, depth=depth + 1)] = hash_value(value, depth=depth + 1) diff --git a/hamilton/caching/stores/base.py b/hamilton/caching/stores/base.py index 6e65c2d12..6b53c8e93 100644 --- a/hamilton/caching/stores/base.py +++ b/hamilton/caching/stores/base.py @@ -17,8 +17,9 @@ import abc import pickle +from collections.abc import Sequence from datetime import datetime, timedelta, timezone -from typing import Any, Optional, Sequence, Tuple, Type +from typing import Any from hamilton.htypes import custom_subclass_check from hamilton.io.data_adapters import DataLoader, DataSaver @@ -33,7 +34,7 @@ class ResultRetrievalError(Exception): # Ideally, it would be done earlier in the caching lifecycle. def search_data_adapter_registry( name: str, type_: type -) -> Tuple[Type[DataSaver], Type[DataLoader]]: +) -> tuple[type[DataSaver], type[DataLoader]]: """Find pair of DataSaver and DataLoader registered with `name` and supporting `type_`""" if name not in SAVER_REGISTRY or name not in LOADER_REGISTRY: raise KeyError( @@ -75,7 +76,7 @@ def set(self, data_version: str, result: Any, **kwargs) -> None: """Store ``result`` keyed by ``data_version``.""" @abc.abstractmethod - def get(self, data_version: str, **kwargs) -> Optional[Any]: + def get(self, data_version: str, **kwargs) -> Any | None: """Try to retrieve ``result`` keyed by ``data_version``. If retrieval misses, return ``None``. """ @@ -105,14 +106,14 @@ def initialize(self, run_id: str) -> None: """Setup the metadata store and log the start of the run""" @abc.abstractmethod - def set(self, cache_key: str, data_version: str, **kwargs) -> Optional[Any]: + def set(self, cache_key: str, data_version: str, **kwargs) -> Any | None: """Store the mapping ``cache_key -> data_version``. Can include other metadata (e.g., node name, run id, code version) depending on the implementation. """ @abc.abstractmethod - def get(self, cache_key: str, **kwargs) -> Optional[str]: + def get(self, cache_key: str, **kwargs) -> str | None: """Try to retrieve ``data_version`` keyed by ``cache_key``. If retrieval misses return ``None``. """ @@ -185,9 +186,9 @@ def __init__( def new( cls, value: Any, - expires_in: Optional[timedelta] = None, - saver: Optional[DataSaver] = None, - loader: Optional[DataLoader] = None, + expires_in: timedelta | None = None, + saver: DataSaver | None = None, + loader: DataLoader | None = None, ) -> "StoredResult": if expires_in is not None and not isinstance(expires_in, timedelta): expires_in = timedelta(seconds=expires_in) diff --git a/hamilton/caching/stores/file.py b/hamilton/caching/stores/file.py index b6685280b..11d285bcc 100644 --- a/hamilton/caching/stores/file.py +++ b/hamilton/caching/stores/file.py @@ -18,7 +18,7 @@ import inspect import shutil from pathlib import Path -from typing import Any, Optional +from typing import Any try: from typing import override @@ -49,7 +49,7 @@ def _write_result(file_path: Path, stored_result: StoredResult) -> None: file_path.write_bytes(stored_result.save()) @staticmethod - def _load_result_from_path(path: Path) -> Optional[StoredResult]: + def _load_result_from_path(path: Path) -> StoredResult | None: try: data = path.read_bytes() return StoredResult.load(data) @@ -73,8 +73,8 @@ def set( self, data_version: str, result: Any, - saver_cls: Optional[DataSaver] = None, - loader_cls: Optional[DataLoader] = None, + saver_cls: DataSaver | None = None, + loader_cls: DataLoader | None = None, ) -> None: # != operator on boolean is XOR if bool(saver_cls is not None) != bool(loader_cls is not None): @@ -114,7 +114,7 @@ def set( self._write_result(result_path, stored_result) @override - def get(self, data_version: str) -> Optional[Any]: + def get(self, data_version: str) -> Any | None: result_path = self._path_from_data_version(data_version) stored_result = self._load_result_from_path(result_path) diff --git a/hamilton/caching/stores/memory.py b/hamilton/caching/stores/memory.py index 7717b1994..7ed7ec49f 100644 --- a/hamilton/caching/stores/memory.py +++ b/hamilton/caching/stores/memory.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Any try: from typing import override @@ -31,9 +32,9 @@ class InMemoryMetadataStore(MetadataStore): def __init__(self) -> None: - self._data_versions: Dict[str, str] = {} # {cache_key: data_version} - self._cache_keys_by_run: Dict[str, List[str]] = {} # {run_id: [cache_key]} - self._run_ids: List[str] = [] + self._data_versions: dict[str, str] = {} # {cache_key: data_version} + self._cache_keys_by_run: dict[str, list[str]] = {} # {run_id: [cache_key]} + self._run_ids: list[str] = [] @override def __len__(self) -> int: @@ -52,13 +53,13 @@ def initialize(self, run_id: str) -> None: self._run_ids.append(run_id) @override - def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Optional[Any]: + def set(self, cache_key: str, data_version: str, run_id: str, **kwargs) -> Any | None: """Set the ``data_version`` for ``cache_key`` and associate it with the ``run_id``.""" self._data_versions[cache_key] = data_version self._cache_keys_by_run[run_id].append(cache_key) @override - def get(self, cache_key: str) -> Optional[str]: + def get(self, cache_key: str) -> str | None: """Retrieve the ``data_version`` for ``cache_key``.""" return self._data_versions.get(cache_key, None) @@ -72,7 +73,7 @@ def delete_all(self) -> None: """Delete all stored metadata.""" self._data_versions.clear() - def persist_to(self, metadata_store: Optional[MetadataStore] = None) -> None: + def persist_to(self, metadata_store: MetadataStore | None = None) -> None: """Persist in-memory metadata using another MetadataStore implementation. :param metadata_store: MetadataStore implementation to use for persistence. @@ -155,12 +156,12 @@ def load_from(cls, metadata_store: MetadataStore) -> "InMemoryMetadataStore": return in_memory_metadata_store @override - def get_run_ids(self) -> List[str]: + def get_run_ids(self) -> list[str]: """Return a list of all ``run_id`` values stored.""" return self._run_ids @override - def get_run(self, run_id: str) -> List[Dict[str, str]]: + def get_run(self, run_id: str) -> list[dict[str, str]]: """Return a list of node metadata associated with a run.""" if self._cache_keys_by_run.get(run_id, None) is None: raise IndexError(f"Run ID not found: {run_id}") @@ -183,7 +184,7 @@ def get_run(self, run_id: str) -> List[Dict[str, str]]: class InMemoryResultStore(ResultStore): def __init__(self, persist_on_exit: bool = False) -> None: - self._results: Dict[str, StoredResult] = {} # {data_version: result} + self._results: dict[str, StoredResult] = {} # {data_version: result} @override def exists(self, data_version: str) -> bool: @@ -195,7 +196,7 @@ def set(self, data_version: str, result: Any, **kwargs) -> None: self._results[data_version] = StoredResult.new(value=result) @override - def get(self, data_version: str) -> Optional[Any]: + def get(self, data_version: str) -> Any | None: stored_result = self._results.get(data_version, None) if stored_result is None: return None @@ -222,7 +223,7 @@ def delete_expired(self) -> None: for data_version in to_delete: self.delete(data_version) - def persist_to(self, result_store: Optional[ResultStore] = None) -> None: + def persist_to(self, result_store: ResultStore | None = None) -> None: """Persist in-memory results using another ``ResultStore`` implementation. :param result_store: ResultStore implementation to use for persistence. @@ -238,8 +239,8 @@ def persist_to(self, result_store: Optional[ResultStore] = None) -> None: def load_from( cls, result_store: ResultStore, - metadata_store: Optional[MetadataStore] = None, - data_versions: Optional[Sequence[str]] = None, + metadata_store: MetadataStore | None = None, + data_versions: Sequence[str] | None = None, ) -> "InMemoryResultStore": """Load in-memory results from another ResultStore instance. diff --git a/hamilton/caching/stores/sqlite.py b/hamilton/caching/stores/sqlite.py index 9a5fe425a..ec285e405 100644 --- a/hamilton/caching/stores/sqlite.py +++ b/hamilton/caching/stores/sqlite.py @@ -18,7 +18,6 @@ import pathlib import sqlite3 import threading -from typing import List, Optional from hamilton.caching.cache_key import decode_key from hamilton.caching.stores.base import MetadataStore @@ -28,7 +27,7 @@ class SQLiteMetadataStore(MetadataStore): def __init__( self, path: str, - connection_kwargs: Optional[dict] = None, + connection_kwargs: dict | None = None, ) -> None: self._directory = pathlib.Path(path).resolve() self._directory.mkdir(parents=True, exist_ok=True) @@ -173,7 +172,7 @@ def set( self.connection.commit() - def get(self, cache_key: str) -> Optional[str]: + def get(self, cache_key: str) -> str | None: cur = self.connection.cursor() cur.execute( """\ @@ -217,7 +216,7 @@ def exists(self, cache_key: str) -> bool: return result is not None - def get_run_ids(self) -> List[str]: + def get_run_ids(self) -> list[str]: """Return a list of run ids, sorted from oldest to newest start time.""" cur = self.connection.cursor() cur.execute("SELECT run_id FROM run_ids ORDER BY id") @@ -244,7 +243,7 @@ def _run_exists(self, run_id: str) -> bool: # SELECT EXISTS returns 1 for True, i.e., `run_id` is found return result[0] == 1 - def get_run(self, run_id: str) -> List[dict]: + def get_run(self, run_id: str) -> list[dict]: """Return a list of node metadata associated with a run. :param run_id: ID of the run to retrieve diff --git a/hamilton/cli/__main__.py b/hamilton/cli/__main__.py index e5321bcb4..13036954d 100644 --- a/hamilton/cli/__main__.py +++ b/hamilton/cli/__main__.py @@ -20,9 +20,10 @@ import logging import os import warnings +from collections.abc import Callable from pathlib import Path from pprint import pprint -from typing import Annotated, Any, Callable, List, Optional +from typing import Annotated, Any import typer @@ -45,17 +46,17 @@ class Response: class CliState: - verbose: Optional[bool] = None - json_out: Optional[bool] = None - dr: Optional[driver.Driver] = None - name: Optional[str] = None + verbose: bool | None = None + json_out: bool | None = None + dr: driver.Driver | None = None + name: str | None = None cli = typer.Typer(rich_markup_mode="rich") state = CliState() MODULES_ANNOTATIONS = Annotated[ - List[Path], + list[Path], typer.Argument( help="Paths to Hamilton modules", exists=True, @@ -66,12 +67,12 @@ class CliState: ] NAME_ANNOTATIONS = Annotated[ - Optional[str], + str | None, typer.Option("--name", "-n", help="Name of the dataflow. Default: Derived from MODULES."), ] CONTEXT_ANNOTATIONS = Annotated[ - Optional[Path], + Path | None, typer.Option( "--context", "-ctx", @@ -303,7 +304,7 @@ def ui( no_migration: bool = False, no_open: bool = False, settings_file: str = "mini", - config_file: Optional[str] = None, + config_file: str | None = None, ): """Runs the Hamilton UI on sqllite in port 8241""" try: diff --git a/hamilton/cli/commands.py b/hamilton/cli/commands.py index 51b5f03d7..d83cd9719 100644 --- a/hamilton/cli/commands.py +++ b/hamilton/cli/commands.py @@ -16,13 +16,12 @@ # under the License. from pathlib import Path -from typing import List, Optional from hamilton import ad_hoc_utils, driver from hamilton.cli import logic -def build(modules: List[Path], context_path: Optional[Path] = None): +def build(modules: list[Path], context_path: Path | None = None): """Build a Hamilton driver from the passed modules, and load the Driver config from the context file. @@ -43,11 +42,11 @@ def build(modules: List[Path], context_path: Optional[Path] = None): def diff( current_dr: driver.Driver, - modules: List[Path], - git_reference: Optional[str] = "HEAD", + modules: list[Path], + git_reference: str | None = "HEAD", view: bool = False, output_file_path: Path = Path("./diff.png"), - context_path: Optional[Path] = None, + context_path: Path | None = None, ) -> dict: """Get the diff of""" context = logic.load_context(context_path) if context_path else {} diff --git a/hamilton/cli/logic.py b/hamilton/cli/logic.py index e5f0f563b..5a427b791 100644 --- a/hamilton/cli/logic.py +++ b/hamilton/cli/logic.py @@ -17,7 +17,6 @@ from pathlib import Path from types import ModuleType -from typing import Dict, List, Union from hamilton import driver @@ -48,7 +47,7 @@ def get_git_base_directory() -> str: raise FileNotFoundError("Git command not found. Please make sure Git is installed.") from e -def get_git_reference(git_relative_path: Union[str, Path], git_reference: str) -> str: +def get_git_reference(git_relative_path: str | Path, git_reference: str) -> str: """Get the source code from the specified file and git reference""" import subprocess @@ -72,11 +71,11 @@ def get_git_reference(git_relative_path: Union[str, Path], git_reference: str) - raise FileNotFoundError("Git command not found. Please make sure Git is installed.") from e -def version_hamilton_functions(module: ModuleType) -> Dict[str, str]: +def version_hamilton_functions(module: ModuleType) -> dict[str, str]: """Hash the source code of Hamilton functions from a module""" from hamilton import graph_types, graph_utils - origins_version: Dict[str, str] = dict() + origins_version: dict[str, str] = dict() for origin_name, _ in graph_utils.find_functions(module): origin_callable = getattr(module, origin_name) @@ -85,7 +84,7 @@ def version_hamilton_functions(module: ModuleType) -> Dict[str, str]: return origins_version -def hash_hamilton_nodes(dr: driver.Driver) -> Dict[str, str]: +def hash_hamilton_nodes(dr: driver.Driver) -> dict[str, str]: """Hash the source code of Hamilton functions from nodes in a Driver""" from hamilton import graph_types @@ -93,7 +92,7 @@ def hash_hamilton_nodes(dr: driver.Driver) -> Dict[str, str]: return {n.name: n.version for n in graph.nodes} -def map_nodes_to_functions(dr: driver.Driver) -> Dict[str, str]: +def map_nodes_to_functions(dr: driver.Driver) -> dict[str, str]: """Get a mapping from node name to Hamilton function name""" from hamilton import graph_types @@ -111,7 +110,7 @@ def map_nodes_to_functions(dr: driver.Driver) -> Dict[str, str]: return node_to_function -def hash_dataflow(nodes_version: Dict[str, str]) -> str: +def hash_dataflow(nodes_version: dict[str, str]) -> str: """Create a dataflow hash from the hashes of its nodes""" import hashlib @@ -120,8 +119,8 @@ def hash_dataflow(nodes_version: Dict[str, str]) -> str: def load_modules_from_git( - module_paths: List[Path], git_reference: str = "HEAD" -) -> List[ModuleType]: + module_paths: list[Path], git_reference: str = "HEAD" +) -> list[ModuleType]: """Dynamically import modules for a git reference""" from hamilton import ad_hoc_utils @@ -138,9 +137,9 @@ def load_modules_from_git( def diff_nodes_against_functions( - nodes_version: Dict[str, str], - origins_version: Dict[str, str], - node_to_origin: Dict[str, str], + nodes_version: dict[str, str], + origins_version: dict[str, str], + node_to_origin: dict[str, str], ) -> dict: """Compare the nodes version from a built Driver to the origins version from module source code when a second @@ -184,7 +183,7 @@ def diff_nodes_against_functions( ) -def diff_versions(current_map: Dict[str, str], reference_map: Dict[str, str]) -> dict: +def diff_versions(current_map: dict[str, str], reference_map: dict[str, str]) -> dict: """Generic diff of two {name: hash} mappings (can be node or origin name) :mapping_v1: mapping from node (or function) name to its function hash @@ -214,9 +213,9 @@ def _custom_diff_style( *, node, node_class, - current_only: List[str], - reference_only: List[str], - edit: List[str], + current_only: list[str], + reference_only: list[str], + edit: list[str], ): """Custom visualization style for the diff of 2 dataflows""" if node.name in current_only: @@ -237,9 +236,9 @@ def _custom_diff_style( def visualize_diff( current_dr: driver.Driver, reference_dr: driver.Driver, - current_only: List[str], - reference_only: List[str], - edit: List[str], + current_only: list[str], + reference_only: list[str], + edit: list[str], ): """Visualize the diff of 2 dataflows. diff --git a/hamilton/common/__init__.py b/hamilton/common/__init__.py index 02b183b3a..bd308b560 100644 --- a/hamilton/common/__init__.py +++ b/hamilton/common/__init__.py @@ -16,12 +16,13 @@ # under the License. # code in this module should no depend on much -from typing import Any, Callable, List, Optional, Set, Tuple, Union +from collections.abc import Callable +from typing import Any, List, Optional, Set, Tuple, Union def convert_output_value( - output_value: Union[str, Callable, Any], module_set: Set[str] -) -> Tuple[Optional[str], Optional[str]]: + output_value: str | Callable | Any, module_set: set[str] +) -> tuple[str | None, str | None]: """Converts output values that one can request into strings. It checks that if it's a function, it's in the passed in module set. @@ -51,8 +52,8 @@ def convert_output_value( def convert_output_values( - output_values: List[Union[str, Callable, Any]], module_set: Set[str] -) -> List[str]: + output_values: list[str | Callable | Any], module_set: set[str] +) -> list[str]: """Checks & converts outputs values to strings. This is used in building dependencies for the DAG. :param output_values: the values to convert. diff --git a/hamilton/data_quality/base.py b/hamilton/data_quality/base.py index 95b459e76..4e498cb0f 100644 --- a/hamilton/data_quality/base.py +++ b/hamilton/data_quality/base.py @@ -19,7 +19,7 @@ import dataclasses import enum import logging -from typing import Any, Dict, List, Tuple, Type +from typing import Any logger = logging.getLogger(__name__) @@ -37,7 +37,7 @@ class DataValidationLevel(enum.Enum): class ValidationResult: passes: bool # Whether or not this passed the validation message: str # Error message or success message - diagnostics: Dict[str, Any] = dataclasses.field( + diagnostics: dict[str, Any] = dataclasses.field( default_factory=dict ) # Any extra diagnostics information needed, free-form @@ -53,7 +53,7 @@ def importance(self) -> DataValidationLevel: return self._importance @abc.abstractmethod - def applies_to(self, datatype: Type[Type]) -> bool: + def applies_to(self, datatype: type[type]) -> bool: """Whether or not this data validator can apply to the specified dataset :param datatype: @@ -105,7 +105,7 @@ def _create_error_string(node_name, validation_result, validator): ) -def act_fail_bulk(node_name: str, failures: List[Tuple[ValidationResult, DataValidator]]): +def act_fail_bulk(node_name: str, failures: list[tuple[ValidationResult, DataValidator]]): """This is the current default for acting on the validation result when you want to fail. Note that we might move this at some point -- we'll want to make it configurable. But for now, this seems like a fine place to put it. @@ -135,7 +135,7 @@ def __init__(self, importance: str): @classmethod @abc.abstractmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: pass @abc.abstractmethod diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index c48376c49..a01fcab02 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -17,7 +17,8 @@ import logging import numbers -from typing import Any, Iterable, List, Tuple, Type, Union +from collections.abc import Iterable +from typing import Any import numpy as np import pandas as pd @@ -28,7 +29,7 @@ class DataInRangeValidatorPandasSeries(base.BaseDefaultValidator): - def __init__(self, range: Tuple[float, float], importance: str): + def __init__(self, range: tuple[float, float], importance: str): """Data validator that tells if data is in a range. This applies to primitives (ints, floats). :param range: Inclusive range of parameters @@ -41,7 +42,7 @@ def arg(cls) -> str: return "range" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) # TODO -- handle dataframes? def description(self) -> str: @@ -86,7 +87,7 @@ def arg(cls) -> str: return "values_in" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) # TODO -- handle dataframes? def description(self) -> str: @@ -121,7 +122,7 @@ def validate(self, data: pd.Series) -> base.ValidationResult: class DataInRangeValidatorPrimitives(base.BaseDefaultValidator): - def __init__(self, range: Tuple[numbers.Real, numbers.Real], importance: str): + def __init__(self, range: tuple[numbers.Real, numbers.Real], importance: str): """Data validator that tells if data is in a range. This applies to primitives (ints, floats). :param range: Inclusive range of parameters @@ -130,7 +131,7 @@ def __init__(self, range: Tuple[numbers.Real, numbers.Real], importance: str): self.range = range @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, numbers.Real) def description(self) -> str: @@ -170,7 +171,7 @@ def arg(cls) -> str: return "values_in" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, numbers.Real) or issubclass( datatype, str ) # TODO support list, dict and typing.* variants @@ -178,7 +179,7 @@ def applies_to(cls, datatype: Type[Type]) -> bool: def description(self) -> str: return f"Validates that python values are from a fixed set of values: ({self.values})." - def validate(self, data: Union[numbers.Real, str]) -> base.ValidationResult: + def validate(self, data: numbers.Real | str) -> base.ValidationResult: if hasattr(data, "dask"): data = data.compute() is_valid_value = data in self.values @@ -208,7 +209,7 @@ def _to_percent(fraction: float): return "{0:.2%}".format(fraction) @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) def description(self) -> str: @@ -260,7 +261,7 @@ def arg(cls) -> str: class DataTypeValidatorPandasSeries(base.BaseDefaultValidator): - def __init__(self, data_type: Type[Type], importance: str): + def __init__(self, data_type: type[type], importance: str): """Constructor :param data_type: the numpy data type to expect. @@ -270,7 +271,7 @@ def __init__(self, data_type: Type[Type], importance: str): self.datatype = data_type @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) def description(self) -> str: @@ -293,7 +294,7 @@ def arg(cls) -> str: class DataTypeValidatorPrimitives(base.BaseDefaultValidator): - def __init__(self, data_type: Type[Type], importance: str): + def __init__(self, data_type: type[type], importance: str): """Constructor :param data_type: the python data type to expect. @@ -303,14 +304,14 @@ def __init__(self, data_type: Type[Type], importance: str): self.datatype = data_type @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, numbers.Real) or datatype in (str, bool) def description(self) -> str: return f"Validates that the datatype of the pandas series is a subclass of: {self.datatype}" def validate( - self, data: Union[numbers.Real, str, bool, int, float, list, dict] + self, data: numbers.Real | str | bool | int | float | list | dict ) -> base.ValidationResult: if hasattr(data, "dask"): data = data.compute() @@ -336,7 +337,7 @@ def __init__(self, max_standard_dev: float, importance: str): self.max_standard_dev = max_standard_dev @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) def description(self) -> str: @@ -362,12 +363,12 @@ def arg(cls) -> str: class MeanInRangeValidatorPandasSeries(base.BaseDefaultValidator): - def __init__(self, mean_in_range: Tuple[float, float], importance: str): + def __init__(self, mean_in_range: tuple[float, float], importance: str): super(MeanInRangeValidatorPandasSeries, self).__init__(importance) self.mean_in_range = mean_in_range @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, pd.Series) def description(self) -> str: @@ -398,7 +399,7 @@ def __init__(self, allow_none: bool, importance: str): self.allow_none = allow_none @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return True def description(self) -> str: @@ -425,7 +426,7 @@ def arg(cls) -> str: class StrContainsValidator(base.BaseDefaultValidator): - def __init__(self, contains: Union[str, List[str]], importance: str): + def __init__(self, contains: str | list[str], importance: str): super(StrContainsValidator, self).__init__(importance) if isinstance(contains, str): self.contains = [contains] @@ -433,7 +434,7 @@ def __init__(self, contains: Union[str, List[str]], importance: str): self.contains = contains @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return datatype == str def description(self) -> str: @@ -457,7 +458,7 @@ def arg(cls) -> str: class StrDoesNotContainValidator(base.BaseDefaultValidator): - def __init__(self, does_not_contain: Union[str, List[str]], importance: str): + def __init__(self, does_not_contain: str | list[str], importance: str): super(StrDoesNotContainValidator, self).__init__(importance) if isinstance(does_not_contain, str): self.does_not_contain = [does_not_contain] @@ -465,7 +466,7 @@ def __init__(self, does_not_contain: Union[str, List[str]], importance: str): self.does_not_contain = does_not_contain @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return datatype == str def description(self) -> str: @@ -526,11 +527,11 @@ def _append_pandera_to_default_validators(): def resolve_default_validators( - output_type: Type[Type], + output_type: type[type], importance: str, - available_validators: List[Type[base.BaseDefaultValidator]] = None, + available_validators: list[type[base.BaseDefaultValidator]] = None, **default_validator_kwargs, -) -> List[base.BaseDefaultValidator]: +) -> list[base.BaseDefaultValidator]: """Resolves default validators given a set pof parameters and the type to which they apply. Note that each (kwarg, type) combination should map to a validator :param importance: importance level of the validator to instantiate diff --git a/hamilton/data_quality/pandera_validators.py b/hamilton/data_quality/pandera_validators.py index c1908e582..25ec89374 100644 --- a/hamilton/data_quality/pandera_validators.py +++ b/hamilton/data_quality/pandera_validators.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any import pandera as pa @@ -34,7 +34,7 @@ def __init__(self, schema: pa.DataFrameSchema, importance: str): self.schema = schema @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: for extension_name in pandera_supported_extensions: if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES: df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.DATAFRAME_TYPE] @@ -80,7 +80,7 @@ def __init__(self, schema: pa.SeriesSchema, importance: str): self.schema = schema @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: for extension_name in pandera_supported_extensions: if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES: df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.COLUMN_TYPE] diff --git a/hamilton/dataflows/__init__.py b/hamilton/dataflows/__init__.py index cae9df9d8..1fece4ab6 100644 --- a/hamilton/dataflows/__init__.py +++ b/hamilton/dataflows/__init__.py @@ -32,11 +32,15 @@ import time import urllib.error import urllib.request +from collections.abc import Callable from types import ModuleType -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Tuple, Type, Union from hamilton import driver, telemetry +if TYPE_CHECKING: + import builtins + logger = logging.getLogger(__name__) """ @@ -76,7 +80,7 @@ def track_call(*args, **kwargs): return track_call -def _track_download(is_official: bool, user: Optional[str], dataflow_name: str, version: str): +def _track_download(is_official: bool, user: str | None, dataflow_name: str, version: str): """Inner function to track "downloads" of a dataflow. :param is_official: is this an official dataflow? False == user. @@ -94,7 +98,7 @@ def _track_download(is_official: bool, user: Optional[str], dataflow_name: str, telemetry.send_event_json(event_json) -def _get_request(url: str) -> Tuple[int, str]: +def _get_request(url: str) -> tuple[int, str]: """Makes a GET request to the given URL and returns the status code and response data. :param url: the url to make the request to. @@ -302,8 +306,8 @@ class InspectResult(NamedTuple): version: str # git commit sha/package version user: str # github user URL dataflow: str # dataflow URL - python_dependencies: List[str] # python dependencies - configurations: List[str] # configurations for the dataflow stored as a JSON string + python_dependencies: list[str] # python dependencies + configurations: list[str] # configurations for the dataflow stored as a JSON string @_track_function_call @@ -345,7 +349,7 @@ def inspect(dataflow: str, user: str = None, version: str = "latest") -> Inspect f"Dataflow {user or 'dagworks'}/{dataflow} with version {version} does not exist locally. Not inspecting." ) # return dictionary of python deps, inputs, nodes, designated outputs, commit hash - info: Dict[str, Union[str, List[Dict], List[str]]] = { + info: dict[str, str | builtins.list[dict] | builtins.list[str]] = { "version": version, "user": user_url, "dataflow": dataflow_url, @@ -367,11 +371,11 @@ class InspectModuleResult(NamedTuple): version: str # git commit sha/package version user: str # github user URL dataflow: str # dataflow URL - python_dependencies: List[str] # python dependencies - configurations: List[str] # configurations for the dataflow stored as a JSON string - possible_inputs: List[Tuple[str, Type]] - nodes: List[Tuple[str, Type]] - designated_outputs: List[Tuple[str, Type]] + python_dependencies: list[str] # python dependencies + configurations: list[str] # configurations for the dataflow stored as a JSON string + possible_inputs: list[tuple[str, type]] + nodes: list[tuple[str, type]] + designated_outputs: list[tuple[str, type]] @_track_function_call @@ -414,7 +418,9 @@ def inspect_module(module: ModuleType) -> InspectModuleResult: f"Dataflow {user or 'dagworks'}/{dataflow} with version {version} does not exist locally. Not inspecting." ) # return dictionary of python deps, inputs, nodes, designated outputs, commit hash - info: Dict[str, Union[str, List[Dict], List[str], List[Tuple[str, Type]]]] = { + info: dict[ + str, str | builtins.list[dict] | builtins.list[str] | builtins.list[tuple[str, type]] + ] = { "version": version, "user": user_url, "dataflow": dataflow_url, diff --git a/hamilton/dev_utils/deprecation.py b/hamilton/dev_utils/deprecation.py index f03f257cb..61de491e2 100644 --- a/hamilton/dev_utils/deprecation.py +++ b/hamilton/dev_utils/deprecation.py @@ -19,7 +19,7 @@ import functools import logging import types -from typing import Callable, Optional, Tuple, Union +from collections.abc import Callable from hamilton import version @@ -36,7 +36,7 @@ def __gt__(self, other: "Version"): return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch) @staticmethod - def from_version_tuple(version_tuple: Tuple[Union[int, str], ...]) -> "Version": + def from_version_tuple(version_tuple: tuple[int | str, ...]) -> "Version": version_ = version_tuple if len(version_) > 3: # This means we have an RC version_ = version_tuple[0:3] # Then let's ignore it @@ -75,16 +75,16 @@ class parameterized(...): """ - warn_starting: Union[Tuple[int, int, int], Version] - fail_starting: Union[Tuple[int, int, int], Version] - use_this: Optional[ - Callable - ] # If this is None, it means this functionality is no longer supported. + warn_starting: tuple[int, int, int] | Version + fail_starting: tuple[int, int, int] | Version + use_this: ( + Callable | None + ) # If this is None, it means this functionality is no longer supported. explanation: str - migration_guide: Optional[ - str - ] # If this is None, this means that the use_this is a drop in replacement - current_version: Union[Tuple[int, int, int], Version] = dataclasses.field( + migration_guide: ( + str | None + ) # If this is None, this means that the use_this is a drop in replacement + current_version: tuple[int, int, int] | Version = dataclasses.field( default_factory=lambda: CURRENT_VERSION ) warn_action: Callable[[str], None] = dataclasses.field(default=logger.warning) @@ -97,7 +97,7 @@ def _raise_failure(message: str): raise DeprecationError(message) @staticmethod - def _ensure_version_type(version_spec: Union[Tuple[int, int, int], Version]) -> Version: + def _ensure_version_type(version_spec: tuple[int, int, int] | Version) -> Version: if isinstance(version_spec, tuple): return Version(*version_spec) return version_spec diff --git a/hamilton/driver.py b/hamilton/driver.py index 3281a950e..6186aa710 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -29,20 +29,13 @@ # required if we want to run this code stand alone. import typing import uuid +from collections.abc import Callable, Collection, Sequence from datetime import datetime from types import ModuleType from typing import ( Any, - Callable, - Collection, - Dict, - List, Literal, Optional, - Sequence, - Set, - Tuple, - Union, ) import pandas as pd @@ -124,11 +117,11 @@ class GraphExecutor(abc.ABC): def execute( self, fg: graph.FunctionGraph, - final_vars: List[Union[str, Callable, Variable]], - overrides: Dict[str, Any], - inputs: Dict[str, Any], + final_vars: list[str | Callable | Variable], + overrides: dict[str, Any], + inputs: dict[str, Any], run_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Executes a graph in a blocking function. :param fg: Graph to execute @@ -142,7 +135,7 @@ def execute( pass @abc.abstractmethod - def validate(self, nodes_to_execute: List[node.Node]): + def validate(self, nodes_to_execute: list[node.Node]): """Validates whether the executor can execute the given graph. Some executors allow API constructs that others do not support (such as Parallelizable[]/Collect[]) @@ -156,14 +149,14 @@ def validate(self, nodes_to_execute: List[node.Node]): class DefaultGraphExecutor(GraphExecutor): DEFAULT_TASK_NAME = "root" # Not task-based, so we just assign a default name for a task - def __init__(self, adapter: Optional[lifecycle_base.LifecycleAdapterSet] = None): + def __init__(self, adapter: lifecycle_base.LifecycleAdapterSet | None = None): """Constructor for the default graph executor. :param adapter: Adapter to use for execution (optional). """ self.adapter = adapter - def validate(self, nodes_to_execute: List[node.Node]): + def validate(self, nodes_to_execute: list[node.Node]): """The default graph executor cannot handle parallelizable[]/collect[] nodes. :param nodes_to_execute: @@ -180,11 +173,11 @@ def validate(self, nodes_to_execute: List[node.Node]): def execute( self, fg: graph.FunctionGraph, - final_vars: List[str], - overrides: Dict[str, Any], - inputs: Dict[str, Any], + final_vars: list[str], + overrides: dict[str, Any], + inputs: dict[str, Any], run_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Basic executor for a function graph. Does no task-based execution, just does a DFS and executes the graph in order, in memory.""" memoized_computation = dict() # memoized storage @@ -201,7 +194,7 @@ def execute( class TaskBasedGraphExecutor(GraphExecutor): - def validate(self, nodes_to_execute: List[node.Node]): + def validate(self, nodes_to_execute: list[node.Node]): """Currently this can run every valid graph""" pass @@ -225,11 +218,11 @@ def __init__( def execute( self, fg: graph.FunctionGraph, - final_vars: List[str], - overrides: Dict[str, Any], - inputs: Dict[str, Any], + final_vars: list[str], + overrides: dict[str, Any], + inputs: dict[str, Any], run_id: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Executes a graph, task by task. This blocks until completion. This does the following: @@ -332,13 +325,10 @@ def __setstate__(self, state): @staticmethod def normalize_adapter_input( - adapter: Optional[ - Union[ - lifecycle_base.LifecycleAdapter, - List[lifecycle_base.LifecycleAdapter], - lifecycle_base.LifecycleAdapterSet, - ] - ], + adapter: lifecycle_base.LifecycleAdapter + | list[lifecycle_base.LifecycleAdapter] + | lifecycle_base.LifecycleAdapterSet + | None, use_legacy_adapter: bool = True, ) -> lifecycle_base.LifecycleAdapterSet: """Normalizes the adapter argument in the driver to a list of adapters. Adds back the legacy adapter if needed. @@ -426,13 +416,13 @@ def _perform_graph_validations( def __init__( self, - config: Dict[str, Any], + config: dict[str, Any], *modules: ModuleType, - adapter: Optional[ - Union[lifecycle_base.LifecycleAdapter, List[lifecycle_base.LifecycleAdapter]] - ] = None, + adapter: lifecycle_base.LifecycleAdapter + | list[lifecycle_base.LifecycleAdapter] + | None = None, allow_module_overrides: bool = False, - _materializers: typing.Sequence[Union[ExtractorFactory, MaterializerFactory]] = None, + _materializers: typing.Sequence[ExtractorFactory | MaterializerFactory] = None, _graph_executor: GraphExecutor = None, _use_legacy_adapter: bool = True, ): @@ -506,9 +496,9 @@ def _repr_mimebundle_(self, include=None, exclude=None, **kwargs): def capture_constructor_telemetry( self, - error: Optional[str], - modules: Tuple[ModuleType], - config: Dict[str, Any], + error: str | None, + modules: tuple[ModuleType], + config: dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet, ): """Captures constructor telemetry. Notes: @@ -548,13 +538,11 @@ def capture_constructor_telemetry( @staticmethod def validate_inputs( fn_graph: graph.FunctionGraph, - adapter: Union[ - lifecycle_base.LifecycleAdapter, - List[lifecycle_base.LifecycleAdapter], - lifecycle_base.LifecycleAdapterSet, - ], + adapter: lifecycle_base.LifecycleAdapter + | list[lifecycle_base.LifecycleAdapter] + | lifecycle_base.LifecycleAdapterSet, user_nodes: Collection[node.Node], - inputs: typing.Optional[Dict[str, Any]] = None, + inputs: dict[str, Any] | None = None, nodes_set: Collection[node.Node] = None, ): """Validates that inputs meet our expectations. This means that: @@ -609,10 +597,10 @@ def validate_inputs( def execute( self, - final_vars: List[Union[str, Callable, Variable]], - overrides: Dict[str, Any] = None, + final_vars: list[str | Callable | Variable], + overrides: dict[str, Any] = None, display_graph: bool = False, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, ) -> Any: """Executes computation. @@ -677,7 +665,7 @@ def execute( ) return outputs - def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -> List[str]: + def _create_final_vars(self, final_vars: list[str | Callable | Variable]) -> list[str]: """Creates the final variables list - converting functions names as required. :param final_vars: @@ -689,10 +677,10 @@ def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) - def capture_execute_telemetry( self, - error: Optional[str], - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + error: str | None, + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], run_successful: bool, duration: float, ): @@ -715,8 +703,8 @@ def capture_execute_telemetry( run_successful, duration, len(final_vars) if final_vars else 0, - len(overrides) if isinstance(overrides, Dict) else 0, - len(inputs) if isinstance(overrides, Dict) else 0, + len(overrides) if isinstance(overrides, dict) else 0, + len(inputs) if isinstance(overrides, dict) else 0, self.driver_run_id, error, ) @@ -735,12 +723,12 @@ def capture_execute_telemetry( ) def raw_execute( self, - final_vars: List[str], - overrides: Dict[str, Any] = None, + final_vars: list[str], + overrides: dict[str, Any] = None, display_graph: bool = False, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, _fn_graph: graph.FunctionGraph = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Raw execute function that does the meat of execute. Don't use this entry point for execution directly. Always go through `.execute()` or `.materialize()`. @@ -811,13 +799,13 @@ def raw_execute( def __raw_execute( self, - final_vars: List[str], - overrides: Dict[str, Any] = None, + final_vars: list[str], + overrides: dict[str, Any] = None, display_graph: bool = False, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, _fn_graph: graph.FunctionGraph = None, _run_id: str = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Raw execute function that does the meat of execute. Private method since the result building and post_graph_execute lifecycle hooks are performed outside and so this returns an incomplete result. @@ -861,8 +849,8 @@ def __raw_execute( @capture_function_usage def list_available_variables( - self, *, tag_filter: Dict[str, Union[Optional[str], List[str]]] = None - ) -> List[Variable]: + self, *, tag_filter: dict[str, str | None | list[str]] = None + ) -> list[Variable]: """Returns available variables, i.e. outputs. These variables correspond 1:1 with nodes in the DAG, and contain the following information: @@ -970,12 +958,12 @@ def display_all_functions( def _visualize_execution_helper( fn_graph: graph.FunctionGraph, adapter: lifecycle_base.LifecycleAdapterSet, - final_vars: List[str], + final_vars: list[str], output_file_path: str, render_kwargs: dict, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, graphviz_kwargs: dict = None, - overrides: Dict[str, Any] = None, + overrides: dict[str, Any] = None, show_legend: bool = True, orient: str = "LR", hide_inputs: bool = False, @@ -1052,12 +1040,12 @@ def _visualize_execution_helper( @capture_function_usage def visualize_execution( self, - final_vars: List[Union[str, Callable, Variable]], + final_vars: list[str | Callable | Variable], output_file_path: str = None, render_kwargs: dict = None, - inputs: Dict[str, Any] = None, + inputs: dict[str, Any] = None, graphviz_kwargs: dict = None, - overrides: Dict[str, Any] = None, + overrides: dict[str, Any] = None, show_legend: bool = True, orient: str = "LR", hide_inputs: bool = False, @@ -1124,9 +1112,9 @@ def visualize_execution( @capture_function_usage def export_execution( self, - final_vars: List[str], - inputs: Dict[str, Any] = None, - overrides: Dict[str, Any] = None, + final_vars: list[str], + inputs: dict[str, Any] = None, + overrides: dict[str, Any] = None, ) -> str: """Method to create JSON representation of the Graph. @@ -1146,7 +1134,7 @@ def export_execution( @capture_function_usage def has_cycles( self, - final_vars: List[Union[str, Callable, Variable]], + final_vars: list[str | Callable | Variable], _fn_graph: graph.FunctionGraph = None, ) -> bool: """Checks that the created graph does not have cycles. @@ -1162,7 +1150,7 @@ def has_cycles( return self.graph.has_cycles(nodes, user_nodes) @capture_function_usage - def what_is_downstream_of(self, *node_names: str) -> List[Variable]: + def what_is_downstream_of(self, *node_names: str) -> list[Variable]: """Tells you what is downstream of this function(s), i.e. node(s). :param node_names: names of function(s) that are starting points for traversing the graph. @@ -1306,7 +1294,7 @@ def display_upstream_of( logger.warning(f"Unable to import {e}", exc_info=True) @capture_function_usage - def what_is_upstream_of(self, *node_names: str) -> List[Variable]: + def what_is_upstream_of(self, *node_names: str) -> list[Variable]: """Tells you what is upstream of this function(s), i.e. node(s). :param node_names: names of function(s) that are starting points for traversing the graph backwards. @@ -1319,7 +1307,7 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]: @capture_function_usage def what_is_the_path_between( self, upstream_node_name: str, downstream_node_name: str - ) -> List[Variable]: + ) -> list[Variable]: """Tells you what nodes are on the path between two nodes. Note: this is inclusive of the two nodes, and returns an unsorted list of nodes. @@ -1345,7 +1333,7 @@ def what_is_the_path_between( def _get_nodes_between( self, upstream_node_name: str, downstream_node_name: str - ) -> Set[node.Node]: + ) -> set[node.Node]: """Gets the nodes representing the path between two nodes, inclusive of the two nodes. Assumes that the nodes exist in the graph. @@ -1365,7 +1353,7 @@ def visualize_path_between( self, upstream_node_name: str, downstream_node_name: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, render_kwargs: dict = None, graphviz_kwargs: dict = None, strict_path_visualization: bool = False, @@ -1467,8 +1455,8 @@ def visualize_path_between( logger.warning(f"Unable to import {e}", exc_info=True) def _process_materializers( - self, materializers: typing.Sequence[Union[MaterializerFactory, ExtractorFactory]] - ) -> Tuple[List[MaterializerFactory], List[ExtractorFactory]]: + self, materializers: typing.Sequence[MaterializerFactory | ExtractorFactory] + ) -> tuple[list[MaterializerFactory], list[ExtractorFactory]]: """Processes materializers, splitting them into materializers and extractors. Note that this also sanitizes the variable names in the materializer dependencies, so one can pass in functions instead of strings. @@ -1488,13 +1476,11 @@ def _process_materializers( @capture_function_usage def materialize( self, - *materializers: Union[ - materialization.MaterializerFactory, materialization.ExtractorFactory - ], - additional_vars: List[Union[str, Callable, Variable]] = None, - overrides: Dict[str, Any] = None, - inputs: Dict[str, Any] = None, - ) -> Tuple[Any, Dict[str, Any]]: + *materializers: materialization.MaterializerFactory | materialization.ExtractorFactory, + additional_vars: list[str | Callable | Variable] = None, + overrides: dict[str, Any] = None, + inputs: dict[str, Any] = None, + ) -> tuple[Any, dict[str, Any]]: """Executes and materializes with ad-hoc materializers (`to`) and extractors (`from_`).This does the following: 1. Creates a new graph, appending the desired materialization nodes and prepending the desired extraction nodes @@ -1748,13 +1734,13 @@ def materialize( @capture_function_usage def visualize_materialization( self, - *materializers: Union[MaterializerFactory, ExtractorFactory], + *materializers: MaterializerFactory | ExtractorFactory, output_file_path: str = None, render_kwargs: dict = None, - additional_vars: List[Union[str, Callable, Variable]] = None, - inputs: Dict[str, Any] = None, + additional_vars: list[str | Callable | Variable] = None, + inputs: dict[str, Any] = None, graphviz_kwargs: dict = None, - overrides: Dict[str, Any] = None, + overrides: dict[str, Any] = None, show_legend: bool = True, orient: str = "LR", hide_inputs: bool = False, @@ -1817,9 +1803,9 @@ def visualize_materialization( def validate_execution( self, - final_vars: List[Union[str, Callable, Variable]], - overrides: Dict[str, Any] = None, - inputs: Dict[str, Any] = None, + final_vars: list[str | Callable | Variable], + overrides: dict[str, Any] = None, + inputs: dict[str, Any] = None, ): """Validates execution of the graph. One can call this to validate execution, independently of actually executing. Note this has no return -- it will raise a ValueError if there is an issue. @@ -1836,9 +1822,9 @@ def validate_execution( def validate_materialization( self, *materializers: materialization.MaterializerFactory, - additional_vars: List[Union[str, Callable, Variable]] = None, - overrides: Dict[str, Any] = None, - inputs: Dict[str, Any] = None, + additional_vars: list[str | Callable | Variable] = None, + overrides: dict[str, Any] = None, + inputs: dict[str, Any] = None, ): """Validates materialization of the graph. Effectively .materialize() with a dry-run. Note this has no return -- it will raise a ValueError if there is an issue. @@ -1899,7 +1885,7 @@ def __init__(self): self.legacy_graph_adapter = None # Standard execution fields - self.adapters: List[lifecycle_base.LifecycleAdapter] = [] + self.adapters: list[lifecycle_base.LifecycleAdapter] = [] # Dynamic execution fields self.execution_manager = None @@ -1933,7 +1919,7 @@ def enable_dynamic_execution(self, *, allow_experimental_mode: bool = False) -> self.v2_executor = True return self - def with_config(self, config: Dict[str, Any]) -> "Builder": + def with_config(self, config: dict[str, Any]) -> "Builder": """Adds the specified configuration to the config. This can be called multilple times -- later calls will take precedence. @@ -1978,7 +1964,7 @@ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> "Builder" return self def with_materializers( - self, *materializers: Union[ExtractorFactory, MaterializerFactory] + self, *materializers: ExtractorFactory | MaterializerFactory ) -> "Builder": """Add materializer nodes to the `Driver` The generated nodes can be referenced by name in `.execute()` @@ -2003,13 +1989,13 @@ def with_materializers( def with_cache( self, - path: Union[str, pathlib.Path] = ".hamilton_cache", - metadata_store: Optional[MetadataStore] = None, - result_store: Optional[ResultStore] = None, - default: Optional[Union[Literal[True], Sequence[str]]] = None, - recompute: Optional[Union[Literal[True], Sequence[str]]] = None, - ignore: Optional[Union[Literal[True], Sequence[str]]] = None, - disable: Optional[Union[Literal[True], Sequence[str]]] = None, + path: str | pathlib.Path = ".hamilton_cache", + metadata_store: MetadataStore | None = None, + result_store: ResultStore | None = None, + default: Literal[True] | Sequence[str] | None = None, + recompute: Literal[True] | Sequence[str] | None = None, + ignore: Literal[True] | Sequence[str] | None = None, + disable: Literal[True] | Sequence[str] | None = None, default_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", default_loader_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", default_saver_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", @@ -2073,7 +2059,7 @@ def with_cache( return self @property - def cache(self) -> Optional[HamiltonCacheAdapter]: + def cache(self) -> HamiltonCacheAdapter | None: """Attribute to check if a cache was set, either via `.with_cache()` or `.with_adapters(SmartCacheAdapter())` diff --git a/hamilton/execution/debugging_utils.py b/hamilton/execution/debugging_utils.py index 016002560..a404f7b1a 100644 --- a/hamilton/execution/debugging_utils.py +++ b/hamilton/execution/debugging_utils.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List from hamilton.execution.grouping import NodeGroup, NodeGroupPurpose, TaskSpec @@ -28,7 +27,7 @@ } -def print_node_groups(node_groups: List[NodeGroup]): +def print_node_groups(node_groups: list[NodeGroup]): """Prints out the node groups in a clean, tree-like format. :param node_groups: @@ -41,7 +40,7 @@ def print_node_groups(node_groups: List[NodeGroup]): print(f" • {node_.name} [ƒ({','.join(map(lambda n: n.name, node_.dependencies))})]") -def print_tasks(tasks: List[TaskSpec]): +def print_tasks(tasks: list[TaskSpec]): """Prints out the node groups in a clean, tree-like format. :param tasks: diff --git a/hamilton/execution/executors.py b/hamilton/execution/executors.py index 1abc9a675..1a0da224f 100644 --- a/hamilton/execution/executors.py +++ b/hamilton/execution/executors.py @@ -26,8 +26,9 @@ else: from concurrent.futures.process import ProcessPoolExecutor +from collections.abc import Callable from concurrent.futures import Executor, Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Protocol +from typing import Any, Protocol from hamilton import node from hamilton.execution.graph_functions import execute_subdag @@ -105,7 +106,7 @@ def _modify_callable(node_source: node.NodeType, callabl: Callable): return callabl -def base_execute_task(task: TaskImplementation) -> Dict[str, Any]: +def base_execute_task(task: TaskImplementation) -> dict[str, Any]: """This is a utility function to execute a base task. In an ideal world this would be recursive, (as in we can use the same task execution/management system as we would otherwise) but for now we just call out to good old DFS. Note that this only returns the result that @@ -359,7 +360,7 @@ class ExecutionManager(abc.ABC): theoretically add metadata in a task as well. """ - def __init__(self, executors: List[TaskExecutor]): + def __init__(self, executors: list[TaskExecutor]): """Initializes the execution manager. Note this does not start it up/claim resources -- you need to call init() to do that. diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 795e7345a..d94551ea2 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -17,8 +17,9 @@ import logging import pprint +from collections.abc import Collection from functools import partial -from typing import Any, Collection, Dict, List, Optional, Set, Tuple +from typing import Any from hamilton import node from hamilton.lifecycle.base import LifecycleAdapterSet @@ -30,7 +31,7 @@ """ -def topologically_sort_nodes(nodes: List[node.Node]) -> List[node.Node]: +def topologically_sort_nodes(nodes: list[node.Node]) -> list[node.Node]: """Topologically sorts a list of nodes based on their dependencies. Note that we bypass utilizing the preset dependencies/depended_on_by attributes of the node, as we may want to use this before these nodes get put in a function graph. @@ -79,7 +80,7 @@ def topologically_sort_nodes(nodes: List[node.Node]) -> List[node.Node]: return sorted_nodes -def get_node_levels(topologically_sorted_nodes: List[node.Node]) -> Dict[str, int]: +def get_node_levels(topologically_sorted_nodes: list[node.Node]) -> dict[str, int]: """Gets the levels for a group of topologically sorted nodes. This only works if its topologically sorted, of course... @@ -98,7 +99,7 @@ def get_node_levels(topologically_sorted_nodes: List[node.Node]) -> Dict[str, in return node_levels -def combine_config_and_inputs(config: Dict[str, Any], inputs: Dict[str, Any]) -> Dict[str, Any]: +def combine_config_and_inputs(config: dict[str, Any], inputs: dict[str, Any]) -> dict[str, Any]: """Validates and combines config and inputs, ensuring that they're mutually disjoint. :param config: Config to construct, run the DAG with. :param inputs: Inputs to run the DAG on at runtime @@ -165,13 +166,13 @@ def create_error_message(kwargs: dict, node_: node.Node, step: str) -> str: def execute_subdag( nodes: Collection[node.Node], - inputs: Dict[str, Any], + inputs: dict[str, Any], adapter: LifecycleAdapterSet = None, - computed: Dict[str, Any] = None, - overrides: Dict[str, Any] = None, + computed: dict[str, Any] = None, + overrides: dict[str, Any] = None, run_id: str = None, task_id: str = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Base function to execute a subdag. This conducts a depth first traversal of the graph. :param nodes: Nodes to compute @@ -274,7 +275,7 @@ def execute_lifecycle_for_node( __adapter: LifecycleAdapterSet, __run_id: str, __task_id: str, - **__kwargs: Dict[str, Any], + **__kwargs: dict[str, Any], ): """Helper function to properly execute node lifecycle. @@ -351,7 +352,7 @@ def execute_lifecycle_for_node( def nodes_between( end_node: node.Node, search_condition: lambda node_: bool, -) -> Tuple[Optional[node.Node], List[node.Node]]: +) -> tuple[node.Node | None, list[node.Node]]: """Utility function to search backwards from an end node to a start node. This returns all nodes for which both of the following conditions are met: @@ -399,7 +400,7 @@ def dfs_traverse(node_: node.Node): return search_node, list(out) -def node_is_required_by_anything(node_: node.Node, node_set: Set[node.Node]) -> bool: +def node_is_required_by_anything(node_: node.Node, node_set: set[node.Node]) -> bool: """Checks dependencies on this node and determines if at least one requires it. Nodes can be optionally depended upon, i.e. the function parameter has a default value. We want to check that diff --git a/hamilton/execution/grouping.py b/hamilton/execution/grouping.py index 59800363e..7fdfdbf57 100644 --- a/hamilton/execution/grouping.py +++ b/hamilton/execution/grouping.py @@ -19,7 +19,8 @@ import dataclasses import enum from collections import defaultdict -from typing import Any, Collection, Dict, List, Optional, Set, Tuple +from collections.abc import Collection +from typing import Any from hamilton import node from hamilton.execution import graph_functions @@ -59,11 +60,11 @@ class NodeGroup: """ base_id: str # Unique ID for node group. - spawning_task_base_id: Optional[str] - nodes: List[Node] + spawning_task_base_id: str | None + nodes: list[Node] purpose: NodeGroupPurpose # TODO -- derive this (or not?) # set of available nodes by this task for querying - available_nodes: Set[str] = dataclasses.field(init=False) + available_nodes: set[str] = dataclasses.field(init=False) def __post_init__(self): self.available_nodes = {node_.name for node_ in self.nodes} @@ -74,14 +75,14 @@ def __hash__(self): def __eq__(self, other): return self.base_id == other.base_id - def get_expander_node(self) -> Optional[Node]: + def get_expander_node(self) -> Node | None: """Returns the expander node for this node group, if it exists""" candidates = [n for n in self.nodes if n.node_role == NodeType.EXPAND] if candidates: return candidates[0] return None - def get_collector_node(self) -> Optional[Node]: + def get_collector_node(self) -> Node | None: """Returns the collector node for this node group, if it exists""" candidates = [n for n in self.nodes if n.node_role == NodeType.COLLECT] if candidates: @@ -102,11 +103,11 @@ class TaskSpec(NodeGroup): """ outputs_to_compute: Collection[str] # list of output names to compute - overrides: Dict[str, Any] # overrides for the task, fixed at the time of creation + overrides: dict[str, Any] # overrides for the task, fixed at the time of creation adapter: lifecycle_base.LifecycleAdapterSet - base_dependencies: List[str] # list of tasks that must be completed before this task can run + base_dependencies: list[str] # list of tasks that must be completed before this task can run - def get_input_vars(self) -> Tuple[List[str], List[str]]: + def get_input_vars(self) -> tuple[list[str], list[str]]: """Returns the node-level dependencies for this node group. This is all of the sources in the subdag. @@ -148,20 +149,20 @@ class TaskImplementation(TaskSpec): # task whose result spawned these tasks # If this is none, it means the graph itself spawned these tasks - group_id: Optional[str] - realized_dependencies: Dict[str, List[str]] # realized dependencies are the actual dependencies + group_id: str | None + realized_dependencies: dict[str, list[str]] # realized dependencies are the actual dependencies # Note that these are lists as we have "gather" operations - spawning_task_id: Optional[str] # task that spawned this task + spawning_task_id: str | None # task that spawned this task task_id: str = dataclasses.field(init=False) - dynamic_inputs: Dict[str, Any] = dataclasses.field(default_factory=dict) + dynamic_inputs: dict[str, Any] = dataclasses.field(default_factory=dict) run_id: str = dataclasses.field(default_factory=str) - def bind(self, dynamic_inputs: Dict[str, Any]) -> "TaskImplementation": + def bind(self, dynamic_inputs: dict[str, Any]) -> "TaskImplementation": """Binds dynamic inputs to the task spec, returning a new task spec""" return dataclasses.replace(self, dynamic_inputs={**dynamic_inputs, **self.dynamic_inputs}) @staticmethod - def determine_task_id(base_id: str, spawning_task: Optional[str], group_id: Optional[str]): + def determine_task_id(base_id: str, spawning_task: str | None, group_id: str | None): return ".".join( filter(lambda i: i is not None, [spawning_task, group_id, base_id]) ) # This will do for now... @@ -183,7 +184,7 @@ class GroupingStrategy(abc.ABC): """Base class for grouping nodes""" @abc.abstractmethod - def group_nodes(self, nodes: List[node.Node]) -> List[NodeGroup]: + def group_nodes(self, nodes: list[node.Node]) -> list[NodeGroup]: """Groups nodes into a list of node groups""" pass @@ -198,7 +199,7 @@ class GroupByRepeatableBlocks(GroupingStrategy): @staticmethod def nodes_after_last_expand_block( collect_node: node.Node, - ) -> Tuple[node.Node, List[node.Node]]: + ) -> tuple[node.Node, list[node.Node]]: """Utility function to yield all nodes between a start and an end node. This returns all nodes for which the following conditions are met: @@ -214,7 +215,7 @@ def is_expander(node_: node.Node) -> bool: return graph_functions.nodes_between(collect_node, is_expander) - def group_nodes(self, nodes: List[node.Node]) -> List[NodeGroup]: + def group_nodes(self, nodes: list[node.Node]) -> list[NodeGroup]: """Groups nodes into blocks. This works as follows: 1. Fina all the Parallelizable[] nodes in the DAG 2. For each of those, do a DFS until the next Collect[] node @@ -293,7 +294,7 @@ def convert_node_type_to_group_purpose(node_type: NodeType) -> NodeGroupPurpose: class GroupNodesIndividually(GroupingStrategy): """Groups nodes into individual blocks.""" - def group_nodes(self, nodes: List[node.Node]): + def group_nodes(self, nodes: list[node.Node]): return [ NodeGroup( base_id=node_.name, @@ -308,7 +309,7 @@ def group_nodes(self, nodes: List[node.Node]): class GroupNodesAllAsOne(GroupingStrategy): """Groups nodes all into one block. TODO -- add validation.""" - def group_nodes(self, nodes: List[node.Node]): + def group_nodes(self, nodes: list[node.Node]): return [ NodeGroup( base_id="root", @@ -320,7 +321,7 @@ def group_nodes(self, nodes: List[node.Node]): class GroupNodesByLevel(GroupingStrategy): - def group_nodes(self, nodes: List[node.Node]) -> List[NodeGroup]: + def group_nodes(self, nodes: list[node.Node]) -> list[NodeGroup]: in_order = topologically_sort_nodes(nodes) node_levels = get_node_levels(in_order) nodes_by_level = defaultdict(list) @@ -342,11 +343,11 @@ def group_nodes(self, nodes: List[node.Node]) -> List[NodeGroup]: def create_task_plan( - node_groups: List[NodeGroup], - outputs: List[str], - overrides: Dict[str, Any], + node_groups: list[NodeGroup], + outputs: list[str], + overrides: dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet, -) -> List[TaskSpec]: +) -> list[TaskSpec]: """Creates tasks from node groups. This occurs after we group and after execute() is called in the driver. Knowing what the user wants, we can finally create the tasks. diff --git a/hamilton/execution/state.py b/hamilton/execution/state.py index e356db334..a28b2800f 100644 --- a/hamilton/execution/state.py +++ b/hamilton/execution/state.py @@ -19,7 +19,7 @@ import collections import enum import logging -from typing import Any, Dict, List, Optional +from typing import Any from hamilton.execution.grouping import NodeGroupPurpose, TaskImplementation, TaskSpec @@ -49,9 +49,9 @@ class ResultCache(abc.ABC): @abc.abstractmethod def write( self, - results: Dict[str, Any], - group_id: Optional[str] = None, - spawning_task_id: Optional[str] = None, + results: dict[str, Any], + group_id: str | None = None, + spawning_task_id: str | None = None, ): """Writes results to the cache. This is called after a task is run. @@ -63,11 +63,11 @@ def write( @abc.abstractmethod def read( self, - keys: List[str], - group_id: Optional[str] = None, - spawning_task_id: Optional[str] = None, + keys: list[str], + group_id: str | None = None, + spawning_task_id: str | None = None, optional: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Reads results in bulk from the cache. :param spawning_task_id: Task ID of the task that spawned the task @@ -83,19 +83,17 @@ def read( class DictBasedResultCache(ResultCache): """Cache of intermediate results. Will likely want to add pruning to this...""" - def __init__(self, cache: Dict[str, Any]): + def __init__(self, cache: dict[str, Any]): self.cache = cache - def _format_key( - self, group_id: Optional[str], spawning_task_id: Optional[str], key: str - ) -> str: + def _format_key(self, group_id: str | None, spawning_task_id: str | None, key: str) -> str: return ":".join([item for item in [spawning_task_id, group_id, key] if item is not None]) def write( self, - results: Dict[str, Any], - group_id: Optional[str] = None, - spawning_task_id: Optional[str] = None, + results: dict[str, Any], + group_id: str | None = None, + spawning_task_id: str | None = None, ): results_with_key_assigned = { self._format_key(group_id, spawning_task_id, key): value @@ -105,11 +103,11 @@ def write( def read( self, - keys: List[str], - group_id: Optional[str] = None, - spawning_task_id: Optional[str] = None, + keys: list[str], + group_id: str | None = None, + spawning_task_id: str | None = None, optional: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Reads results in bulk from the cache. If its optional, we don't mind if its not there. :param keys: Keys to read @@ -139,7 +137,7 @@ class ExecutionState: 3. Prep the task for execution (give it the results it needs) """ - def __init__(self, tasks: List[TaskSpec], result_cache: ResultCache, run_id: str): + def __init__(self, tasks: list[TaskSpec], result_cache: ResultCache, run_id: str): """Initializes an ExecutionState to all uninitialized. TBD if we want to add in an initialization step that can, say, read from a db. @@ -162,7 +160,7 @@ def __init__(self, tasks: List[TaskSpec], result_cache: ResultCache, run_id: str self.base_reverse_dependencies = self.compute_reverse_dependencies(tasks) @staticmethod - def compute_reverse_dependencies(tasks: List[TaskSpec]) -> Dict[str, List[TaskSpec]]: + def compute_reverse_dependencies(tasks: list[TaskSpec]) -> dict[str, list[TaskSpec]]: """Computes dependencies in reverse order, E.G. what tasks depend on this task. :param tasks: @@ -174,7 +172,7 @@ def compute_reverse_dependencies(tasks: List[TaskSpec]) -> Dict[str, List[TaskSp reverse_dependencies[dependency].append(task) return reverse_dependencies - def _initialize_task_pool(self, tasks: List[TaskSpec]): + def _initialize_task_pool(self, tasks: list[TaskSpec]): """Initializes the task pool to all nodes that have no dependencies. :param tasks: @@ -187,7 +185,7 @@ def _initialize_task_pool(self, tasks: List[TaskSpec]): if len(task.base_dependencies) == 0: self.realize_task(task, None, None, None) - def _initialize_base_task_pool(self, tasks: List[TaskSpec]): + def _initialize_base_task_pool(self, tasks: list[TaskSpec]): """Initializes the base task pool to all nodes. :param tasks: @@ -210,10 +208,10 @@ def _initialize_task_queue(self): def realize_task( self, task_spec: TaskSpec, - spawning_task: Optional[str], - group_id: Optional[str], - dependencies: Dict[str, List[str]] = None, - bind: Dict[str, Any] = None, + spawning_task: str | None, + group_id: str | None, + dependencies: dict[str, list[str]] = None, + bind: dict[str, Any] = None, ): """Creates a task and enqueues it to the internal queue. This takes a task in "Plan" state (E.G. a TaskSpec), and trasnforms it into execution-ready state, freezing the dependencies. @@ -255,9 +253,9 @@ def realize_task( def realize_parameterized_group( self, spawning_task_id: str, - parameterizations: Dict[str, Any], + parameterizations: dict[str, Any], input_to_parameterize: str, - ) -> List[TaskImplementation]: + ) -> list[TaskImplementation]: """Parameterizes an unordered expand group. These are tasks that are all part of the same group. For every result in the list, the input of the next task gets the result. @@ -332,7 +330,7 @@ def realize_parameterized_group( out.append(new_task) return out - def write_task_results(self, writer: TaskImplementation, results: Dict[str, Any]): + def write_task_results(self, writer: TaskImplementation, results: dict[str, Any]): results_to_write = results if writer.purpose.is_expander(): # In this case we need to write each result individually @@ -342,9 +340,7 @@ def write_task_results(self, writer: TaskImplementation, results: Dict[str, Any] # write the rest with the appropriate namespace self.result_cache.write(results_to_write, writer.group_id, writer.spawning_task_id) - def update_task_state( - self, task_id: str, new_state: TaskState, results: Optional[Dict[str, Any]] - ): + def update_task_state(self, task_id: str, new_state: TaskState, results: dict[str, Any] | None): """Updates the state of a task based on an external push. This also determines which tasks to create/enqueue next. If the node finished succesfully, this has two steps: @@ -514,7 +510,7 @@ def bind_task(self, task: TaskImplementation): } return task.bind(dynamic_inputs) - def release_next_task(self) -> Optional[TaskImplementation]: + def release_next_task(self) -> TaskImplementation | None: """Gives the next task to run, and inserts inputs it needs to run. Note that this can return None, which means there is nothing to run. That indicates that, either: diff --git a/hamilton/experimental/decorators/parameterize_frame.py b/hamilton/experimental/decorators/parameterize_frame.py index 10aee4a53..98f234662 100644 --- a/hamilton/experimental/decorators/parameterize_frame.py +++ b/hamilton/experimental/decorators/parameterize_frame.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List import pandas as pd @@ -40,7 +39,7 @@ def _get_dep_type(dep_type: str) -> UpstreamDependency: raise ValueError(f"Invalid dep type: {dep_type}") -def _get_index_levels(index: pd.MultiIndex) -> List[list]: +def _get_index_levels(index: pd.MultiIndex) -> list[list]: out = [[] for _ in index[0]] for specific_index in index: for i, key in enumerate(specific_index): @@ -59,7 +58,7 @@ def _validate_df_parameterization(parameterization: pd.DataFrame): ) -def _convert_params_from_df(parameterization: pd.DataFrame) -> List[ParameterizedExtract]: +def _convert_params_from_df(parameterization: pd.DataFrame) -> list[ParameterizedExtract]: _validate_df_parameterization(parameterization) args, dep_types = _get_index_levels(parameterization.columns) dep_types_converted = [_get_dep_type(val) for val in dep_types] diff --git a/hamilton/experimental/h_cache.py b/hamilton/experimental/h_cache.py index af5913475..5ddd7c6b5 100644 --- a/hamilton/experimental/h_cache.py +++ b/hamilton/experimental/h_cache.py @@ -19,8 +19,9 @@ import logging import os import pickle +from collections.abc import Callable from functools import singledispatch -from typing import Any, Callable, Dict, Optional, Set, Type +from typing import Any import typing_inspect @@ -301,9 +302,9 @@ def __init__( self, cache_path: str, *args, - force_compute: Optional[Set[str]] = None, - writers: Optional[Dict[str, Callable[[Any, str, str], None]]] = None, - readers: Optional[Dict[str, Callable[[Any, str], Any]]] = None, + force_compute: set[str] | None = None, + writers: dict[str, Callable[[Any, str, str], None]] | None = None, + readers: dict[str, Callable[[Any, str], Any]] | None = None, **kwargs, ): """Constructs the adapter. @@ -357,12 +358,12 @@ def _read_cache(self, fmt: str, expected_type: Any, filepath: str) -> None: self._check_format(fmt) return self.readers[fmt](expected_type, filepath) - def _get_empty_expected_type(self, expected_type: Type) -> Any: + def _get_empty_expected_type(self, expected_type: type) -> Any: if typing_inspect.is_generic_type(expected_type): return typing_inspect.get_origin(expected_type)() return expected_type() # This ASSUMES that we can just do `str()`, `pd.DataFrame()`, etc. - def execute_node(self, node: Node, kwargs: Dict[str, Any]) -> Any: + def execute_node(self, node: Node, kwargs: dict[str, Any]) -> Any: """Executes nodes conditionally according to caching rules. This node is executed if at least one of these is true: @@ -406,7 +407,7 @@ def execute_node(self, node: Node, kwargs: Dict[str, Any]) -> Any: self.computed_nodes.add(node.name) return node.callable(**kwargs) - def build_result(self, **outputs: Dict[str, Any]) -> Any: + def build_result(self, **outputs: dict[str, Any]) -> Any: """Clears the computed nodes information and delegates to the super class.""" self.computed_nodes = set() return super().build_result(**outputs) diff --git a/hamilton/experimental/h_databackends.py b/hamilton/experimental/h_databackends.py index 2cb70967b..088d8be42 100644 --- a/hamilton/experimental/h_databackends.py +++ b/hamilton/experimental/h_databackends.py @@ -47,7 +47,6 @@ def _(df: h_databackends.AbstractIbisDataFrame) -> pyarrow.Schema: import importlib import inspect -from typing import Tuple from hamilton.experimental.databackend import AbstractBackend @@ -152,7 +151,7 @@ class AbstractNumpyArray(AbstractBackend): _backends = [("numpy", "ndarray")] -def register_backends() -> Tuple[Tuple[type], Tuple[type]]: +def register_backends() -> tuple[tuple[type], tuple[type]]: """Register databackends defined in this module that include `DataFrame` and `Column` in their class name """ diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index bdad4e55e..bc5c89348 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -18,7 +18,8 @@ import inspect import logging import typing -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type +from collections.abc import Callable, Collection +from typing import Any import typing_inspect @@ -46,7 +47,7 @@ class AdapterFactory: """Factory for data loaders. This handles the fact that we pass in source(...) and value(...) parameters to the data loaders.""" - def __init__(self, adapter_cls: Type[AdapterCommon], **kwargs: ParametrizedDependency): + def __init__(self, adapter_cls: type[AdapterCommon], **kwargs: ParametrizedDependency): """Initializes an adapter factory. This takes in parameterized dependencies and stores them for later resolution. @@ -97,7 +98,7 @@ def create_saver(self, **resolved_kwargs: Any) -> DataSaver: return self.adapter_cls(**resolved_kwargs) -def resolve_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, Any]]: +def resolve_kwargs(kwargs: dict[str, Any]) -> tuple[dict[str, str], dict[str, Any]]: """Resolves kwargs to a list of dependencies, and a dictionary of name to resolved literal values. @@ -116,15 +117,15 @@ def resolve_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, An def resolve_adapter_class( - type_: Type[Type], loader_classes: List[Type[AdapterCommon]] -) -> Optional[Type[AdapterCommon]]: + type_: type[type], loader_classes: list[type[AdapterCommon]] +) -> type[AdapterCommon] | None: """Resolves the loader class for a function. This will return the most recently registered loader class that applies to the injection type, hence the reversed order. :param fn: Function to inject the loaded data into. :return: The loader class to use. """ - applicable_adapters: List[Type[AdapterCommon]] = [] + applicable_adapters: list[type[AdapterCommon]] = [] loaders_with_any = [] for loader_cls in reversed(loader_classes): # We do this here, rather than in applies_to, as its a bit of a special case @@ -150,7 +151,7 @@ def resolve_adapter_class( class LoadFromDecorator(NodeInjector): def __init__( self, - loader_classes: typing.Sequence[Type[DataLoader]], + loader_classes: typing.Sequence[type[DataLoader]], inject_=None, **kwargs: ParametrizedDependency, ): @@ -165,7 +166,7 @@ def __init__( self.kwargs = kwargs self.inject = inject_ - def _select_param_to_inject(self, params: List[str], fn: Callable) -> str: + def _select_param_to_inject(self, params: list[str], fn: Callable) -> str: """Chooses a parameter to inject, given the parameters available. If self.inject is None (meaning we inject the only parameter), then that's the one. If it is not None, then we need to ensure it is one of the available parameters, in which case we choose it. @@ -185,8 +186,8 @@ def _select_param_to_inject(self, params: List[str], fn: Callable) -> str: return self.inject def get_loader_nodes( - self, inject_parameter: str, load_type: Type[Type], namespace: str = None - ) -> List[node.Node]: + self, inject_parameter: str, load_type: type[type], namespace: str = None + ) -> list[node.Node]: loader_cls = resolve_adapter_class( load_type, self.loader_classes, @@ -205,12 +206,12 @@ def get_loader_nodes( def load_data( __loader_factory: AdapterFactory = loader_factory, - __load_type: Type[Type] = load_type, + __load_type: type[type] = load_type, __resolved_kwargs=resolved_kwargs, __dependencies=dependencies_inverted, __optional_params=loader_cls.get_optional_arguments(), # noqa: B008 **input_kwargs: Any, - ) -> Tuple[load_type, Dict[str, Any]]: + ) -> tuple[load_type, dict[str, Any]]: input_args_with_fixed_dependencies = { __dependencies.get(key, key): value for key, value in input_kwargs.items() } @@ -242,7 +243,7 @@ def get_input_type_key(key: str) -> str: loader_node = node.Node( name=f"{inject_parameter}", callabl=load_data, - typ=Tuple[Dict[str, Any], load_type], + typ=tuple[dict[str, Any], load_type], input_types=input_types, tags={ "hamilton.data_loader": True, @@ -284,8 +285,8 @@ def filter_function(_inject_parameter=inject_parameter, **kwargs): return [loader_node, filter_node] def inject_nodes( - self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Tuple[Collection[node.Node], Dict[str, str]]: + self, params: dict[str, type[type]], config: dict[str, Any], fn: Callable + ) -> tuple[Collection[node.Node], dict[str, str]]: """Generates two nodes: 1. A node that loads the data from the data source, and returns that + metadata 2. A node that takes the data from the data source, injects it into, and runs, the function. @@ -305,7 +306,7 @@ def inject_nodes( return [loader_node, filter_node], {inject_parameter: filter_node.name} - def _get_inject_parameter_from_function(self, fn: Callable) -> Tuple[str, Type[Type]]: + def _get_inject_parameter_from_function(self, fn: Callable) -> tuple[str, type[type]]: """Gets the name of the parameter to inject the data into. :param fn: The function to decorate. @@ -471,7 +472,7 @@ def __call__(self, *args, **kwargs): @classmethod def decorator_factory( - cls, loaders: typing.Sequence[Type[DataLoader]] + cls, loaders: typing.Sequence[type[DataLoader]] ) -> Callable[..., LoadFromDecorator]: """Effectively a partial function for the load_from decorator. Broken into its own ( rather than using functools.partial) as it is a little clearer to parse. @@ -510,7 +511,7 @@ def __getattr__(cls, item: str): class SaveToDecorator(SingleNodeNodeTransformer): def __init__( self, - saver_classes_: typing.Sequence[Type[DataSaver]], + saver_classes_: typing.Sequence[type[DataSaver]], output_name_: str = None, target_: str = None, **kwargs: ParametrizedDependency, @@ -522,7 +523,7 @@ def __init__( self.target = target_ def create_saver_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> node.Node: artifact_name = self.artifact_name artifact_namespace = () @@ -554,7 +555,7 @@ def save_data( __resolved_kwargs=resolved_kwargs, __data_node_name=node_to_save, **input_kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: input_args_with_fixed_dependencies = { __dependencies.get(key, key): value for key, value in input_kwargs.items() } @@ -587,7 +588,7 @@ def get_input_type_key(key: str) -> str: save_node = node.Node( name=artifact_name, callabl=save_data, - typ=Dict[str, Any], + typ=dict[str, Any], input_types=input_types, namespace=artifact_namespace, tags={ @@ -599,7 +600,7 @@ def get_input_type_key(key: str) -> str: return save_node def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Transforms the node to a data saver. @@ -691,7 +692,7 @@ def __call__(self, *args, **kwargs): @classmethod def decorator_factory( - cls, savers: typing.Sequence[Type[DataSaver]] + cls, savers: typing.Sequence[type[DataSaver]] ) -> Callable[..., SaveToDecorator]: """Effectively a partial function for the load_from decorator. Broken into its own ( rather than using functools.partial) as it is a little clearer to parse. @@ -778,7 +779,7 @@ def validate(self, fn: Callable): f"Function: {fn.__qualname__} must return a tuple of type (SOME_TYPE, dict[str, ...]). Instead got (SOME_TYPE, dict[{second_arg_params[0]}, ...]" ) - def generate_nodes(self, fn: Callable, config) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config) -> list[node.Node]: """Generates two nodes. We have to add tags appropriately. The first one is just the fn - with a slightly different name. @@ -877,10 +878,10 @@ def validate(self, fn: Callable): f"Function: {fn.__qualname__} must have a return annotation." ) # check that the return type is a dict - if return_annotation not in (dict, Dict): + if return_annotation not in (dict, dict): raise InvalidDecoratorException(f"Function: {fn.__qualname__} must return a dict.") - def generate_nodes(self, fn: Callable, config) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config) -> list[node.Node]: """Generates same node but all this does is add tags to it. :param fn: :param config: diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 0b538fc39..9a67fefa2 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -27,7 +27,8 @@ except ImportError: # python3.10 and above EllipsisType = type(...) -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Collection +from typing import Any, Union from hamilton import node, registry, settings @@ -129,7 +130,7 @@ def __call__(self, fn: Callable): setattr(fn, lifecycle_name, [self]) return fn - def required_config(self) -> Optional[List[str]]: + def required_config(self) -> list[str] | None: """Declares the required configuration keys for this decorator. Note that these configuration keys will be filtered and passed to the `configuration` parameter of the functions that this decorator uses. @@ -141,7 +142,7 @@ def required_config(self) -> Optional[List[str]]: """ return [] - def optional_config(self) -> Optional[Dict[str, Any]]: + def optional_config(self) -> dict[str, Any] | None: """Declares the optional configuration keys for this decorator. These are configuration keys that can be used by the decorator, but are not required. Along with these we have *defaults*, which we will use to pass to the config. @@ -165,7 +166,7 @@ class NodeResolver(NodeTransformLifecycle): """Decorator to resolve a nodes function. Can modify anything about the function and is run at DAG creation time.""" @abc.abstractmethod - def resolve(self, fn: Callable, config: Dict[str, Any]) -> Optional[Callable]: + def resolve(self, fn: Callable, config: dict[str, Any]) -> Callable | None: """Determines what a function resolves to. Returns None if it should not be included in the DAG. :param fn: Function to resolve @@ -197,7 +198,7 @@ class NodeCreator(NodeTransformLifecycle, abc.ABC): """Abstract class for nodes that "expand" functions into other nodes.""" @abc.abstractmethod - def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config: dict[str, Any]) -> list[node.Node]: """Given a function, converts it to a series of nodes that it produces. :param config: @@ -227,7 +228,7 @@ def allows_multiple(cls) -> bool: class SubDAGModifier(NodeTransformLifecycle, abc.ABC): @abc.abstractmethod def transform_dag( - self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, nodes: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Modifies a DAG consisting of a set of nodes. Note that this is to support the following two base classes. @@ -254,7 +255,7 @@ def processed_data(data: pd.DataFrame) -> pd.DataFrame: """ @staticmethod - def find_injectable_params(nodes: Collection[node.Node]) -> Dict[str, Type[Type]]: + def find_injectable_params(nodes: Collection[node.Node]) -> dict[str, type[type]]: """Identifies required nodes of this subDAG (nodes produced by this function) that aren't satisfied by the nodes inside it. These are "injectable", meaning that we can add more nodes that feed into them. @@ -274,7 +275,7 @@ def find_injectable_params(nodes: Collection[node.Node]) -> Dict[str, Type[Type] return output_deps def transform_dag( - self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, nodes: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Transforms the subDAG by getting the injectable parameters (anything not produced by nodes inside it), then calling the inject_nodes function on it. @@ -298,8 +299,8 @@ def transform_dag( @abc.abstractmethod def inject_nodes( - self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Tuple[List[node.Node], Dict[str, str]]: + self, params: dict[str, type[type]], config: dict[str, Any], fn: Callable + ) -> tuple[list[node.Node], dict[str, str]]: """Adds a set of nodes to inject into the DAG. These get injected into the specified param name, meaning that exactly one of the output nodes will have that name. Note that this also allows input renaming, meaning that the injector can rename the input to something else (to avoid @@ -336,7 +337,7 @@ class NodeExpander(SubDAGModifier): EXPAND_NODES = "expand_nodes" def transform_dag( - self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, nodes: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: if len(nodes) != 1: raise ValueError( @@ -348,7 +349,7 @@ def transform_dag( @abc.abstractmethod def expand_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Given a single node, expands into multiple nodes. Note that this node list includes: 1. Each "output" node (think sink in a DAG) @@ -497,7 +498,7 @@ def compliment( return [node_ for node_ in all_nodes if node_ not in nodes_to_transform] def transform_targets( - self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, targets: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Transforms a set of target nodes. Note that this is just a loop, but abstracting t away gives subclasses control over how this is done, @@ -516,7 +517,7 @@ def transform_targets( return out def transform_dag( - self, nodes: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, nodes: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Finds the sources and sinks and runs the transformer on each sink. Then returns the result of the entire set of sinks. Note that each sink has to have a unique name. @@ -534,7 +535,7 @@ def transform_dag( @abc.abstractmethod def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: pass @@ -567,7 +568,7 @@ def __init__(self): super().__init__(target=None) def transform_targets( - self, targets: Collection[node.Node], config: Dict[str, Any], fn: Callable + self, targets: Collection[node.Node], config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Transforms the target set of nodes. Exists to validate the target set. @@ -606,7 +607,7 @@ def validate_node(self, node_: node.Node): pass def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Transforms the node. Delegates to decorate_node @@ -640,7 +641,7 @@ def decorate_node(self, node_: node.Node) -> node.Node: class DefaultNodeCreator(NodeCreator): - def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config: dict[str, Any]) -> list[node.Node]: return [node.Node.from_fn(fn)] def validate(self, fn: Callable): @@ -648,7 +649,7 @@ def validate(self, fn: Callable): class DefaultNodeResolver(NodeResolver): - def resolve(self, fn: Callable, config: Dict[str, Any]) -> Callable: + def resolve(self, fn: Callable, config: dict[str, Any]) -> Callable: return fn def validate(self, fn): @@ -665,10 +666,10 @@ def decorate_node(self, node_: node.Node) -> node.Node: def resolve_config( name_for_error: str, - config: Dict[str, Any], - config_required: Optional[List[str]], - config_optional_with_defaults: Dict[str, Any], -) -> Dict[str, Any]: + config: dict[str, Any], + config_required: list[str] | None, + config_optional_with_defaults: dict[str, Any], +) -> dict[str, Any]: """Resolves the configuration that a decorator utilizes :param name_for_error: @@ -716,7 +717,7 @@ def validate(self, fn: Callable): pass -def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> Dict[str, Any]: +def filter_config(config: dict[str, Any], decorator: NodeTransformLifecycle) -> dict[str, Any]: """Filters the config to only include the keys in config_required :param config: The config to filter :param config_required: The keys to include @@ -729,8 +730,8 @@ def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> def get_node_decorators( - fn: Callable, config: Dict[str, Any] -) -> Dict[str, List[NodeTransformLifecycle]]: + fn: Callable, config: dict[str, Any] +) -> dict[str, list[NodeTransformLifecycle]]: """Gets the decorators for a function. Contract is this will have one entry for every step of the decorator lifecycle that can always be run (currently everything except NodeExpander) @@ -761,7 +762,7 @@ def get_node_decorators( return defaults -def _add_original_function_to_nodes(fn: Callable, nodes: List[node.Node]) -> List[node.Node]: +def _add_original_function_to_nodes(fn: Callable, nodes: list[node.Node]) -> list[node.Node]: """Adds the original function to the nodes. We do this so that we can have appropriate metadata on the function -- this is valuable to see if/how the function changes over time to manage node versions, etc... @@ -792,7 +793,7 @@ def _resolve_nodes_error(fn: Callable) -> str: return f"Exception occurred while compiling function: {fn.__name__} to nodes" -def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: +def resolve_nodes(fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: """Gets a list of nodes from a function. This is meant to be an abstraction between the node and the function that it implements. This will end up coordinating with the decorators we build to modify nodes. diff --git a/hamilton/function_modifiers/configuration.py b/hamilton/function_modifiers/configuration.py index 7f941b021..41b8c3184 100644 --- a/hamilton/function_modifiers/configuration.py +++ b/hamilton/function_modifiers/configuration.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Collection, Dict, List, Optional +from collections.abc import Callable, Collection +from typing import Any from . import base @@ -26,16 +27,16 @@ class ConfigResolver: """Base class for resolving configuration so we can share the tooling between different functions.""" - def __init__(self, resolves: Callable[[Dict[str, Any]], bool], config_used: List[str]): + def __init__(self, resolves: Callable[[dict[str, Any]], bool], config_used: list[str]): self.resolves = resolves self._config_used = config_used @property - def optional_config(self) -> Dict[str, Any]: + def optional_config(self) -> dict[str, Any]: """Gives the optional configuration for this resolver -- to be used by the @config decorator.""" return {key: None for key in self._config_used} - def __call__(self, config: Dict[str, Any]) -> bool: + def __call__(self, config: dict[str, Any]) -> bool: return self.resolves(config) @staticmethod @@ -47,7 +48,7 @@ def when(**key_value_pairs) -> "ConfigResolver": :return: a configuration decorator """ - def resolves(configuration: Dict[str, Any]) -> bool: + def resolves(configuration: dict[str, Any]) -> bool: return all(value == configuration.get(key) for key, value in key_value_pairs.items()) return ConfigResolver(resolves, config_used=list(key_value_pairs.keys())) @@ -60,7 +61,7 @@ def when_not(**key_value_pairs: Any) -> "ConfigResolver": :return: a configuration decorator """ - def resolves(configuration: Dict[str, Any]) -> bool: + def resolves(configuration: dict[str, Any]) -> bool: return all(value != configuration.get(key) for key, value in key_value_pairs.items()) return ConfigResolver(resolves, config_used=list(key_value_pairs.keys())) @@ -74,7 +75,7 @@ def when_in(**key_value_group_pairs: Collection[Any]) -> "ConfigResolver": :return: a configuration decorator """ - def resolves(configuration: Dict[str, Any]) -> bool: + def resolves(configuration: dict[str, Any]) -> bool: return all( configuration.get(key) in value for key, value in key_value_group_pairs.items() ) @@ -89,7 +90,7 @@ def when_not_in(**key_value_group_pairs: Collection[Any]) -> "ConfigResolver": :return: a configuration decorator """ - def resolves(configuration: Dict[str, Any]) -> bool: + def resolves(configuration: dict[str, Any]) -> bool: return all( configuration.get(key) not in value for key, value in key_value_group_pairs.items() ) @@ -151,9 +152,9 @@ def my_transform__uk(some_input: pd.Series, some_input_c: pd.Series) -> pd.Serie def __init__( self, - resolves: Callable[[Dict[str, Any]], bool], + resolves: Callable[[dict[str, Any]], bool], target_name: str = None, - config_used: List[str] = None, + config_used: list[str] = None, ): """Decorator that resolves a function based on the configuration... @@ -166,7 +167,7 @@ def __init__( self.target_name = target_name self._config_used = config_used - def required_config(self) -> Optional[List[str]]: + def required_config(self) -> list[str] | None: """This returns the required configuration elements. Note that "none" is a sentinel value that means that we actaully don't know what it uses. If either required or optional configs are None, we @@ -180,7 +181,7 @@ def required_config(self) -> Optional[List[str]]: """ return None if self._config_used is None else [] - def optional_config(self) -> Optional[Dict[str, Any]]: + def optional_config(self) -> dict[str, Any] | None: """Everything is optional with None as the required value""" return {key: None for key in self._config_used} if self._config_used is not None else None @@ -189,7 +190,7 @@ def _get_function_name(self, fn: Callable) -> str: return self.target_name return base.sanitize_function_name(fn.__name__) - def resolve(self, fn, config: Dict[str, Any]) -> Callable: + def resolve(self, fn, config: dict[str, Any]) -> Callable: if not self.does_resolve(config): return None # attaches config keys used to resolve function @@ -295,7 +296,7 @@ def helper(...) -> ...: def __init__(self): pass - def resolve(self, *args, **kwargs) -> Optional[Callable]: + def resolve(self, *args, **kwargs) -> Callable | None: """Returning None defaults to not be included in the DAG. :param fn: Function to resolve diff --git a/hamilton/function_modifiers/delayed.py b/hamilton/function_modifiers/delayed.py index fb209b75d..d22ef9231 100644 --- a/hamilton/function_modifiers/delayed.py +++ b/hamilton/function_modifiers/delayed.py @@ -17,7 +17,8 @@ import enum import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple +from collections.abc import Callable +from typing import Any from hamilton import settings from hamilton.function_modifiers.base import ( @@ -34,7 +35,7 @@ class ResolveAt(enum.Enum): VALID_PARAM_KINDS = [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY] -def extract_and_validate_params(fn: Callable) -> Tuple[List[str], Dict[str, Any]]: +def extract_and_validate_params(fn: Callable) -> tuple[list[str], dict[str, Any]]: """Gets the parameters from a function, while validating that the function has *only* named arguments. @@ -140,13 +141,13 @@ def __init__(self, *, when: ResolveAt, decorate_with: Callable[..., NodeTransfor self.decorate_with = decorate_with self._required_config, self._optional_config = extract_and_validate_params(decorate_with) - def required_config(self) -> Optional[List[str]]: + def required_config(self) -> list[str] | None: return self._required_config - def optional_config(self) -> Optional[Dict[str, Any]]: + def optional_config(self) -> dict[str, Any] | None: return self._optional_config - def resolve(self, config: Dict[str, Any], fn: Callable) -> NodeTransformLifecycle: + def resolve(self, config: dict[str, Any], fn: Callable) -> NodeTransformLifecycle: if not config[settings.ENABLE_POWER_USER_MODE]: raise InvalidDecoratorException( "Dynamic functions are only allowed in power user mode!" diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py index 4e89b13bc..eeefb5c4c 100644 --- a/hamilton/function_modifiers/dependencies.py +++ b/hamilton/function_modifiers/dependencies.py @@ -19,7 +19,8 @@ import dataclasses import enum import typing -from typing import Any, Dict, List, Mapping, Sequence, Type +from collections.abc import Mapping, Sequence +from typing import Any import typing_inspect @@ -73,7 +74,7 @@ def get_dependency_type(self) -> ParametrizedDependencySource: class GroupedDependency(ParametrizedDependency, abc.ABC): @classmethod @abc.abstractmethod - def resolve_dependency_type(cls, annotated_type: Type[Type], param_name: str) -> Type[Type]: + def resolve_dependency_type(cls, annotated_type: type[type], param_name: str) -> type[type]: """Resolves dependency type for an annotated parameter. E.G. List[str] -> str, or Dict[str, int] -> int. @@ -86,10 +87,10 @@ def resolve_dependency_type(cls, annotated_type: Type[Type], param_name: str) -> @dataclasses.dataclass class GroupedListDependency(GroupedDependency): - sources: List[ParametrizedDependency] + sources: list[ParametrizedDependency] @classmethod - def resolve_dependency_type(cls, annotated_type: Type[Sequence[Type]], param_name: str): + def resolve_dependency_type(cls, annotated_type: type[Sequence[type]], param_name: str): if typing_inspect.is_optional_type( annotated_type ): # need to pull out the type from Optional. @@ -116,13 +117,13 @@ def get_dependency_type(self) -> ParametrizedDependencySource: @dataclasses.dataclass class GroupedDictDependency(GroupedDependency): - sources: typing.Dict[str, ParametrizedDependency] + sources: dict[str, ParametrizedDependency] def get_dependency_type(self) -> ParametrizedDependencySource: return ParametrizedDependencySource.GROUPED_DICT @classmethod - def resolve_dependency_type(cls, annotated_type: Type[Mapping[str, Type]], param_name: str): + def resolve_dependency_type(cls, annotated_type: type[Mapping[str, type]], param_name: str): if typing_inspect.is_optional_type( annotated_type ): # need to pull out the type from Optional. @@ -184,8 +185,8 @@ def configuration(dependency_on: str) -> ConfigDependency: def _validate_group_params( - dependency_args: List[ParametrizedDependency], - dependency_kwargs: Dict[str, ParametrizedDependency], + dependency_args: list[ParametrizedDependency], + dependency_kwargs: dict[str, ParametrizedDependency], ): """Validates the following for params to group(...): 1. That either dependency_args or dependency_kwargs is non-empty, but not both. diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 3bcf1f74d..0cfe1ea62 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -20,7 +20,8 @@ import functools import inspect import typing -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Collection +from typing import Any import typing_extensions import typing_inspect @@ -119,10 +120,8 @@ def concat(to_concat: List[str]) -> Any: def __init__( self, - **parametrization: Union[ - Dict[str, ParametrizedDependency], - Tuple[Dict[str, ParametrizedDependency], str], - ], + **parametrization: dict[str, ParametrizedDependency] + | tuple[dict[str, ParametrizedDependency], str], ): """Decorator to use to create many functions. @@ -150,8 +149,8 @@ def __init__( } def split_parameterizations( - self, parameterizations: Dict[str, ParametrizedDependency] - ) -> Dict[ParametrizedDependencySource, Dict[str, ParametrizedDependency]]: + self, parameterizations: dict[str, ParametrizedDependency] + ) -> dict[ParametrizedDependencySource, dict[str, ParametrizedDependency]]: """Split parameterizations into two groups: those that are literal values, and those that are upstream nodes. Will have a key for each existing dependency type. @@ -168,7 +167,7 @@ def _get_grouped_list_name(self, index: int, arg_name: str): return f"__{arg_name}_{index}" def expand_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: nodes = [] for ( @@ -447,7 +446,7 @@ def create_one_off_dates(date_index: pd.Series, one_off_date: str) -> pd.Series: """ - def __init__(self, parameter: str, assigned_output: Dict[Tuple[str, str], Any]): + def __init__(self, parameter: str, assigned_output: dict[tuple[str, str], Any]): """Constructor for a modifier that expands a single function into n, each of which corresponds to a function in which the parameter value is replaced by that *specific value*. @@ -455,7 +454,7 @@ def __init__(self, parameter: str, assigned_output: Dict[Tuple[str, str], Any]): :param assigned_output: A map of tuple of [parameter names, documentation] to values """ for node_ in assigned_output.keys(): - if not isinstance(node_, Tuple): + if not isinstance(node_, tuple): raise base.InvalidDecoratorException( f"assigned_output key is incorrect: {node_}. The parameterized decorator needs a dict of " "[name, doc string] -> value to function." @@ -499,7 +498,7 @@ def date_shifter(one_off_date: pd.Series) -> pd.Series: """ - def __init__(self, **parameterization: Dict[str, str]): + def __init__(self, **parameterization: dict[str, str]): """Constructor for a modifier that expands a single function into n, each of which corresponds to replacing\ some subset of the specified parameters with specific upstream nodes. @@ -546,7 +545,7 @@ def __init__(self, **parameterization: Dict[str, str]): "-parameterized", ) class parametrized_input(parameterize): - def __init__(self, parameter: str, variable_inputs: Dict[str, Tuple[str, str]]): + def __init__(self, parameter: str, variable_inputs: dict[str, tuple[str, str]]): """Constructor for a modifier that expands a single function into n, each of which corresponds to the specified parameter replaced by a *specific input column*. @@ -563,7 +562,7 @@ def __init__(self, parameter: str, variable_inputs: Dict[str, Tuple[str, str]]): :param variable_inputs: A map of tuple of [parameter names, documentation] to values """ for val in variable_inputs.values(): - if not isinstance(val, Tuple): + if not isinstance(val, tuple): raise base.InvalidDecoratorException( f"assigned_output key is incorrect: {node}. The parameterized decorator needs a dict of " "input column -> [name, description] to function." @@ -588,7 +587,7 @@ class parameterized_inputs(parameterize_sources): class extract_columns(base.SingleNodeNodeTransformer): - def __init__(self, *columns: Union[Tuple[str, str], str], fill_with: Any = None): + def __init__(self, *columns: tuple[str, str] | str, fill_with: Any = None): """Constructor for a modifier that expands a single function into the following nodes: - n functions, each of which take in the original dataframe and output a specific column @@ -635,7 +634,7 @@ def validate(self, fn: Callable): extract_columns.validate_return_type(fn) def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """For each column to extract, output a node that extracts that column. Also, output the original dataframe generator. @@ -676,7 +675,7 @@ def df_generator(*args, **kwargs) -> Any: series_type = registry.get_column_type_from_df_type(output_type) for column in self.columns: doc_string = base_doc # default doc string of base function. - if isinstance(column, Tuple): # Expand tuple into constituents + if isinstance(column, tuple): # Expand tuple into constituents column, doc_string = column if inspect.iscoroutinefunction(fn): @@ -717,8 +716,8 @@ def extractor_fn( def _determine_fields_to_extract( - fields: Optional[Union[Dict[str, Any], List[str]]], output_type: Any -) -> Dict[str, Any]: + fields: dict[str, Any] | list[str] | None, output_type: Any +) -> dict[str, Any]: """Determines which fields to extract based on user requested fields and the output type of the return type of the function. @@ -733,7 +732,7 @@ def _determine_fields_to_extract( f"`typing.Dict[str, int]`), not: {output_type}" ) - if output_type == dict or output_type == Dict: + if output_type == dict or output_type == dict: # NOTE: typing_inspect.is_generic_type(typing.Dict) without type parameters returns True, # so we need to address the bare dictionaries first before generics. if fields is None or not isinstance(fields, dict): @@ -743,7 +742,7 @@ def _determine_fields_to_extract( ) elif typing_inspect.is_generic_type(output_type): base_type = typing_inspect.get_origin(output_type) - if base_type != dict and base_type != Dict: + if base_type != dict and base_type != dict: raise base.InvalidDecoratorException(output_type_error) if fields is None: raise base.InvalidDecoratorException( @@ -840,11 +839,11 @@ class extract_fields(base.SingleNodeNodeTransformer): """Extracts fields from a dictionary of output.""" output_type: Any - resolved_fields: Dict[str, Type] + resolved_fields: dict[str, type] def __init__( self, - fields: Optional[Union[Dict[str, Any], List[str], Any]] = None, + fields: dict[str, Any] | list[str] | Any | None = None, *others, fill_with: Any = None, ): @@ -878,7 +877,7 @@ def validate(self, fn: Callable): self.resolved_fields = _determine_fields_to_extract(self.fields, self.output_type) def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """For each field to extract, output a node that extracts that field. Also, output the original TypedDict generator. @@ -941,7 +940,7 @@ def extractor_fn(field_to_extract: str = field, **kwargs) -> field_type: # type return output_nodes -def _determine_fields_to_unpack(fields: List[str], output_type: Any) -> List[Type]: +def _determine_fields_to_unpack(fields: list[str], output_type: Any) -> list[type]: """Determines which fields to unpack based on user requested fields and the output type of the return type of the function. @@ -951,7 +950,7 @@ def _determine_fields_to_unpack(fields: List[str], output_type: Any) -> List[Typ """ base_type = typing_inspect.get_origin(output_type) # Returns None when output_type is None - if base_type != tuple and base_type != Tuple: + if base_type != tuple and base_type != tuple: message = ( f"For unpacking fields, the decorated function output type must be either an " f"explicit length tuple (e.g.`tuple[int, str]`, `typing.Tuple[int, str]`) or an " @@ -1017,7 +1016,7 @@ class unpack_fields(base.SingleNodeNodeTransformer): """ output_type: Any - field_types: List[Type] + field_types: list[type] def __init__(self, *fields: str): super().__init__() @@ -1037,7 +1036,7 @@ def validate(self, fn: Callable): @override def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Unpacks the specified fields form the tuple output into separate nodes. @@ -1102,8 +1101,8 @@ class ParameterizedExtract: parameter to the function. """ - outputs: Tuple[str, ...] - input_mapping: Dict[str, ParametrizedDependency] + outputs: tuple[str, ...] + input_mapping: dict[str, ParametrizedDependency] class parameterize_extract_columns(base.NodeExpander): @@ -1147,7 +1146,7 @@ def __init__(self, *extract_config: ParameterizedExtract, reassign_columns: bool self.reassign_columns = reassign_columns def expand_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Expands a node into multiple, given the extract_config passed to parameterize_extract_columns. Goes through all parameterizations, diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index f331e05dc..3258e83f2 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -21,7 +21,8 @@ import logging import typing from collections import Counter, defaultdict -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from collections.abc import Callable, Collection +from typing import Any, Union import pandas as pd @@ -103,7 +104,7 @@ def copy_of_x(x: pd.Series) -> pd.Series: to just allow functions that consist only of one argument, a generic \\*\\*kwargs. """ - def __init__(self, replacing_function: Callable, **argument_mapping: Union[str, List[str]]): + def __init__(self, replacing_function: Callable, **argument_mapping: str | list[str]): """Constructor for a modifier that replaces the annotated functions functionality with something else. Right now this has a very strict validation requirements to make compliance with the framework easy. @@ -115,7 +116,7 @@ def __init__(self, replacing_function: Callable, **argument_mapping: Union[str, self.argument_mapping = argument_mapping @staticmethod - def map_kwargs(kwargs: Dict[str, Any], argument_mapping: Dict[str, str]) -> Dict[str, Any]: + def map_kwargs(kwargs: dict[str, Any], argument_mapping: dict[str, str]) -> dict[str, Any]: """Maps kwargs using the argument mapping. This does 2 things: 1. Replaces all kwargs in passed_in_kwargs with their mapping @@ -137,7 +138,7 @@ def map_kwargs(kwargs: Dict[str, Any], argument_mapping: Dict[str, str]) -> Dict def test_function_signatures_compatible( fn_signature: inspect.Signature, replace_with_signature: inspect.Signature, - argument_mapping: Dict[str, str], + argument_mapping: dict[str, str], ) -> bool: """Tests whether a function signature and the signature of the replacing function are compatible. @@ -170,7 +171,7 @@ def test_function_signatures_compatible( def ensure_function_signature_compatible( og_function: Callable, replacing_function: Callable, - argument_mapping: Dict[str, str], + argument_mapping: dict[str, str], ): """Ensures that a function signature is compatible with the replacing function, given the argument mapping @@ -220,7 +221,7 @@ def validate(self, fn: Callable): fn, self.replacing_function, self.argument_mapping ) - def generate_nodes(self, fn: Callable, config) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config) -> list[node.Node]: """Returns one node which has the replaced functionality :param fn: Function to decorate :param config: Configuration (not used in this) @@ -241,7 +242,7 @@ def wrapper_function(**kwargs): return [node.Node.from_fn(fn).copy_with(callabl=wrapper_function)] -def get_default_tags(fn: Callable) -> Dict[str, str]: +def get_default_tags(fn: Callable) -> dict[str, str]: """Function that encapsulates default tags on a function. :param fn: the function we want to create default tags for. @@ -264,7 +265,7 @@ def get_default_tags(fn: Callable) -> Dict[str, str]: class dynamic_transform(base.NodeCreator): def __init__( self, - transform_cls: Type[models.BaseModel], + transform_cls: type[models.BaseModel], config_param: str, **extra_transform_params, ): @@ -293,7 +294,7 @@ def validate(self, fn: Callable): "Models must have no parameters -- all are passed in through the config" ) - def generate_nodes(self, fn: Callable, config: Dict[str, Any] = None) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config: dict[str, Any] = None) -> list[node.Node]: if self.config_param not in config: raise base.InvalidDecoratorException( f"Configuration has no parameter: {self.config_param}. Did you define it? If so did you spell it right?" @@ -313,7 +314,7 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any] = None) -> List[no ) ] - def require_config(self) -> List[str]: + def require_config(self) -> list[str]: """Returns the configuration parameters that this model requires :return: Just the one config param used by this model @@ -339,13 +340,13 @@ class Applicable: def __init__( self, - fn: Union[Callable, str, None], - args: Tuple[Union[Any, SingleDependency], ...], - kwargs: Dict[str, Union[Any, SingleDependency]], - target_fn: Union[Callable, str, None] = None, - _resolvers: List[ConfigResolver] = None, - _name: Optional[str] = None, - _namespace: Union[str, None, EllipsisType] = ..., + fn: Callable | str | None, + args: tuple[Any | SingleDependency, ...], + kwargs: dict[str, Any | SingleDependency], + target_fn: Callable | str | None = None, + _resolvers: list[ConfigResolver] = None, + _name: str | None = None, + _namespace: str | None | EllipsisType = ..., _target: base.TargetType = None, ): """Instantiates an Applicable. @@ -450,7 +451,7 @@ def namespaced(self, namespace: NamespaceType) -> "Applicable": target_fn=self.target_fn, ) - def resolves(self, config: Dict[str, Any]) -> bool: + def resolves(self, config: dict[str, Any]) -> bool: """Returns whether the Applicable resolves with the given config :param config: Configuration to check @@ -533,7 +534,7 @@ def on_output(self, target: base.TargetType) -> "Applicable": target_fn=self.target_fn, ) - def get_config_elements(self) -> List[str]: + def get_config_elements(self) -> list[str]: """Returns the config elements that this Applicable uses""" out = [] for resolver in self.resolvers: @@ -607,7 +608,7 @@ def validate(self, chain_first_param: bool, allow_custom_namespace: bool): "Current workarounds are to define a wrapper function that assigns types with the proper keyword-friendly arguments." ) from e - def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]: + def resolve_namespace(self, default_namespace: str) -> tuple[str, ...]: """Resolves the namespace -- see rules in `named` for more details. :param default_namespace: namespace to use as a default if we do not wish to override it @@ -622,8 +623,8 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]: ) def bind_function_args( - self, current_param: Optional[str] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + self, current_param: str | None + ) -> tuple[dict[str, Any], dict[str, Any]]: """Binds function arguments, given current, chained parameter :param current_param: Current, chained parameter. None, if we're not chaining. @@ -651,9 +652,7 @@ def bind_function_args( return upstream_inputs, literal_inputs -def step( - fn, *args: Union[SingleDependency, Any], **kwargs: Union[SingleDependency, Any] -) -> Applicable: +def step(fn, *args: SingleDependency | Any, **kwargs: SingleDependency | Any) -> Applicable: """Applies a function to for a node (or a subcomponent of a node). See documentation for `pipe` to see how this is used. @@ -928,8 +927,8 @@ def __init__( raise NotImplementedError("@flow() is not yet supported -- this is ") def _distribute_transforms_to_parameters( - self, params: Dict[str, Type[Type]] - ) -> Dict[str, List[Applicable]]: + self, params: dict[str, type[type]] + ) -> dict[str, list[Applicable]]: """Resolves target option on the transform level. Adds option that we can decide for each applicable which input parameter it will target on top of the global target (if it is set). @@ -963,8 +962,8 @@ def _distribute_transforms_to_parameters( return selected_transforms def _create_valid_parameters_transforms_mapping( - self, mapping: Dict[str, List[Applicable]], fn: Callable, params: Dict[str, Type[Type]] - ) -> Dict[str, List[Applicable]]: + self, mapping: dict[str, list[Applicable]], fn: Callable, params: dict[str, type[type]] + ) -> dict[str, list[Applicable]]: """Checks for a valid distribution of transforms to parameters.""" sig = inspect.signature(fn) param_names = [] @@ -1025,8 +1024,8 @@ def _resolve_namespace( return f"{self.namespace}_{param}" def inject_nodes( - self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Tuple[List[node.Node], Dict[str, str]]: + self, params: dict[str, type[type]], config: dict[str, Any], fn: Callable + ) -> tuple[list[node.Node], dict[str, str]]: """Injects nodes into the graph. This creates a node for each pipe() step, then reassigns the inputs to pass it in.""" @@ -1067,7 +1066,7 @@ def validate(self, fn: Callable): ) # TODO -- validate that the types match on the chain (this is de-facto done later) - def optional_config(self) -> Dict[str, Any]: + def optional_config(self) -> dict[str, Any]: """Declares the optional configuration keys for this decorator. These are configuration keys that can be used by the decorator, but are not required. Along with these we have *defaults*, which we will use to pass to the config. @@ -1221,7 +1220,7 @@ def foo(a:int)->Dict[str,int]: """ @classmethod - def _validate_single_target_level(cls, target: base.TargetType, transforms: Tuple[Applicable]): + def _validate_single_target_level(cls, target: base.TargetType, transforms: tuple[Applicable]): """We want to make sure that target gets applied on a single level. Either choose for each step individually what it targets or set it on the global level where all steps will target the same node(s). @@ -1294,7 +1293,7 @@ def _filter_individual_target(self, node_): return tuple(selected_transforms) def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Injects nodes into the graph. @@ -1356,7 +1355,7 @@ def validate(self, fn: Callable): ) # TODO -- validate that the types match on the chain (this is de-facto done later) - def optional_config(self) -> Dict[str, Any]: + def optional_config(self) -> dict[str, Any]: """Declares the optional configuration keys for this decorator. These are configuration keys that can be used by the decorator, but are not required. Along with these we have *defaults*, which we will use to pass to the config. @@ -1374,9 +1373,9 @@ def optional_config(self) -> Dict[str, Any]: def chain_transforms( target_arg: str, - transforms: List[Applicable], + transforms: list[Applicable], namespace: str, - config: Dict[str, Any], + config: dict[str, Any], fn: Callable, ): """Chaining nodes together sequentially through the a specified argument. @@ -1424,7 +1423,7 @@ def chain_transforms( return nodes, target_arg -def apply_to(fn_: Union[Callable, str], **mutating_fn_kwargs: Union[SingleDependency, Any]): +def apply_to(fn_: Callable | str, **mutating_fn_kwargs: SingleDependency | Any): """Creates an applicable placeholder with potential kwargs that will be applied to a node (or a subcomponent of a node). See documentation for ``mutate`` to see how this is used. It de facto allows a postponed ``step``. @@ -1542,10 +1541,10 @@ def foo(a:int)->Dict[str,int]: def __init__( self, - *target_functions: Union[Applicable, Callable], + *target_functions: Applicable | Callable, collapse: bool = False, _chain: bool = False, - **mutating_function_kwargs: Union[SingleDependency, Any], + **mutating_function_kwargs: SingleDependency | Any, ): """Instantiates a ``mutate`` decorator. diff --git a/hamilton/function_modifiers/metadata.py b/hamilton/function_modifiers/metadata.py index 786e00e42..4c708d652 100644 --- a/hamilton/function_modifiers/metadata.py +++ b/hamilton/function_modifiers/metadata.py @@ -18,7 +18,8 @@ """Decorators that attach metadata to nodes""" import json -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any, Literal from hamilton import htypes, node, registry from hamilton.function_modifiers import base @@ -82,7 +83,7 @@ def __init__( *, target_: base.TargetType = None, bypass_reserved_namespaces_: bool = False, - **tags: Union[str, List[str]], + **tags: str | list[str], ): """Constructor for adding tag annotations to a function. @@ -178,7 +179,7 @@ def validate(self, fn: Callable): class tag_outputs(base.NodeDecorator): - def __init__(self, **tag_mapping: Dict[str, Union[str, List[str]]]): + def __init__(self, **tag_mapping: dict[str, str | list[str]]): """Creates a tag_outputs decorator. Note that this currently does not validate whether the nodes are spelled correctly as it takes in a superset of\ @@ -227,7 +228,7 @@ def decorate_node(self, node_: node.Node) -> node.Node: class SchemaOutput(tag): - def __init__(self, *fields: Tuple[str, str], target_: Optional[str] = None): + def __init__(self, *fields: tuple[str, str], target_: str | None = None): """Initializes SchemaOutput. See docs for `@schema.output` for more details.""" tag_value = ",".join([f"{key}={value}" for key, value in fields]) @@ -271,7 +272,7 @@ class schema: INTERNAL_SCHEMA_OUTPUT_KEY = "hamilton.internal.schema_output" @staticmethod - def output(*fields: Tuple[str, str], target_: Optional[str] = None) -> SchemaOutput: + def output(*fields: tuple[str, str], target_: str | None = None) -> SchemaOutput: """Initializes a `@schema.output` decorator. This takes in a list of fields, which are tuples of the form `(field_name, field_type)`. The field type must be one of the function_modifiers.SchemaTypes types. @@ -302,7 +303,7 @@ def example_schema() -> pd.DataFrame: class RayRemote(tag): - def __init__(self, **options: Union[int, Dict[str, int]]): + def __init__(self, **options: int | dict[str, int]): """Initializes RayRemote. See docs for `@ray_remote_options` for more details.""" ray_tags = {f"ray_remote.{option}": json.dumps(value) for option, value in options.items()} @@ -310,7 +311,7 @@ def __init__(self, **options: Union[int, Dict[str, int]]): super(RayRemote, self).__init__(bypass_reserved_namespaces_=True, **ray_tags) -def ray_remote_options(**kwargs: Union[int, Dict[str, int]]) -> RayRemote: +def ray_remote_options(**kwargs: int | dict[str, int]) -> RayRemote: """Initializes a `@ray_remote_options` decorator. This takes in a list of options to pass to ray.remote(). Supported options include resources, as well as other options: @@ -362,8 +363,8 @@ class cache(base.NodeDecorator): def __init__( self, *, - behavior: Optional[CACHE_BEHAVIORS] = None, - format: Optional[Union[CACHE_MATERIALIZERS, str]] = None, + behavior: CACHE_BEHAVIORS | None = None, + format: CACHE_MATERIALIZERS | str | None = None, target_: base.TargetType = ..., ): """The ``@cache`` decorator can define the behavior and format of a specific node. diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index 8e99237b1..f0b283955 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -20,8 +20,9 @@ import sys import typing from collections import defaultdict +from collections.abc import Callable, Collection from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, TypedDict, Union +from typing import Any, TypedDict _sys_version_info = sys.version_info _version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro) @@ -60,7 +61,7 @@ def derive_type(dependency: dependencies.LiteralDependency): def create_identity_node( - from_: str, typ: Type[Type], name: str, namespace: Tuple[str, ...], tags: Dict[str, Any] + from_: str, typ: type[type], name: str, namespace: tuple[str, ...], tags: dict[str, Any] ) -> node.Node: """Creates an identity node -- this passes through the exact value returned by the upstream node. @@ -87,7 +88,7 @@ def identity(**kwargs): ) -def extract_all_known_types(nodes: Collection[node.Node]) -> Dict[str, Type[Type]]: +def extract_all_known_types(nodes: Collection[node.Node]) -> dict[str, type[type]]: """Extracts all known types from a set of nodes given the dependencies. We have to do this as we don't know the dependency types at compile-time of upstream nodes. That said, this is only used for guessing dependency types of @@ -106,7 +107,7 @@ def extract_all_known_types(nodes: Collection[node.Node]) -> Dict[str, Type[Type def create_static_node( - typ: Type, name: str, value: Any, namespace: Tuple[str, ...], tags: Dict[str, Any] + typ: type, name: str, value: Any, namespace: tuple[str, ...], tags: dict[str, Any] ) -> node.Node: """Utility function to create a static node -- this helps us bridge nodes together. @@ -125,7 +126,7 @@ def node_fn(_value=value): ) -def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]): +def _validate_config_inputs(config: dict[str, Any], inputs: dict[str, Any]): """Validates that the inputs specified in the config are valid. :param original_config: Original configuration @@ -149,8 +150,8 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]): def _resolve_subdag_configuration( - configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str -) -> Dict[str, Any]: + configuration: dict[str, Any], fields: dict[str, Any], function_name: str +) -> dict[str, Any]: """Resolves the configuration for a subdag. :param configuration: the Hamilton configuration @@ -249,12 +250,12 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame: def __init__( self, - *load_from: Union[ModuleType, Callable], - inputs: Dict[str, ParametrizedDependency] = None, - config: Dict[str, Any] = None, + *load_from: ModuleType | Callable, + inputs: dict[str, ParametrizedDependency] = None, + config: dict[str, Any] = None, namespace: str = None, final_node_name: str = None, - external_inputs: List[str] = None, + external_inputs: list[str] = None, ): """Adds a subDAG to the main DAG. @@ -283,8 +284,8 @@ def __init__( @staticmethod def collect_functions( - load_from: Union[Collection[ModuleType], Collection[Callable]], - ) -> List[Callable]: + load_from: Collection[ModuleType] | Collection[Callable], + ) -> list[Callable]: """Utility function to collect functions from a list of callables/modules. :param load_from: A list of callables or modules to load from @@ -302,7 +303,7 @@ def collect_functions( return out @staticmethod - def collect_nodes(config: Dict[str, Any], subdag_functions: List[Callable]) -> List[node.Node]: + def collect_nodes(config: dict[str, Any], subdag_functions: list[Callable]) -> list[node.Node]: nodes = [] for fn in subdag_functions: for node_ in base.resolve_nodes(fn, config): @@ -356,11 +357,11 @@ def _create_additional_static_nodes( @staticmethod def add_namespace( - nodes: List[node.Node], + nodes: list[node.Node], namespace: str, - inputs: Dict[str, Any] = None, - config: Dict[str, Any] = None, - ) -> List[node.Node]: + inputs: dict[str, Any] = None, + config: dict[str, Any] = None, + ) -> list[node.Node]: """Utility function to add a namespace to nodes. :param nodes: @@ -475,7 +476,7 @@ def _derive_name(self, fn: Callable) -> str: """ return fn.__name__ if self.final_node_name is None else self.final_node_name - def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]: + def generate_nodes(self, fn: Callable, configuration: dict[str, Any]) -> Collection[node.Node]: # Resolve all nodes from passed in functions # if self.config has configuration() or value() in it, we need to resolve it resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__) @@ -512,7 +513,7 @@ def validate(self, fn): self._validate_parameterization() - def required_config(self) -> Optional[List[str]]: + def required_config(self) -> list[str] | None: """Currently we do not filter for subdag as we do not *statically* know what configuration is required. This is because we need to parse the function so that we can figure it out, and that is not available at the time that we call required_config. We need to think about @@ -528,9 +529,9 @@ def required_config(self) -> Optional[List[str]]: class SubdagParams(TypedDict): - inputs: NotRequired[Dict[str, ParametrizedDependency]] - config: NotRequired[Dict[str, Any]] - external_inputs: NotRequired[List[str]] + inputs: NotRequired[dict[str, ParametrizedDependency]] + config: NotRequired[dict[str, Any]] + external_inputs: NotRequired[list[str]] class parameterized_subdag(base.NodeCreator): @@ -592,12 +593,12 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame: def __init__( self, - *load_from: Union[ModuleType, Callable], - inputs: Dict[ - str, Union[dependencies.ParametrizedDependency, dependencies.LiteralDependency] + *load_from: ModuleType | Callable, + inputs: dict[ + str, dependencies.ParametrizedDependency | dependencies.LiteralDependency ] = None, - config: Dict[str, Any] = None, - external_inputs: List[str] = None, + config: dict[str, Any] = None, + external_inputs: list[str] = None, **parameterization: SubdagParams, ): """Initializes a parameterized_subdag decorator. @@ -625,7 +626,7 @@ def __init__( self.parameterization = parameterization self.external_inputs = external_inputs if external_inputs is not None else [] - def _gather_subdag_generators(self) -> List[subdag]: + def _gather_subdag_generators(self) -> list[subdag]: subdag_generators = [] for key, parameterization in self.parameterization.items(): subdag_generators.append( @@ -640,7 +641,7 @@ def _gather_subdag_generators(self) -> List[subdag]: ) return subdag_generators - def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node]: + def generate_nodes(self, fn: Callable, config: dict[str, Any]) -> list[node.Node]: generated_nodes = [] for subdag_generator in self._gather_subdag_generators(): generated_nodes.extend(subdag_generator.generate_nodes(fn, config)) @@ -650,7 +651,7 @@ def validate(self, fn: Callable): for subdag_generator in self._gather_subdag_generators(): subdag_generator.validate(fn) - def required_config(self) -> Optional[List[str]]: + def required_config(self) -> list[str] | None: """See note for subdag.required_config -- this is the same pattern. :return: Any required config items. @@ -658,7 +659,7 @@ def required_config(self) -> Optional[List[str]]: return None -def prune_nodes(nodes: List[node.Node], select: Optional[List[str]] = None) -> List[node.Node]: +def prune_nodes(nodes: list[node.Node], select: list[str] | None = None) -> list[node.Node]: """Prunes the nodes to only include those upstream from the select columns. Conducts a depth-first search using the nodes `input_types` field. @@ -725,7 +726,7 @@ class with_columns_base(base.NodeInjector, abc.ABC): # TODO: if we rename the column nodes into something smarter this can be avoided and # can also modify columns in place @staticmethod - def contains_duplicates(nodes_: List[node.Node]) -> bool: + def contains_duplicates(nodes_: list[node.Node]) -> bool: """Ensures that we don't run into name clashing of columns and group operations. In the case when we extract columns for the user, because ``columns_to_pass`` was used, we want @@ -743,7 +744,7 @@ def contains_duplicates(nodes_: List[node.Node]) -> bool: @staticmethod def validate_dataframe( - fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]], required_type: Type + fn: Callable, inject_parameter: str, params: dict[str, type[type]], required_type: type ) -> None: input_types = typing.get_type_hints(fn) if inject_parameter not in params: @@ -762,14 +763,14 @@ def validate_dataframe( def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, on_input: str = None, - select: List[str] = None, + select: list[str] = None, namespace: str = None, - config_required: List[str] = None, - dataframe_type: Type = None, + config_required: list[str] = None, + dataframe_type: type = None, ): """Instantiates a ``@with_columns`` decorator. @@ -832,13 +833,13 @@ def __init__( self.dataframe_type = dataframe_type - def required_config(self) -> List[str]: + def required_config(self) -> list[str]: return self.config_required @abc.abstractmethod def get_initial_nodes( - self, fn: Callable, params: Dict[str, Type[Type]] - ) -> Tuple[str, Collection[node.Node]]: + self, fn: Callable, params: dict[str, type[type]] + ) -> tuple[str, Collection[node.Node]]: """Preparation stage where columns get extracted into nodes. In case `pass_dataframe_as` or `on_input` is used, this should return an empty list (no column nodes) since the users will extract it themselves. @@ -851,7 +852,7 @@ def get_initial_nodes( pass @abc.abstractmethod - def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + def get_subdag_nodes(self, fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: """Creates subdag from the passed in module / functions. :param config: Configuration with which the DAG was constructed. @@ -873,8 +874,8 @@ def chain_subdag_nodes( pass def inject_nodes( - self, params: Dict[str, Type[Type]], config: Dict[str, Any], fn: Callable - ) -> Tuple[List[node.Node], Dict[str, str]]: + self, params: dict[str, type[type]], config: dict[str, Any], fn: Callable + ) -> tuple[list[node.Node], dict[str, str]]: namespace = fn.__name__ if self.namespace is None else self.namespace inject_parameter, initial_nodes = self.get_initial_nodes(fn=fn, params=params) diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index cad487912..019cd7c27 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -17,7 +17,8 @@ import abc from collections import defaultdict -from typing import Any, Callable, Collection, Dict, List, Type +from collections.abc import Callable, Collection +from typing import Any from hamilton import node from hamilton.data_quality import base as dq_base @@ -32,7 +33,7 @@ class BaseDataValidationDecorator(base.NodeTransformer): @abc.abstractmethod - def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + def get_validators(self, node_to_validate: node.Node) -> list[dq_base.DataValidator]: """Returns a list of validators used to transform the nodes. :param node_to_validate: Nodes to which the output of the validator will apply @@ -41,7 +42,7 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida pass def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: raw_node = node.Node( name=node_.name @@ -159,7 +160,7 @@ def __init__(self, *validators: dq_base.DataValidator, target_: base.TargetType super(check_output_custom, self).__init__(target=target_) self.validators = list(validators) - def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + def get_validators(self, node_to_validate: node.Node) -> list[dq_base.DataValidator]: return self.validators @@ -204,7 +205,7 @@ def builds_dataframe(...) -> pd.DataFrame: """ - def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + def get_validators(self, node_to_validate: node.Node) -> list[dq_base.DataValidator]: try: return default_validators.resolve_default_validators( node_to_validate.type, @@ -222,7 +223,7 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida def __init__( self, importance: str = dq_base.DataValidationLevel.WARN.value, - default_validator_candidates: List[Type[dq_base.BaseDefaultValidator]] = None, + default_validator_candidates: list[type[dq_base.BaseDefaultValidator]] = None, target_: base.TargetType = None, **default_validator_kwargs: Any, ): diff --git a/hamilton/graph.py b/hamilton/graph.py index a8e929471..78595c4fe 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -29,9 +29,10 @@ import os.path import pathlib import uuid +from collections.abc import Callable, Collection from enum import Enum from types import ModuleType -from typing import Any, Callable, Collection, Dict, FrozenSet, List, Optional, Set, Tuple, Type +from typing import Any, Optional import hamilton.lifecycle.base as lifecycle_base from hamilton import graph_types, node @@ -57,9 +58,9 @@ class VisualizationNodeModifiers(Enum): def add_dependency( func_node: node.Node, func_name: str, - nodes: Dict[str, node.Node], + nodes: dict[str, node.Node], param_name: str, - param_type: Type, + param_type: type, adapter: lifecycle_base.LifecycleAdapterSet, ): """Adds dependencies to the node objects. @@ -132,7 +133,7 @@ def add_dependency( def update_dependencies( - nodes: Dict[str, node.Node], + nodes: dict[str, node.Node], adapter: lifecycle_base.LifecycleAdapterSet, reset_dependencies: bool = True, ): @@ -161,11 +162,11 @@ def update_dependencies( def create_function_graph( *modules: ModuleType, - config: Dict[str, Any], + config: dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, fg: Optional["FunctionGraph"] = None, allow_module_overrides: bool = False, -) -> Dict[str, node.Node]: +) -> dict[str, node.Node]: """Creates a graph of all available functions & their dependencies. :param modules: A set of modules over which one wants to compute the function graph :param config: Dictionary that we will inspect to get values from in building the function graph. @@ -217,10 +218,10 @@ def _check_keyword_args_only(func: Callable) -> bool: def create_graphviz_graph( - nodes: Set[node.Node], + nodes: set[node.Node], comment: str, graphviz_kwargs: dict, - node_modifiers: Dict[str, Set[VisualizationNodeModifiers]], + node_modifiers: dict[str, set[VisualizationNodeModifiers]], strictly_display_only_nodes_passed_in: bool, show_legend: bool = True, orient: str = "LR", @@ -276,8 +277,8 @@ def create_graphviz_graph( def _get_node_label( n: node.Node, - name: Optional[str] = None, - type_string: Optional[str] = None, + name: str | None = None, + type_string: str | None = None, ) -> str: """Get a graphviz HTML-like node label. It uses the DAG node name and type but values can be overridden. Overriding is currently @@ -316,7 +317,7 @@ def _get_node_label( escaped_type_string = html.escape(type_string, quote=True) return f"<{escaped_display_name}

{escaped_type_string}>" - def _get_input_label(input_nodes: FrozenSet[node.Node]) -> str: + def _get_input_label(input_nodes: frozenset[node.Node]) -> str: """Get a graphviz HTML-like node label formatted as a table. Each row is a different input node with one column containing the name (or display_name if present) and the other the type. @@ -354,7 +355,7 @@ def _get_node_type(n: node.Node) -> str: else: return "function" - def _get_node_style(node_type: str) -> Dict[str, str]: + def _get_node_style(node_type: str) -> dict[str, str]: """Get the style of a node type. Graphviz needs values to be strings. """ @@ -391,7 +392,7 @@ def _get_node_style(node_type: str) -> Dict[str, str]: return node_style - def _get_function_modifier_style(modifier: str) -> Dict[str, str]: + def _get_function_modifier_style(modifier: str) -> dict[str, str]: """Get the style of a modifier. The dictionary returned is used to overwrite values of the base node style. Graphviz needs values to be strings. @@ -417,7 +418,7 @@ def _get_function_modifier_style(modifier: str) -> Dict[str, str]: return modifier_style - def _get_edge_style(from_type: str, to_type: str) -> Dict: + def _get_edge_style(from_type: str, to_type: str) -> dict: """ Graphviz needs values to be strings. @@ -436,7 +437,7 @@ def _get_edge_style(from_type: str, to_type: str) -> Dict: return edge_style def _get_legend( - node_types: Set[str], extra_legend_nodes: Dict[Tuple[str, str], Dict[str, str]] + node_types: set[str], extra_legend_nodes: dict[tuple[str, str], dict[str, str]] ): """Create a visualization legend as a graphviz subgraph. The legend includes the node types and modifiers presente in the visualization. @@ -599,7 +600,7 @@ def _get_legend( seen_node_types.add("cluster") seen_node_types.add("field") - def _create_equal_length_cols(schema_tag: str) -> List[str]: + def _create_equal_length_cols(schema_tag: str) -> list[str]: cols = schema_tag.split(",") for i in range(len(cols)): @@ -705,7 +706,7 @@ def _insert_space_after_colon(col: str) -> str: def create_networkx_graph( - nodes: Set[node.Node], user_nodes: Set[node.Node], name: str + nodes: set[node.Node], user_nodes: set[node.Node], name: str ) -> "networkx.DiGraph": # noqa: F821 """Helper function to create a networkx graph. @@ -736,8 +737,8 @@ class FunctionGraph: def __init__( self, - nodes: Dict[str, Node], - config: Dict[str, Any], + nodes: dict[str, Node], + config: dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, ): """Initializes a function graph from specified nodes. See note on `from_modules` if you @@ -757,7 +758,7 @@ def __init__( @staticmethod def from_modules( *modules: ModuleType, - config: Dict[str, Any], + config: dict[str, Any], adapter: lifecycle_base.LifecycleAdapterSet = None, allow_module_overrides: bool = False, ): @@ -778,7 +779,7 @@ def from_modules( ) return FunctionGraph(nodes, config, adapter) - def with_nodes(self, nodes: Dict[str, Node]) -> "FunctionGraph": + def with_nodes(self, nodes: dict[str, Node]) -> "FunctionGraph": """Creates a new function graph with the additional specified nodes. Note that if there is a duplication in the node definitions, it will error out. @@ -798,10 +799,10 @@ def config(self): return self._config @property - def decorator_counter(self) -> Dict[str, int]: + def decorator_counter(self) -> dict[str, int]: return fm_base.DECORATOR_COUNTER - def get_nodes(self) -> List[node.Node]: + def get_nodes(self) -> list[node.Node]: return list(self.nodes.values()) def display_all( @@ -862,7 +863,7 @@ def display_all( keep_dot=keep_dot, ) - def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool: + def has_cycles(self, nodes: set[node.Node], user_nodes: set[node.Node]) -> bool: """Checks that the graph created does not contain cycles. :param nodes: the set of nodes that need to be computed. @@ -872,7 +873,7 @@ def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool: cycles = self.get_cycles(nodes, user_nodes) return True if cycles else False - def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[List[str]]: + def get_cycles(self, nodes: set[node.Node], user_nodes: set[node.Node]) -> list[list[str]]: """Returns cycles found in the graph. :param nodes: the set of nodes that need to be computed. @@ -893,11 +894,11 @@ def get_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> List[ @staticmethod def display( - nodes: Set[node.Node], - output_file_path: Optional[str] = None, + nodes: set[node.Node], + output_file_path: str | None = None, render_kwargs: dict = None, graphviz_kwargs: dict = None, - node_modifiers: Dict[str, Set[VisualizationNodeModifiers]] = None, + node_modifiers: dict[str, set[VisualizationNodeModifiers]] = None, strictly_display_only_passed_in_nodes: bool = False, show_legend: bool = True, orient: str = "LR", @@ -987,7 +988,7 @@ def display( pathlib.Path(output_file_path).write_bytes(dot.pipe(**kwargs)) return dot - def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]: + def get_impacted_nodes(self, var_changes: list[str]) -> set[node.Node]: """DEPRECATED - use `get_downstream_nodes` instead.""" logger.warning( "FunctionGraph.get_impacted_nodes is deprecated. " @@ -996,7 +997,7 @@ def get_impacted_nodes(self, var_changes: List[str]) -> Set[node.Node]: ) return self.get_downstream_nodes(var_changes) - def get_downstream_nodes(self, var_changes: List[str]) -> Set[node.Node]: + def get_downstream_nodes(self, var_changes: list[str]) -> set[node.Node]: """Given our function graph, and a list of nodes that are changed, returns the subgraph that they will impact. @@ -1010,10 +1011,10 @@ def get_downstream_nodes(self, var_changes: List[str]) -> Set[node.Node]: def get_upstream_nodes( self, - final_vars: List[str], - runtime_inputs: Dict[str, Any] = None, - runtime_overrides: Dict[str, Any] = None, - ) -> Tuple[Set[node.Node], Set[node.Node]]: + final_vars: list[str], + runtime_inputs: dict[str, Any] = None, + runtime_overrides: dict[str, Any] = None, + ) -> tuple[set[node.Node], set[node.Node]]: """Given our function graph, and a list of desired output variables, returns the subgraph required to compute them. @@ -1027,7 +1028,7 @@ def get_upstream_nodes( :return: a tuple of sets: - set of all nodes. - subset of nodes that human input is required for. """ - def next_nodes_function(n: node.Node) -> List[node.Node]: + def next_nodes_function(n: node.Node) -> list[node.Node]: deps = [] if runtime_overrides is not None and n.name in runtime_overrides: return deps @@ -1051,7 +1052,7 @@ def next_nodes_function(n: node.Node) -> List[node.Node]: next_nodes_function, starting_nodes=final_vars, runtime_inputs=runtime_inputs ) - def nodes_between(self, start: str, end: str) -> Set[node.Node]: + def nodes_between(self, start: str, end: str) -> set[node.Node]: """Given our function graph, and a list of desired output variables, returns the subgraph required to compute them. Note that this returns an empty set if no path exists. @@ -1073,8 +1074,8 @@ def nodes_between(self, start: str, end: str) -> Set[node.Node]: def directional_dfs_traverse( self, next_nodes_fn: Callable[[node.Node], Collection[node.Node]], - starting_nodes: List[str], - runtime_inputs: Dict[str, Any] = None, + starting_nodes: list[str], + runtime_inputs: dict[str, Any] = None, ): """Traverses the DAG directionally using a DFS. @@ -1116,11 +1117,11 @@ def dfs_traverse(node: node.Node): def execute( self, nodes: Collection[node.Node] = None, - computed: Dict[str, Any] = None, - overrides: Dict[str, Any] = None, - inputs: Dict[str, Any] = None, + computed: dict[str, Any] = None, + overrides: dict[str, Any] = None, + inputs: dict[str, Any] = None, run_id: str = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Executes the DAG, given potential inputs/previously computed components. :param nodes: Nodes to compute diff --git a/hamilton/graph_types.py b/hamilton/graph_types.py index 89952148c..e4920ca7c 100644 --- a/hamilton/graph_types.py +++ b/hamilton/graph_types.py @@ -87,7 +87,7 @@ def _remove_docs_and_comments(source: str) -> str: return ast.unparse(parsed) -def hash_source_code(source: typing.Union[str, typing.Callable], strip: bool = False) -> str: +def hash_source_code(source: str | typing.Callable, strip: bool = False) -> str: """Hashes the source code of a function (str). The `strip` parameter requires Python 3.9 @@ -118,14 +118,14 @@ class HamiltonNode: Furthermore, we can always add attributes and maintain backwards compatibility.""" name: str - type: typing.Type - tags: typing.Dict[str, typing.Union[str, typing.List[str]]] + type: type + tags: dict[str, str | list[str]] is_external_input: bool - originating_functions: typing.Optional[typing.Tuple[typing.Callable, ...]] - documentation: typing.Optional[str] - required_dependencies: typing.Set[str] - optional_dependencies: typing.Set[str] - optional_dependencies_default_values: typing.Dict[str, typing.Any] + originating_functions: tuple[typing.Callable, ...] | None + documentation: str | None + required_dependencies: set[str] + optional_dependencies: set[str] + optional_dependencies_default_values: dict[str, typing.Any] def as_dict(self, include_optional_dependencies_default_values: bool = False) -> dict: """Create a dictionary representation of the Node that is JSON serializable. @@ -183,7 +183,7 @@ def from_node(n: node.Node) -> "HamiltonNode": ) @functools.cached_property - def version(self) -> typing.Optional[str]: + def version(self) -> str | None: """Generate a hash of the node originating function source code. Note that this will be `None` if the node is an external input/has no @@ -222,7 +222,7 @@ class HamiltonGraph: Note that you do not construct this class directly -- instead, you will get this at various points in the API. """ - nodes: typing.List[HamiltonNode] + nodes: list[HamiltonNode] # store the original graph for internal use @staticmethod @@ -247,7 +247,7 @@ def version(self) -> str: return hashlib.sha256(str(sorted_node_versions).encode()).hexdigest() @functools.cached_property - def __nodes_lookup(self) -> typing.Dict[str, HamiltonNode]: + def __nodes_lookup(self) -> dict[str, HamiltonNode]: """Cache the mapping {node_name: node} for faster `__getitem__`""" return {n.name: n for n in self.nodes} @@ -259,8 +259,6 @@ def __getitem__(self, key: str) -> HamiltonNode: """ return self.__nodes_lookup[key] - def filter_nodes( - self, filter: typing.Callable[[HamiltonNode], bool] - ) -> typing.List[HamiltonNode]: + def filter_nodes(self, filter: typing.Callable[[HamiltonNode], bool]) -> list[HamiltonNode]: """Return Hamilton nodes matching the filter criteria""" return [n for n in self.nodes if filter(n) is True] diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py index e09a1f5d6..1913ef3c0 100644 --- a/hamilton/graph_utils.py +++ b/hamilton/graph_utils.py @@ -16,15 +16,15 @@ # under the License. import inspect +from collections.abc import Callable from types import ModuleType -from typing import Callable, List, Tuple def is_submodule(child: ModuleType, parent: ModuleType): return parent.__name__ in child.__name__ -def find_functions(function_module: ModuleType) -> List[Tuple[str, Callable]]: +def find_functions(function_module: ModuleType) -> list[tuple[str, Callable]]: """Function to determine the set of functions we want to build a graph from. This iterates through the function module and grabs all function definitions. diff --git a/hamilton/htypes.py b/hamilton/htypes.py index 89be4cac3..0c10bcfe0 100644 --- a/hamilton/htypes.py +++ b/hamilton/htypes.py @@ -18,7 +18,8 @@ import inspect import sys import typing -from typing import Any, Iterable, Literal, Optional, Protocol, Tuple, Type, TypeVar, Union +from collections.abc import Iterable +from typing import Any, Literal, Protocol, TypeVar, Union import typing_inspect @@ -27,7 +28,7 @@ BASE_ARGS_FOR_GENERICS = (typing.T,) -def _safe_subclass(candidate_type: Type, base_type: Type) -> bool: +def _safe_subclass(candidate_type: type, base_type: type) -> bool: """Safely checks subclass, returning False if python's subclass does not work. This is *not* a true subclass check, and will not tell you whether hamilton considers the types to be equivalent. Rather, it is used to short-circuit further @@ -48,7 +49,7 @@ def _safe_subclass(candidate_type: Type, base_type: Type) -> bool: return False -def custom_subclass_check(requested_type: Type, param_type: Type): +def custom_subclass_check(requested_type: type, param_type: type): """This is a custom check around generics & classes. It probably misses a few edge cases. We will likely need to revisit this in the future (perhaps integrate with graphadapter?) @@ -104,7 +105,7 @@ def custom_subclass_check(requested_type: Type, param_type: Type): return False -def get_type_as_string(type_: Type) -> Optional[str]: +def get_type_as_string(type_: type) -> str | None: """Get a string representation of a type. The logic supports the evolution of the type system between 3.8 and 3.10. @@ -127,7 +128,7 @@ def get_type_as_string(type_: Type) -> Optional[str]: return type_string -def types_match(param_type: Type[Type], required_node_type: Any) -> bool: +def types_match(param_type: type[type], required_node_type: Any) -> bool: """Checks that we have "types" that "match". Matching can be loose here -- and depends on the adapter being used as to what is @@ -188,7 +189,7 @@ def types_match(param_type: Type[Type], required_node_type: Any) -> bool: else: ANNOTATE_ALLOWED = True - from typing import Annotated, Type + from typing import Annotated column = Annotated @@ -202,7 +203,7 @@ def types_match(param_type: Type[Type], required_node_type: Any) -> bool: from typing import get_origin as _get_origin -def _is_annotated_type(type_: Type[Type]) -> bool: +def _is_annotated_type(type_: type[type]) -> bool: """Utility function to tell if a type is Annotated""" return _get_origin(type_) == column @@ -222,7 +223,7 @@ class InvalidTypeException(Exception): ) -def _is_valid_series_type(candidate_type: Type[Type]) -> bool: +def _is_valid_series_type(candidate_type: type[type]) -> bool: """Tells if something is a valid series type, using the registry we have. :param candidate_type: Type to check @@ -236,7 +237,7 @@ def _is_valid_series_type(candidate_type: Type[Type]) -> bool: return False -def validate_type_annotation(annotation: Type[Type]): +def validate_type_annotation(annotation: type[type]): """Validates a type annotation for a hamilton function. If it is not an Annotated type, it will be fine. If it is the Annotated type, it will check that @@ -271,7 +272,7 @@ def validate_type_annotation(annotation: Type[Type]): ) -def get_type_information(some_type: Any) -> Tuple[Type[Type], list]: +def get_type_information(some_type: Any) -> tuple[type[type], list]: """Gets the type information for a given type. If it is an annotated type, it will return the original type and the annotation. @@ -311,7 +312,7 @@ class Parallelizable(Iterable[ParallelizableElement], Protocol[ParallelizableEle pass -def is_parallelizable_type(type_: Type) -> bool: +def is_parallelizable_type(type_: type) -> bool: return _get_origin(type_) == Parallelizable @@ -326,7 +327,7 @@ class Collect(Iterable[CollectElement], Protocol[CollectElement]): """ -def check_input_type(node_type: Type, input_value: Any) -> bool: +def check_input_type(node_type: type, input_value: Any) -> bool: """Checks an input value against the declare input type. This is a utility function to be used for checking types against values. Note we are looser here than in custom_subclass_check, as runtime-typing is less specific. diff --git a/hamilton/io/data_adapters.py b/hamilton/io/data_adapters.py index 0a918ad2e..851cbbfb6 100644 --- a/hamilton/io/data_adapters.py +++ b/hamilton/io/data_adapters.py @@ -18,7 +18,8 @@ import abc import dataclasses import typing -from typing import Any, Collection, Dict, Tuple, Type +from collections.abc import Collection +from typing import Any from hamilton.htypes import custom_subclass_check @@ -26,7 +27,7 @@ class AdapterCommon(abc.ABC): @classmethod @abc.abstractmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: """Returns the types that this data loader can load to. These will be checked against the desired type to determine whether this is a suitable loader for that type. @@ -44,7 +45,7 @@ def applicable_types(cls) -> Collection[Type]: @classmethod @abc.abstractmethod - def applies_to(cls, type_: Type[Type]) -> bool: + def applies_to(cls, type_: type[type]) -> bool: """Tells whether or not this adapter applies to the given type. Note: you need to understand the edge direction to properly determine applicability. @@ -78,7 +79,7 @@ def _ensure_dataclass(cls): ) @classmethod - def get_required_arguments(cls) -> Dict[str, Type[Type]]: + def get_required_arguments(cls) -> dict[str, type[type]]: """Gives the required arguments for the class. Note that this just uses the type hints from the dataclass. @@ -93,7 +94,7 @@ def get_required_arguments(cls) -> Dict[str, Type[Type]]: } @classmethod - def get_optional_arguments(cls) -> Dict[str, Type[Type]]: + def get_optional_arguments(cls) -> dict[str, type[type]]: """Gives the optional arguments for the class. Note that this just uses the type hints from the dataclass. @@ -135,7 +136,7 @@ class DataLoader(AdapterCommon, abc.ABC): """ @abc.abstractmethod - def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]: + def load_data(self, type_: type[type]) -> tuple[type, dict[str, Any]]: """Loads the data from the data source. Note this uses the constructor parameters to determine how to load the data. @@ -149,7 +150,7 @@ def can_load(cls) -> bool: return True @classmethod - def applies_to(cls, type_: Type[Type]) -> bool: + def applies_to(cls, type_: type[type]) -> bool: """Tells whether or not this data loader can load to a specific type. For instance, a CSV data loader might be able to load to a dataframe, a json, but not an integer. @@ -176,7 +177,7 @@ class DataSaver(AdapterCommon, abc.ABC): """ @abc.abstractmethod - def save_data(self, data: Any) -> Dict[str, Any]: + def save_data(self, data: Any) -> dict[str, Any]: """Saves the data to the data source. Note this uses the constructor parameters to determine how to save the data. @@ -192,7 +193,7 @@ def can_save(cls) -> bool: return True @classmethod - def applies_to(cls, type_: Type[Type]) -> bool: + def applies_to(cls, type_: type[type]) -> bool: """Tells whether or not this data saver can ingest a specific type to save it. I.e. is the adapter type a superclass of the passed in type? diff --git a/hamilton/io/default_data_loaders.py b/hamilton/io/default_data_loaders.py index 14a76e3eb..1810e1fc8 100644 --- a/hamilton/io/default_data_loaders.py +++ b/hamilton/io/default_data_loaders.py @@ -21,7 +21,8 @@ import os import pathlib import pickle -from typing import Any, Collection, Dict, Tuple, Type, Union +from collections.abc import Collection +from typing import Any from hamilton.io.data_adapters import DataLoader, DataSaver from hamilton.io.utils import get_file_metadata @@ -32,10 +33,10 @@ class JSONDataLoader(DataLoader): path: str @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [dict, list] - def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[dict, dict[str, Any]]: with open(self.path, "r") as f: return json.load(f), get_file_metadata(self.path) @@ -49,14 +50,14 @@ class JSONDataSaver(DataSaver): path: str @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [dict, list] @classmethod def name(cls) -> str: return "json" - def save_data(self, data: Any) -> Dict[str, Any]: + def save_data(self, data: Any) -> dict[str, Any]: with open(self.path, "w") as f: json.dump(data, f) return get_file_metadata(self.path) @@ -67,12 +68,12 @@ class RawFileDataLoader(DataLoader): path: str encoding: str = "utf-8" - def load_data(self, type_: Type) -> Tuple[str, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[str, dict[str, Any]]: with open(self.path, "r", encoding=self.encoding) as f: return f.read(), get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [str] @classmethod @@ -86,14 +87,14 @@ class RawFileDataSaver(DataSaver): encoding: str = "utf-8" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [str] @classmethod def name(cls) -> str: return "file" - def save_data(self, data: Any) -> Dict[str, Any]: + def save_data(self, data: Any) -> dict[str, Any]: with open(self.path, "w", encoding=self.encoding) as f: f.write(data) return get_file_metadata(self.path) @@ -101,17 +102,17 @@ def save_data(self, data: Any) -> Dict[str, Any]: @dataclasses.dataclass class RawFileDataSaverBytes(DataSaver): - path: Union[pathlib.Path, str] + path: pathlib.Path | str @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [bytes, io.BytesIO] @classmethod def name(cls) -> str: return "file" - def save_data(self, data: Union[bytes, io.BytesIO]) -> Dict[str, Any]: + def save_data(self, data: bytes | io.BytesIO) -> dict[str, Any]: if isinstance(data, io.BytesIO): data_bytes = data.getvalue() # Extract bytes from BytesIO else: @@ -128,14 +129,14 @@ class PickleLoader(DataLoader): path: str @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [object, Any] @classmethod def name(cls) -> str: return "pickle" - def load_data(self, type_: Type[object]) -> Tuple[object, Dict[str, Any]]: + def load_data(self, type_: type[object]) -> tuple[object, dict[str, Any]]: with open(self.path, "rb") as f: return pickle.load(f), get_file_metadata(self.path) @@ -145,14 +146,14 @@ class PickleSaver(DataSaver): path: str @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [object] @classmethod def name(cls) -> str: return "pickle" - def save_data(self, data: Any) -> Dict[str, Any]: + def save_data(self, data: Any) -> dict[str, Any]: with open(self.path, "wb") as f: pickle.dump(data, f) return get_file_metadata(self.path) @@ -160,9 +161,9 @@ def save_data(self, data: Any) -> Dict[str, Any]: @dataclasses.dataclass class EnvVarDataLoader(DataLoader): - names: Tuple[str, ...] + names: tuple[str, ...] - def load_data(self, type_: Type[dict]) -> Tuple[dict, Dict[str, Any]]: + def load_data(self, type_: type[dict]) -> tuple[dict, dict[str, Any]]: return {name: os.environ[name] for name in self.names}, {} @classmethod @@ -170,7 +171,7 @@ def name(cls) -> str: return "environment" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [dict] @@ -179,10 +180,10 @@ class LiteralValueDataLoader(DataLoader): value: Any @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] - def load_data(self, type_: Type) -> Tuple[dict, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[dict, dict[str, Any]]: return self.value, {} @classmethod @@ -198,7 +199,7 @@ class InMemoryResult(DataSaver): """ @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] def save_data(self, data: Any) -> Any: diff --git a/hamilton/io/materialization.py b/hamilton/io/materialization.py index b23e17cbe..c41bb5944 100644 --- a/hamilton/io/materialization.py +++ b/hamilton/io/materialization.py @@ -19,7 +19,7 @@ import functools import inspect import typing -from typing import Any, Dict, List, Optional, Protocol, Set, Type, Union +from typing import Any, Protocol from hamilton import base, common, graph, lifecycle, node from hamilton.function_modifiers.adapters import LoadFromDecorator, SaveToDecorator @@ -109,8 +109,8 @@ def __getattr__(cls, item: str) -> "_ExtractorFactoryProtocol": def process_kwargs( - data_saver_kwargs: Dict[str, Union[Any, SingleDependency]], -) -> Dict[str, SingleDependency]: + data_saver_kwargs: dict[str, Any | SingleDependency], +) -> dict[str, SingleDependency]: """Processes raw strings from the user, converting them into dependency specs. This goes according to the following rules. @@ -135,8 +135,8 @@ class ExtractorFactory: def __init__( self, target: str, - loaders: List[Type[DataLoader]], - **data_loader_kwargs: Union[Any, SingleDependency], + loaders: list[type[DataLoader]], + **data_loader_kwargs: Any | SingleDependency, ): """Instantiates an ExtractorFactory. Note this is not a public API -- this is internally what gets called (through a factory method) to create it. Called using `from_`, @@ -150,7 +150,7 @@ def __init__( self.loaders = loaders self.data_loader_kwargs = process_kwargs(data_loader_kwargs) - def generate_nodes(self, fn_graph: graph.FunctionGraph) -> List[node.Node]: + def generate_nodes(self, fn_graph: graph.FunctionGraph) -> list[node.Node]: """Resolves the extractor, returning the set of nodes that should get added to the function graph. Note that this is an upsert operation -- these nodes can replace existing nodes. @@ -178,9 +178,9 @@ class MaterializerFactory: def __init__( self, id: str, - savers: List[Type[DataSaver]], - result_builder: Optional[base.ResultMixin], - dependencies: List[Union[str, Any]], + savers: list[type[DataSaver]], + result_builder: base.ResultMixin | None, + dependencies: list[str | Any], **data_saver_kwargs: Any, ): """Creates a materializer factory. @@ -200,7 +200,7 @@ def __init__( self.dependencies = dependencies self.data_saver_kwargs = process_kwargs(data_saver_kwargs) - def sanitize_dependencies(self, module_set: Set[str]) -> "MaterializerFactory": + def sanitize_dependencies(self, module_set: set[str]) -> "MaterializerFactory": """Sanitizes the dependencies to ensure they're strings. This replaces the internal value for self.dependencies and returns a new object. @@ -218,7 +218,7 @@ def sanitize_dependencies(self, module_set: Set[str]) -> "MaterializerFactory": **self.data_saver_kwargs, ) - def _resolve_dependencies(self, fn_graph: graph.FunctionGraph) -> List[node.Node]: + def _resolve_dependencies(self, fn_graph: graph.FunctionGraph) -> list[node.Node]: out = [] missing_nodes = [] for name in self.dependencies: @@ -232,7 +232,7 @@ def _resolve_dependencies(self, fn_graph: graph.FunctionGraph) -> List[node.Node ) return [fn_graph.nodes[name] for name in self.dependencies] - def generate_nodes(self, fn_graph: graph.FunctionGraph) -> List[node.Node]: + def generate_nodes(self, fn_graph: graph.FunctionGraph) -> list[node.Node]: """Generates additional nodes from a materializer, returning the set of nodes that should get appended to the function graph. This does two things: @@ -286,25 +286,25 @@ class _MaterializerFactoryProtocol(Protocol): def __call__( self, id: str, - dependencies: List[str], + dependencies: list[str], combine: lifecycle.ResultBuilder = None, - **kwargs: Union[str, SingleDependency], + **kwargs: str | SingleDependency, ) -> MaterializerFactory: pass @typing.runtime_checkable class _ExtractorFactoryProtocol(Protocol): - def __call__(self, target: str, **kwargs: Union[str, SingleDependency]) -> ExtractorFactory: + def __call__(self, target: str, **kwargs: str | SingleDependency) -> ExtractorFactory: pass -def partial_materializer(data_savers: List[Type[DataSaver]]) -> _MaterializerFactoryProtocol: +def partial_materializer(data_savers: list[type[DataSaver]]) -> _MaterializerFactoryProtocol: """Creates a partial materializer, with the specified data savers.""" def create_materializer_factory( id: str, - dependencies: List[str], + dependencies: list[str], combine: base.ResultMixin = None, **kwargs: typing.Any, ) -> MaterializerFactory: @@ -320,7 +320,7 @@ def create_materializer_factory( def partial_extractor( - data_loaders: List[Type[DataLoader]], + data_loaders: list[type[DataLoader]], ) -> _ExtractorFactoryProtocol: """Creates a partial materializer, with the specified data savers.""" @@ -371,8 +371,8 @@ def _set_materializer_attrs(): This is so one can get auto-complete""" def with_modified_signature( - fn: Type[_MaterializerFactoryProtocol], - dataclasses_union: List[Type[dataclasses.dataclass]], + fn: type[_MaterializerFactoryProtocol], + dataclasses_union: list[type[dataclasses.dataclass]], key: str, ): """Modifies the signature to include the parameters from *all* dataclasses. @@ -455,8 +455,8 @@ def wrapper(*args, **kwargs): def modify_graph( fn_graph: FunctionGraph, - materializer_factories: List[MaterializerFactory], - extractor_factories: List[ExtractorFactory], + materializer_factories: list[MaterializerFactory], + extractor_factories: list[ExtractorFactory], ) -> FunctionGraph: """Modifies the function graph, adding in the specified materialization/loader nodes. diff --git a/hamilton/io/utils.py b/hamilton/io/utils.py index 199bdc636..0953eb9bd 100644 --- a/hamilton/io/utils.py +++ b/hamilton/io/utils.py @@ -20,7 +20,7 @@ from datetime import datetime from os import PathLike from pathlib import Path -from typing import Any, Dict, Union +from typing import Any from urllib import parse import pandas as pd @@ -30,7 +30,7 @@ FILE_METADATA = "file_metadata" -def get_file_metadata(path: Union[str, Path, PathLike]) -> Dict[str, Any]: +def get_file_metadata(path: str | Path | PathLike) -> dict[str, Any]: """Gives metadata from loading a file. Note: we reserve the right to change this schema. So if you're using this come @@ -71,7 +71,7 @@ def get_file_metadata(path: Union[str, Path, PathLike]) -> Dict[str, Any]: } -def get_dataframe_metadata(df: pd.DataFrame) -> Dict[str, Any]: +def get_dataframe_metadata(df: pd.DataFrame) -> dict[str, Any]: """Gives metadata from loading a dataframe. Note: we reserve the right to change this schema. So if you're using this come @@ -106,7 +106,7 @@ def get_dataframe_metadata(df: pd.DataFrame) -> Dict[str, Any]: return {DATAFRAME_METADATA: metadata} -def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> Dict[str, Any]: +def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> dict[str, Any]: """Gives metadata from loading a file and a dataframe. Note: we reserve the right to change this schema. So if you're using this come @@ -127,7 +127,7 @@ def get_file_and_dataframe_metadata(path: str, df: pd.DataFrame) -> Dict[str, An return {**get_file_metadata(path), **get_dataframe_metadata(df)} -def get_sql_metadata(query_or_table: str, results: Union[int, pd.DataFrame]) -> Dict[str, Any]: +def get_sql_metadata(query_or_table: str, results: int | pd.DataFrame) -> dict[str, Any]: """Gives metadata from reading a SQL table or writing to SQL db. Note: we reserve the right to change this schema. So if you're using this come diff --git a/hamilton/lifecycle/api.py b/hamilton/lifecycle/api.py index c05027ad9..575616118 100644 --- a/hamilton/lifecycle/api.py +++ b/hamilton/lifecycle/api.py @@ -17,8 +17,9 @@ import abc from abc import ABC +from collections.abc import Collection from types import ModuleType -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Type, final +from typing import TYPE_CHECKING, Any, final from hamilton import graph_types, node @@ -81,13 +82,13 @@ def build_result(self, **outputs: Any) -> Any: @override @final - def do_build_result(self, outputs: Dict[str, Any]) -> Any: + def do_build_result(self, outputs: dict[str, Any]) -> Any: """Implements the do_build_result method from the BaseDoBuildResult class. This is kept from the user as the public-facing API is build_result, allowing us to change the API/implementation of the internal set of hooks""" return self.build_result(**outputs) - def input_types(self) -> List[Type[Type]]: + def input_types(self) -> list[type[type]]: """Gives the applicable types to this result builder. This is optional for backwards compatibility, but is recommended. @@ -95,7 +96,7 @@ def input_types(self) -> List[Type[Type]]: """ return [Any] - def output_type(self) -> Type: + def output_type(self) -> type: """Returns the output type of this result builder :return: the type that this creates """ @@ -129,7 +130,7 @@ class GraphAdapter( @staticmethod @abc.abstractmethod - def check_input_type(node_type: Type, input_value: Any) -> bool: + def check_input_type(node_type: type, input_value: Any) -> bool: """Used to check whether the user inputs match what the execution strategy & functions can handle. Static purely for legacy reasons. @@ -142,7 +143,7 @@ def check_input_type(node_type: Type, input_value: Any) -> bool: @staticmethod @abc.abstractmethod - def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: + def check_node_type_equivalence(node_type: type, input_type: type) -> bool: """Used to check whether two types are equivalent. Static, purely for legacy reasons. @@ -159,7 +160,7 @@ def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: @override @final def do_node_execute( - self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None ) -> Any: return self.execute_node(node_, kwargs) @@ -174,7 +175,7 @@ def do_check_edge_types_match(self, type_from: type, type_to: type) -> bool: return self.check_node_type_equivalence(type_to, type_from) @abc.abstractmethod - def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: + def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any: """Given a node that represents a hamilton function, execute it. Note, in some adapters this might just return some type of "future". @@ -193,12 +194,12 @@ def run_before_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, - task_id: Optional[str], + task_id: str | None, run_id: str, - node_input_types: Dict[str, Any], + node_input_types: dict[str, Any], **future_kwargs: Any, ): """Hook that is executed prior to node execution. @@ -221,8 +222,8 @@ def pre_node_execute( *, run_id: str, node_: node.Node, - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ): """Wraps the before_execution method, providing a bridge to an external-facing API. Do not override this!""" self.run_before_node_execution( @@ -240,13 +241,13 @@ def run_after_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, result: Any, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -272,11 +273,11 @@ def post_node_execute( *, run_id: str, node_: node.Node, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], success: bool, - error: Optional[Exception], - result: Optional[Any], - task_id: Optional[str] = None, + error: Exception | None, + result: Any | None, + task_id: str | None = None, ): """Wraps the after_execution method, providing a bridge to an external-facing API. Do not override this!""" self.run_after_node_execution( @@ -303,8 +304,8 @@ def post_graph_execute( run_id: str, graph: "FunctionGraph", success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, ): """Just delegates to the interface method, passing in the right data.""" return self.run_after_graph_execution( @@ -322,9 +323,9 @@ def pre_graph_execute( *, run_id: str, graph: "FunctionGraph", - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ): """Implementation of the pre_graph_execute hook. This just converts the inputs to the format the user-facing hook is expecting -- performing a walk of the DAG to pass in @@ -345,9 +346,9 @@ def run_before_graph_execution( self, *, graph: graph_types.HamiltonGraph, - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], execution_path: Collection[str], run_id: str, **future_kwargs: Any, @@ -372,8 +373,8 @@ def run_after_graph_execution( *, graph: graph_types.HamiltonGraph, success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, run_id: str, **future_kwargs: Any, ): @@ -400,10 +401,10 @@ def pre_task_submission( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): self.run_before_task_submission( @@ -422,10 +423,10 @@ def run_before_task_submission( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, **future_kwargs, ): @@ -454,11 +455,11 @@ def post_task_return( *, run_id: str, task_id: str, - nodes: List["node.Node"], + nodes: list["node.Node"], result: Any, success: bool, - error: Optional[Exception], - spawning_task_id: Optional[str], + error: Exception | None, + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): self.run_after_task_return( @@ -478,11 +479,11 @@ def run_after_task_return( *, run_id: str, task_id: str, - nodes: List["node.Node"], + nodes: list["node.Node"], result: Any, success: bool, - error: Optional[Exception], - spawning_task_id: Optional[str], + error: Exception | None, + spawning_task_id: str | None, purpose: NodeGroupPurpose, **future_kwargs, ): @@ -511,10 +512,10 @@ def pre_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): self.run_before_task_execution( @@ -532,11 +533,11 @@ def post_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - results: Optional[Dict[str, Any]], + nodes: list["node.Node"], + results: dict[str, Any] | None, success: bool, error: Exception, - spawning_task_id: Optional[str], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): self.run_after_task_execution( @@ -556,10 +557,10 @@ def run_before_task_execution( *, task_id: str, run_id: str, - nodes: List[HamiltonNode], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list[HamiltonNode], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, **future_kwargs, ): @@ -583,11 +584,11 @@ def run_after_task_execution( *, task_id: str, run_id: str, - nodes: List[HamiltonNode], - results: Optional[Dict[str, Any]], + nodes: list[HamiltonNode], + results: dict[str, Any] | None, success: bool, error: Exception, - spawning_task_id: Optional[str], + spawning_task_id: str | None, purpose: NodeGroupPurpose, **future_kwargs, ): @@ -662,8 +663,8 @@ def do_node_execute( *, run_id: str, node_: node.Node, - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ) -> Any: return self.run_to_execute_node( node_name=node_.name, @@ -680,10 +681,10 @@ def run_to_execute_node( self, *, node_name: str, - node_tags: Dict[str, Any], + node_tags: dict[str, Any], node_callable: Any, - node_kwargs: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + task_id: str | None, is_expand: bool, is_collect: bool, **future_kwargs: Any, @@ -727,7 +728,7 @@ def run_to_validate_node( def run_to_validate_node( self, *, node: HamiltonNode, **future_kwargs - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """Override this to build custom node validations! Defaults to just returning that a node is valid so you don't have to implement it if you want to just implement a single method. Runs post node construction to validate a node. You have access to a bunch of metadata about the node, stored in the hamilton_node argument @@ -741,7 +742,7 @@ def run_to_validate_node( def run_to_validate_graph( self, graph: HamiltonGraph, **future_kwargs - ) -> Tuple[bool, Optional[str]]: + ) -> tuple[bool, str | None]: """Override this to build custom DAG validations! Default to just returning that the graph is valid, so you don't have to implement it if you want to just implement a single method. Runs post graph construction to validate a graph. You have access to a bunch of metadata about the graph, stored in the graph argument. @@ -754,14 +755,14 @@ def run_to_validate_graph( @override @final - def validate_node(self, *, created_node: node.Node) -> Tuple[bool, Optional[Exception]]: + def validate_node(self, *, created_node: node.Node) -> tuple[bool, Exception | None]: return self.run_to_validate_node(node=HamiltonNode.from_node(created_node)) @override @final def validate_graph( - self, *, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] - ) -> Tuple[bool, Optional[Exception]]: + self, *, graph: "FunctionGraph", modules: list[ModuleType], config: dict[str, Any] + ) -> tuple[bool, Exception | None]: return self.run_to_validate_graph(graph=HamiltonGraph.from_graph(graph)) @@ -771,16 +772,16 @@ class TaskGroupingHook(BasePostTaskGroup, BasePostTaskExpand): @override @final - def post_task_group(self, *, run_id: str, task_ids: List[str]): + def post_task_group(self, *, run_id: str, task_ids: list[str]): return self.run_after_task_grouping(run_id=run_id, task_ids=task_ids) @override @final - def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): + def post_task_expand(self, *, run_id: str, task_id: str, parameters: dict[str, Any]): return self.run_after_task_expansion(run_id=run_id, task_id=task_id, parameters=parameters) @abc.abstractmethod - def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs): + def run_after_task_grouping(self, *, run_id: str, task_ids: list[str], **future_kwargs): """Runs after task grouping. This allows you to capture information about which tasks were created for a given run. @@ -792,7 +793,7 @@ def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_ @abc.abstractmethod def run_after_task_expansion( - self, *, run_id: str, task_id: str, parameters: Dict[str, Any], **future_kwargs + self, *, run_id: str, task_id: str, parameters: dict[str, Any], **future_kwargs ): """Runs after task expansion in Parallelize/Collect blocks. This allows you to capture information about the task that was expanded. @@ -813,13 +814,13 @@ class GraphConstructionHook(BasePostGraphConstruct, abc.ABC): """ def post_graph_construct( - self, *, graph: "FunctionGraph", modules: List[ModuleType], config: Dict[str, Any] + self, *, graph: "FunctionGraph", modules: list[ModuleType], config: dict[str, Any] ): self.run_after_graph_construction(graph=HamiltonGraph.from_graph(graph), config=config) @abc.abstractmethod def run_after_graph_construction( - self, *, graph: HamiltonGraph, config: Dict[str, Any], **future_kwargs: Any + self, *, graph: HamiltonGraph, config: dict[str, Any], **future_kwargs: Any ): """Hook that is run post graph construction. This allows you to register/capture info on the graph. A common pattern is to store something in your object's state here so that you can use it later diff --git a/hamilton/lifecycle/base.py b/hamilton/lifecycle/base.py index 1c881220f..1140ff728 100644 --- a/hamilton/lifecycle/base.py +++ b/hamilton/lifecycle/base.py @@ -47,8 +47,9 @@ import collections import dataclasses import inspect +from collections.abc import Callable from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Union from hamilton import htypes @@ -69,16 +70,16 @@ # as it is a clear, simple way to manage the metadata. This allows us to track the registered hooks/methods/validators. # A set of registered hooks -- each one refers to a string -REGISTERED_SYNC_HOOKS: Set[str] = set() -REGISTERED_ASYNC_HOOKS: Set[str] = set() +REGISTERED_SYNC_HOOKS: set[str] = set() +REGISTERED_ASYNC_HOOKS: set[str] = set() # A set of registered methods -- each one refers to a string, which is the name of the metho -REGISTERED_SYNC_METHODS: Set[str] = set() -REGISTERED_ASYNC_METHODS: Set[str] = set() +REGISTERED_SYNC_METHODS: set[str] = set() +REGISTERED_ASYNC_METHODS: set[str] = set() # A set of registered validators -- these have attached Exception data to them # Note we do not curently have async validators -- see no need now -REGISTERED_SYNC_VALIDATORS: Set[str] = set() +REGISTERED_SYNC_VALIDATORS: set[str] = set() # constants to refer to internally for hooks SYNC_HOOK = "hooks" @@ -95,7 +96,7 @@ @dataclasses.dataclass class ValidationResult: success: bool - error: Optional[str] + error: str | None validator: object # validator so we can make the error message more friendly @@ -110,7 +111,7 @@ class InvalidLifecycleAdapter(Exception): def validate_lifecycle_adapter_function( - fn: Callable, returns_value: bool, return_type: Optional[Type] = None + fn: Callable, returns_value: bool, return_type: type | None = None ): """Validates that a function has arguments that are keyword-only, and either does or does not return a value, depending on the value of returns_value. @@ -313,7 +314,7 @@ def do_validate_input(self, *, node_type: type, input_value: Any) -> bool: @lifecycle.base_validator("validate_node") class BaseValidateNode(abc.ABC): @abc.abstractmethod - def validate_node(self, *, created_node: "node.Node") -> Tuple[bool, Optional[Exception]]: + def validate_node(self, *, created_node: "node.Node") -> tuple[bool, Exception | None]: """Validates a node. This will raise an InvalidNodeException if the node is invalid. @@ -330,9 +331,9 @@ def validate_graph( self, *, graph: "graph.FunctionGraph", - modules: List[ModuleType], - config: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: + modules: list[ModuleType], + config: dict[str, Any], + ) -> tuple[bool, str | None]: """Validates the graph. This will raise an InvalidNodeException :param graph: Graph that has been constructed. @@ -349,8 +350,8 @@ def post_graph_construct( self, *, graph: "graph.FunctionGraph", - modules: List[ModuleType], - config: Dict[str, Any], + modules: list[ModuleType], + config: dict[str, Any], ): """Hooks that is called after the graph is constructed. @@ -368,8 +369,8 @@ async def post_graph_construct( self, *, graph: "graph.FunctionGraph", - modules: List[ModuleType], - config: Dict[str, Any], + modules: list[ModuleType], + config: dict[str, Any], ): """Asynchronous hook that is called after the graph is constructed. @@ -388,9 +389,9 @@ def pre_graph_execute( *, run_id: str, graph: "graph.FunctionGraph", - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ): """Hook that is called immediately prior to graph execution. @@ -411,9 +412,9 @@ async def pre_graph_execute( *, run_id: str, graph: "graph.FunctionGraph", - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ): """Asynchronous Hook that is called immediately prior to graph execution. @@ -429,7 +430,7 @@ async def pre_graph_execute( @lifecycle.base_hook("post_task_group") class BasePostTaskGroup(abc.ABC): @abc.abstractmethod - def post_task_group(self, *, run_id: str, task_ids: List[str]): + def post_task_group(self, *, run_id: str, task_ids: list[str]): """Hook that is called immediately after a task group is created. Note that this is only useful in dynamic execution, although we reserve the right to add this back into the standard hamilton execution pattern. @@ -441,7 +442,7 @@ def post_task_group(self, *, run_id: str, task_ids: List[str]): @lifecycle.base_hook("post_task_expand") class BasePostTaskExpand(abc.ABC): @abc.abstractmethod - def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]): + def post_task_expand(self, *, run_id: str, task_id: str, parameters: dict[str, Any]): """Hook that is called immediately after a task is expanded into parallelizable tasks. Note that this is only useful in dynamic execution. @@ -459,10 +460,10 @@ def pre_task_submission( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Hook that is called immediately prior to task submission to an executor as a task future. @@ -488,10 +489,10 @@ def pre_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Hook that is called immediately prior to task execution. Note that this is only useful in dynamic @@ -516,10 +517,10 @@ async def pre_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - inputs: Dict[str, Any], - overrides: Dict[str, Any], - spawning_task_id: Optional[str], + nodes: list["node.Node"], + inputs: dict[str, Any], + overrides: dict[str, Any], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Hook that is called immediately prior to task execution. Note that this is only useful in dynamic @@ -544,8 +545,8 @@ def pre_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ): """Hook that is called immediately prior to node execution. @@ -565,8 +566,8 @@ async def pre_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ): """Asynchronous hook that is called immediately prior to node execution. @@ -586,8 +587,8 @@ def do_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ) -> Any: """Method that is called to implement node execution. This can replace the execution of a node with something all together, augment it, or delegate it. @@ -607,7 +608,7 @@ def do_remote_execute( self, *, node: "node.Node", - kwargs: Dict[str, Any], + kwargs: dict[str, Any], execute_lifecycle_for_node: Callable, ) -> Any: """Method that is called to implement correct remote execution of hooks. This makes sure that all the pre-node and post-node hooks get executed in the remote environment which is necessary for some adapters. Node execution is called the same as before through "do_node_execute". @@ -628,8 +629,8 @@ async def do_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], - task_id: Optional[str] = None, + kwargs: dict[str, Any], + task_id: str | None = None, ) -> Any: """Asynchronous method that is called to implement node execution. This can replace the execution of a node with something all together, augment it, or delegate it. @@ -650,11 +651,11 @@ def post_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], + kwargs: dict[str, Any], success: bool, - error: Optional[Exception], - result: Optional[Any], - task_id: Optional[str] = None, + error: Exception | None, + result: Any | None, + task_id: str | None = None, ): """Hook that is called immediately after node execution. @@ -677,11 +678,11 @@ async def post_node_execute( *, run_id: str, node_: "node.Node", - kwargs: Dict[str, Any], + kwargs: dict[str, Any], success: bool, - error: Optional[Exception], + error: Exception | None, result: Any, - task_id: Optional[str] = None, + task_id: str | None = None, ): """Hook that is called immediately after node execution. @@ -704,11 +705,11 @@ def post_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - results: Optional[Dict[str, Any]], + nodes: list["node.Node"], + results: dict[str, Any] | None, success: bool, error: Exception, - spawning_task_id: Optional[str], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Hook called immediately after task execution. Note that this is only useful in dynamic @@ -734,11 +735,11 @@ async def post_task_execute( *, run_id: str, task_id: str, - nodes: List["node.Node"], - results: Optional[Dict[str, Any]], + nodes: list["node.Node"], + results: dict[str, Any] | None, success: bool, error: Exception, - spawning_task_id: Optional[str], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Asynchronous Hook called immediately after task execution. Note that this is only useful in dynamic @@ -764,11 +765,11 @@ def post_task_return( *, run_id: str, task_id: str, - nodes: List["node.Node"], + nodes: list["node.Node"], result: Any, success: bool, error: Exception, - spawning_task_id: Optional[str], + spawning_task_id: str | None, purpose: NodeGroupPurpose, ): """Hook called immediately after a task returns from an executor. Note that this is only @@ -795,8 +796,8 @@ def post_graph_execute( run_id: str, graph: "graph.FunctionGraph", success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, ): """Hook called immediately after graph execution. @@ -818,8 +819,8 @@ async def post_graph_execute( run_id: str, graph: "graph.FunctionGraph", success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, ): """Asynchronous Hook called immediately after graph execution. @@ -893,7 +894,7 @@ def __init__(self, *adapters: LifecycleAdapter): self.sync_methods, self.async_methods = self._get_lifecycle_methods() self.sync_validators = self._get_lifecycle_validators() - def _uniqify_adapters(self, adapters: List[LifecycleAdapter]) -> List[LifecycleAdapter]: + def _uniqify_adapters(self, adapters: list[LifecycleAdapter]) -> list[LifecycleAdapter]: """Removes duplicate adapters from the list of adapters -- this often happens on how they're passed in and we don't want to have the same adapter twice. Specifically, this came up due to parsing/splitting out adapters with async lifecycle hooks -- there were cases in which we were passed duplicates. This was compounded as we would pass @@ -909,7 +910,7 @@ def _uniqify_adapters(self, adapters: List[LifecycleAdapter]) -> List[LifecycleA def _get_lifecycle_validators( self, - ) -> Dict[str, List[LifecycleAdapter]]: + ) -> dict[str, list[LifecycleAdapter]]: sync_validators = collections.defaultdict(set) for adapter in self.adapters: for cls in inspect.getmro(adapter.__class__): @@ -920,7 +921,7 @@ def _get_lifecycle_validators( def _get_lifecycle_hooks( self, - ) -> Tuple[Dict[str, List[LifecycleAdapter]], Dict[str, List[LifecycleAdapter]]]: + ) -> tuple[dict[str, list[LifecycleAdapter]], dict[str, list[LifecycleAdapter]]]: sync_hooks = collections.defaultdict(list) async_hooks = collections.defaultdict(list) for adapter in self.adapters: @@ -940,7 +941,7 @@ def _get_lifecycle_hooks( def _get_lifecycle_methods( self, - ) -> Tuple[Dict[str, List[LifecycleAdapter]], Dict[str, List[LifecycleAdapter]]]: + ) -> tuple[dict[str, list[LifecycleAdapter]], dict[str, list[LifecycleAdapter]]]: sync_methods = collections.defaultdict(set) async_methods = collections.defaultdict(set) for adapter in self.adapters: @@ -968,7 +969,7 @@ def _get_lifecycle_methods( {method: list(adapters) for method, adapters in async_methods.items()}, ) - def does_hook(self, hook_name: str, is_async: Optional[bool] = None) -> bool: + def does_hook(self, hook_name: str, is_async: bool | None = None) -> bool: """Whether or not a hook is implemented by any of the adapters in this group. If this hook is not registered, this will raise a ValueError. @@ -991,7 +992,7 @@ def does_hook(self, hook_name: str, is_async: Optional[bool] = None) -> bool: has_sync = hook_name in self.sync_hooks return (has_async or has_sync) if either else has_async if is_async else has_sync - def does_method(self, method_name: str, is_async: Optional[bool] = None) -> bool: + def does_method(self, method_name: str, is_async: bool | None = None) -> bool: """Whether a method is implemented by any of the adapters in this group. If this method is not registered, this will raise a ValueError. @@ -1094,7 +1095,7 @@ async def call_lifecycle_method_async(self, method_name: str, **kwargs): def call_all_validators_sync( self, validator_name: str, output_only_failures: bool = True, **kwargs - ) -> List[ValidationResult]: + ) -> list[ValidationResult]: """Calls all the lifecycle validators in this group, by validator name (stage) :param validator_name: Name of the validators to call @@ -1109,7 +1110,7 @@ def call_all_validators_sync( return results @property - def adapters(self) -> List[LifecycleAdapter]: + def adapters(self) -> list[LifecycleAdapter]: """Gives the adapters in this group :return: A list of adapters diff --git a/hamilton/lifecycle/default.py b/hamilton/lifecycle/default.py index 37a0b9a77..d8cde2d24 100644 --- a/hamilton/lifecycle/default.py +++ b/hamilton/lifecycle/default.py @@ -25,8 +25,9 @@ import random import shelve import time +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Union from hamilton import graph_types, htypes from hamilton.graph_types import HamiltonGraph @@ -41,15 +42,15 @@ NodeFilter = Union[ Callable[ - [str, Dict[str, Any]], bool + [str, dict[str, Any]], bool ], # filter function for nodes, mapping node name to a boolean - List[str], # list of node names to run + list[str], # list of node names to run str, # node name to run None, # run all nodes ] # filter function for nodes, mapping node name and node tags to a boolean -def should_run_node(node_name: str, node_tags: Dict[str, Any], node_filter: NodeFilter) -> bool: +def should_run_node(node_name: str, node_tags: dict[str, Any], node_filter: NodeFilter) -> bool: if node_filter is None: return True if isinstance(node_filter, str): @@ -83,7 +84,7 @@ def _validate_verbosity(verbosity: int): raise ValueError(f"Verbosity must be one of [1, 2], got {verbosity}") @staticmethod - def _format_node_name(node_name: str, task_id: Optional[str]) -> str: + def _format_node_name(node_name: str, task_id: str | None) -> str: """Formats a node name and task id into a unique node name.""" if task_id is not None: return f"{task_id}:{node_name}" @@ -116,9 +117,9 @@ def run_before_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], - task_id: Optional[str], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], + task_id: str | None, **future_kwargs: Any, ): """Runs before a node executes. Prints out the node name and inputs if verbosity is 2. @@ -142,12 +143,12 @@ def run_after_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], result: Any, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, **future_kwargs: Any, ): """Runs after a node executes. Prints out the node name and time it took, the output if verbosity is 1. @@ -212,10 +213,10 @@ def run_to_execute_node( self, *, node_name: str, - node_tags: Dict[str, Any], + node_tags: dict[str, Any], node_callable: Any, - node_kwargs: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + task_id: str | None, **future_kwargs: Any, ) -> Any: """Executes the node with a PDB debugger. This modifies the global PDBDebugger.CONTEXT variable to contain information about the node, @@ -241,16 +242,14 @@ def run_to_execute_node( "future_kwargs": future_kwargs, } logger.warning( - ( - f"Placing you in a PDB debugger for node {node_name}." - "\nYou can access additional node information via PDBDebugger.CONTEXT. Data is:" - f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" - f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" - f"\n - node_callable: {PDBDebugger._truncate_repr(node_callable)}" - f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" - f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" - f"\n - future_kwargs: {PDBDebugger._truncate_repr(future_kwargs)}" - ) + f"Placing you in a PDB debugger for node {node_name}." + "\nYou can access additional node information via PDBDebugger.CONTEXT. Data is:" + f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" + f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" + f"\n - node_callable: {PDBDebugger._truncate_repr(node_callable)}" + f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" + f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" + f"\n - future_kwargs: {PDBDebugger._truncate_repr(future_kwargs)}" ) out = pdb.runcall(node_callable, **node_kwargs) logger.info(f"Finished executing node {node_name}.") @@ -267,10 +266,10 @@ def run_before_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, - task_id: Optional[str], + task_id: str | None, **future_kwargs: Any, ): """Executes before a node executes. Does nothing, just runs pdb.set_trace() @@ -285,15 +284,13 @@ def run_before_node_execution( """ if should_run_node(node_name, node_tags, self.node_filter) and self.run_before: logger.warning( - ( - f"Placing you in a PDB debugger prior to execution of node: {node_name}." - "\nYou can access additional node information via the following variables:" - f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" - f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" - f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" - f"\n - node_return_type: {PDBDebugger._truncate_repr(node_return_type)}" - f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" - ) + f"Placing you in a PDB debugger prior to execution of node: {node_name}." + "\nYou can access additional node information via the following variables:" + f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" + f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" + f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" + f"\n - node_return_type: {PDBDebugger._truncate_repr(node_return_type)}" + f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" ) pdb.set_trace() @@ -301,13 +298,13 @@ def run_after_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, result: Any, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, **future_kwargs: Any, ): """Executes after a node, whether or not it was successful. Does nothing, just runs pdb.set_trace(). @@ -324,18 +321,16 @@ def run_after_node_execution( """ if should_run_node(node_name, node_tags, self.node_filter) and self.run_after: logger.warning( - ( - f"Placing you in a PDB debugger post execution of node: {node_name}." - "\nYou can access additional node information via the following variables:" - f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" - f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" - f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" - f"\n - node_return_type: {PDBDebugger._truncate_repr(node_return_type)}" - f"\n - result: {PDBDebugger._truncate_repr(result)}" - f"\n - error: {PDBDebugger._truncate_repr(error)}" - f"\n - success: {PDBDebugger._truncate_repr(success)}" - f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" - ) + f"Placing you in a PDB debugger post execution of node: {node_name}." + "\nYou can access additional node information via the following variables:" + f"\n - node_name: {PDBDebugger._truncate_repr(node_name)}" + f"\n - node_tags: {PDBDebugger._truncate_repr(node_tags)}" + f"\n - node_kwargs: {PDBDebugger._truncate_repr(', '.join(list(node_kwargs.keys())))}" + f"\n - node_return_type: {PDBDebugger._truncate_repr(node_return_type)}" + f"\n - result: {PDBDebugger._truncate_repr(result)}" + f"\n - error: {PDBDebugger._truncate_repr(error)}" + f"\n - success: {PDBDebugger._truncate_repr(success)}" + f"\n - task_id: {PDBDebugger._truncate_repr(task_id)}" ) pdb.set_trace() @@ -356,9 +351,7 @@ class CacheAdapter(NodeExecutionHook, NodeExecutionMethod, GraphExecutionHook): nodes_history_key: str = "_nodes_history" - def __init__( - self, cache_vars: Union[List[str], None] = None, cache_path: str = "./hamilton-cache" - ): + def __init__(self, cache_vars: list[str] | None = None, cache_path: str = "./hamilton-cache"): """Initialize the cache :param cache_vars: List of nodes for which to store/load results. Passing None will use the cache @@ -368,10 +361,10 @@ def __init__( self.cache_vars = cache_vars or [] self.cache_path = cache_path self.cache = shelve.open(self.cache_path) - self.nodes_history: Dict[str, List[str]] = self.cache.get( + self.nodes_history: dict[str, list[str]] = self.cache.get( key=CacheAdapter.nodes_history_key, default=dict() ) - self.used_nodes_hash: Dict[str, str] = dict() + self.used_nodes_hash: dict[str, str] = dict() self.cache.close() logger.warning( @@ -388,7 +381,7 @@ def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs): self.cache_vars = [n.name for n in graph.nodes] def run_to_execute_node( - self, *, node_name: str, node_callable: Any, node_kwargs: Dict[str, Any], **kwargs + self, *, node_name: str, node_callable: Any, node_kwargs: dict[str, Any], **kwargs ): """Create cache key based on node callable hash (equiv. to HamiltonNode.version) and the node inputs (`node_kwargs`).If key in cache (cache hit), load result; else (cache miss), @@ -416,7 +409,7 @@ def run_to_execute_node( return node_callable(**node_kwargs) def run_after_node_execution( - self, *, node_name: str, node_kwargs: Dict[str, Any], result: Any, **kwargs + self, *, node_name: str, node_kwargs: dict[str, Any], result: Any, **kwargs ): """If `run_to_execute_node` was a cache miss (hash stored in `used_nodes_hash`), store the computed result in cache @@ -444,7 +437,7 @@ def run_before_node_execution(self, *args, **kwargs): pass @staticmethod - def create_key(node_hash: str, node_inputs: Dict[str, Any]) -> str: + def create_key(node_hash: str, node_inputs: dict[str, Any]) -> str: """Pickle objects into bytes then get their hash value""" digest = hashlib.sha256() digest.update(node_hash.encode()) @@ -504,12 +497,12 @@ def __init__(self, check_input: bool = True, check_output: bool = True): def run_before_node_execution( self, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, - task_id: Optional[str], + task_id: str | None, run_id: str, - node_input_types: Dict[str, Any], + node_input_types: dict[str, Any], **future_kwargs: Any, ): """Checks that the result type matches the expected node return type.""" @@ -523,13 +516,13 @@ def run_before_node_execution( def run_after_node_execution( self, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, result: Any, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -601,7 +594,7 @@ class GracefulErrorAdapter(NodeExecutionMethod): def __init__( self, - error_to_catch: Type[Exception], + error_to_catch: type[Exception], sentinel_value: Any = SENTINEL_DEFAULT, try_all_parallel: bool = True, allow_injection: bool = True, @@ -729,7 +722,7 @@ def run_to_execute_node( self, *, node_callable: Any, - node_kwargs: Dict[str, Any], + node_kwargs: dict[str, Any], is_expand: bool, is_collect: bool, **future_kwargs: Any, diff --git a/hamilton/models.py b/hamilton/models.py index 543eeede6..7b8bf8597 100644 --- a/hamilton/models.py +++ b/hamilton/models.py @@ -16,7 +16,7 @@ # under the License. import abc -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -36,7 +36,7 @@ def __init__(self, config_parameters: Any, name: str): self._name = name @abc.abstractmethod - def get_dependents(self) -> List[str]: + def get_dependents(self) -> list[str]: """Gets the names/types of the inputs to this transform. :return: A list of columns on which this model depends. """ @@ -51,7 +51,7 @@ def compute(self, **inputs: Any) -> Any: pass @property - def config_parameters(self) -> Dict[str, Any]: + def config_parameters(self) -> dict[str, Any]: """Accessor for configuration parameters""" return self._config_parameters diff --git a/hamilton/node.py b/hamilton/node.py index 36e62183f..0fa52d61b 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. +import builtins import inspect import sys import typing +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any import typing_inspect @@ -58,21 +60,21 @@ def from_parameter(param: inspect.Parameter): return DependencyType.OPTIONAL -class Node(object): +class Node: """Object representing a node of computation.""" def __init__( self, name: str, - typ: Type, + typ: type, doc_string: str = "", callabl: Callable = None, node_source: NodeType = NodeType.STANDARD, - input_types: Dict[str, Union[Type, Tuple[Type, DependencyType]]] = None, - tags: Dict[str, Any] = None, - namespace: Tuple[str, ...] = (), - originating_functions: Optional[Tuple[Callable, ...]] = None, - optional_values: Optional[Dict[str, Any]] = None, + input_types: dict[str, type | tuple[type, DependencyType]] = None, + tags: dict[str, Any] = None, + namespace: tuple[str, ...] = (), + originating_functions: tuple[Callable, ...] | None = None, + optional_values: dict[str, Any] | None = None, ): """Constructor for our Node object. @@ -154,7 +156,7 @@ def collect_dependency(self) -> str: return key @property - def namespace(self) -> Tuple[str, ...]: + def namespace(self) -> tuple[str, ...]: return self._namespace @property @@ -162,11 +164,11 @@ def documentation(self) -> str: return self._doc @property - def input_types(self) -> Dict[Any, Tuple[Any, DependencyType]]: + def input_types(self) -> dict[Any, tuple[Any, DependencyType]]: return self._input_types @property - def default_parameter_values(self) -> Dict[str, Any]: + def default_parameter_values(self) -> dict[str, Any]: """Only returns parameters for which we have optional values.""" return self._default_parameter_values @@ -208,19 +210,19 @@ def node_role(self): return self._node_source @property - def dependencies(self) -> List["Node"]: + def dependencies(self) -> list["Node"]: return self._dependencies @property - def depended_on_by(self) -> List["Node"]: + def depended_on_by(self) -> list["Node"]: return self._depended_on_by @property - def tags(self) -> Dict[str, str]: + def tags(self) -> dict[str, str]: return self._tags @property - def originating_functions(self) -> Optional[Tuple[Callable, ...]]: + def originating_functions(self) -> tuple[Callable, ...] | None: """Gives all functions from which this node was created. None if the data is not available (it is user-defined, or we have not added it yet). Note that this can be multiple in the case of subdags (the subdag function + the other function). In that case, @@ -362,7 +364,7 @@ def copy(self, include_refs: bool = True) -> "Node": return self.copy_with(include_refs) def reassign_inputs( - self, input_names: Dict[str, Any] = None, input_values: Dict[str, Any] = None + self, input_names: dict[str, Any] = None, input_values: dict[str, Any] = None ) -> "Node": """Reassigns the input names of a node. Useful for applying a node to a separate input if needed. Note that things can get a @@ -397,7 +399,7 @@ async def async_function(**kwargs): return out def transform_output( - self, __transform: Callable[[Dict[str, Any], Any], Any], __output_type: Type[Any] + self, __transform: Callable[[dict[str, Any], Any], Any], __output_type: builtins.type[Any] ) -> "Node": """Applies a transformation on the output of the node, returning a new node. Also modifies the type. @@ -415,7 +417,7 @@ def new_callable(**kwargs) -> Any: def matches_query( - tags: Dict[str, Union[str, List[str]]], query_dict: Dict[str, Optional[Union[str, List[str]]]] + tags: dict[str, str | list[str]], query_dict: dict[str, str | list[str] | None] ) -> bool: """Check whether a set of node tags matches the query based on tags. diff --git a/hamilton/plugins/dlt_extensions.py b/hamilton/plugins/dlt_extensions.py index a00ea9e6e..190623708 100644 --- a/hamilton/plugins/dlt_extensions.py +++ b/hamilton/plugins/dlt_extensions.py @@ -16,7 +16,8 @@ # under the License. import dataclasses -from typing import Any, Collection, Dict, Iterable, Literal, Optional, Sequence, Tuple, Type +from collections.abc import Collection, Iterable, Sequence +from typing import Any, Literal try: import dlt @@ -63,10 +64,10 @@ def name(cls) -> str: return "dlt" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [pd.DataFrame] - def load_data(self, type_: Type) -> Tuple[pd.DataFrame, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[pd.DataFrame, dict[str, Any]]: """Creates a pipeline and conduct `extract` and `normalize` steps. Then, "load packages" are read with pandas """ @@ -105,18 +106,18 @@ class DltDestinationSaver(DataSaver): pipeline: dlt.Pipeline table_name: str - primary_key: Optional[str] = None - write_disposition: Optional[Literal["skip", "append", "replace", "merge"]] = None - columns: Optional[Sequence[TColumnSchema]] = None - schema: Optional[Schema] = None - loader_file_format: Optional[TLoaderFileFormat] = None + primary_key: str | None = None + write_disposition: Literal["skip", "append", "replace", "merge"] | None = None + columns: Sequence[TColumnSchema] | None = None + schema: Schema | None = None + loader_file_format: TLoaderFileFormat | None = None @classmethod def name(cls) -> str: return "dlt" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return DATAFRAME_TYPES def _get_kwargs(self) -> dict: @@ -133,7 +134,7 @@ def _get_kwargs(self) -> dict: return kwargs # TODO get pyarrow table from polars, dask, etc. - def save_data(self, data) -> Dict[str, Any]: + def save_data(self, data) -> dict[str, Any]: """ ref: https://dlthub.com/docs/dlt-ecosystem/verified-sources/arrow-pandas """ diff --git a/hamilton/plugins/h_dask.py b/hamilton/plugins/h_dask.py index 175cf07e1..2e2c2da96 100644 --- a/hamilton/plugins/h_dask.py +++ b/hamilton/plugins/h_dask.py @@ -131,7 +131,7 @@ def __init__( self.compute_at_end = compute_at_end @staticmethod - def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: + def check_input_type(node_type: type, input_value: typing.Any) -> bool: # NOTE: the type of dask Delayed is unknown until they are computed if isinstance(input_value, Delayed): return True @@ -142,14 +142,14 @@ def check_input_type(node_type: typing.Type, input_value: typing.Any) -> bool: return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type) -> bool: + def check_node_type_equivalence(node_type: type, input_type: type) -> bool: if node_type == dask.array.Array and input_type == pd.Series: return True elif node_type == dask.dataframe.Series and input_type == pd.Series: return True return node_type == input_type - def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any: + def execute_node(self, node: node.Node, kwargs: dict[str, typing.Any]) -> typing.Any: """Function that is called as we walk the graph to determine how to execute a hamilton function. :param node: the node from the graph. @@ -168,7 +168,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> dask_key_name=dask_key_name, # this is what shows up in the dask console ) - def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def build_result(self, **outputs: dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. :param outputs: the dictionary of key -> Union[delayed object reference | value] @@ -198,7 +198,7 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any: class DaskDataFrameResult(base.ResultMixin): @staticmethod - def build_result(**outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def build_result(**outputs: dict[str, typing.Any]) -> typing.Any: """Builds a dask dataframe from the outputs. This has some assumptions: diff --git a/hamilton/plugins/h_ddog.py b/hamilton/plugins/h_ddog.py index 84ba5ef3c..8d61639a5 100644 --- a/hamilton/plugins/h_ddog.py +++ b/hamilton/plugins/h_ddog.py @@ -17,7 +17,7 @@ import logging from types import ModuleType -from typing import Any, Dict, List, Optional +from typing import Any from hamilton import graph as h_graph from hamilton import lifecycle, node @@ -57,7 +57,7 @@ def __init__(self, root_name: str, include_causal_links: bool = False, service: self.node_span_cache = {} # Cache of run_id -> [task_id, node_id] -> span. We use this to open/close general traces @staticmethod - def _serialize_span_dict(span_dict: Dict[str, Span]): + def _serialize_span_dict(span_dict: dict[str, Span]): """Serializes to a readable format. We're not propogating span links (see note above on causal links), but that's fine (for now). We have to do this as passing spans back and forth is frowned upon. @@ -75,7 +75,7 @@ def _serialize_span_dict(span_dict: Dict[str, Span]): } @staticmethod - def _deserialize_span_dict(serialized_repr: Dict[str, dict]) -> Dict[str, context.Context]: + def _deserialize_span_dict(serialized_repr: dict[str, dict]) -> dict[str, context.Context]: """Note that we deserialize as contexts, as passing spans is not supported (the child should never terminate the parent span). @@ -120,7 +120,7 @@ def __setstate__(self, state): } @staticmethod - def _sanitize_tags(tags: Dict[str, Any]) -> Dict[str, str]: + def _sanitize_tags(tags: dict[str, Any]) -> dict[str, str]: """Sanitizes tags to be strings, just in case. :param tags: Node tags. @@ -147,9 +147,9 @@ def run_before_node_execution( self, *, node_name: str, - node_kwargs: Dict[str, Any], - node_tags: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + node_tags: dict[str, Any], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -191,8 +191,8 @@ def run_after_node_execution( self, *, node_name: str, - error: Optional[Exception], - task_id: Optional[str], + error: Exception | None, + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -215,7 +215,7 @@ def run_after_node_execution( span.__exit__(exc_type, exc_value, tb) def run_after_graph_execution( - self, *, error: Optional[Exception], run_id: str, **future_kwargs: Any + self, *, error: Exception | None, run_id: str, **future_kwargs: Any ): """Runs after graph execution. Garbage collects + finishes the root span. @@ -322,9 +322,9 @@ def run_before_node_execution( self, *, node_name: str, - node_kwargs: Dict[str, Any], - node_tags: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + node_tags: dict[str, Any], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -350,8 +350,8 @@ def run_after_node_execution( self, *, node_name: str, - error: Optional[Exception], - task_id: Optional[str], + error: Exception | None, + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -368,7 +368,7 @@ def run_after_node_execution( ) def run_after_graph_execution( - self, *, error: Optional[Exception], run_id: str, **future_kwargs: Any + self, *, error: Exception | None, run_id: str, **future_kwargs: Any ): """Runs after graph execution. Garbage collects + finishes the root span. @@ -433,7 +433,7 @@ def __init__( ) async def post_graph_construct( - self, graph: h_graph.FunctionGraph, modules: List[ModuleType], config: Dict[str, Any] + self, graph: h_graph.FunctionGraph, modules: list[ModuleType], config: dict[str, Any] ) -> None: """Runs after graph construction. This is a no-op for this plugin. @@ -447,9 +447,9 @@ async def pre_graph_execute( self, run_id: str, graph: h_graph.FunctionGraph, - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ) -> None: """Runs before graph execution -- sets the state so future ones can reference it. @@ -462,7 +462,7 @@ async def pre_graph_execute( self._impl.run_before_graph_execution(run_id=run_id) async def pre_node_execute( - self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None ) -> None: """Runs before a node's execution. Sets up/stores spans. @@ -484,9 +484,9 @@ async def post_node_execute( run_id: str, node_: node.Node, success: bool, - error: Optional[Exception], + error: Exception | None, result: Any, - task_id: Optional[str] = None, + task_id: str | None = None, **future_kwargs: dict, ) -> None: """Runs after a node's execution -- completes the span. @@ -508,8 +508,8 @@ async def post_graph_execute( run_id: str, graph: h_graph.FunctionGraph, success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, ) -> None: """Runs after graph execution. Garbage collects + finishes the root span. diff --git a/hamilton/plugins/h_diskcache.py b/hamilton/plugins/h_diskcache.py index 481a38f91..b98f504d8 100644 --- a/hamilton/plugins/h_diskcache.py +++ b/hamilton/plugins/h_diskcache.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict, List, Union +from typing import Any import diskcache @@ -29,11 +29,11 @@ def _bytes_to_mb(kb: int) -> float: return kb / (1024**2) -def evict_all_except(nodes_to_keep: Dict[str, node.Node], cache: diskcache.Cache) -> int: +def evict_all_except(nodes_to_keep: dict[str, node.Node], cache: diskcache.Cache) -> int: """Evicts all nodes and node version except those passed. Remaining nodes may have multiple entries for different input values """ - nodes_history: Dict[str, List[str]] = cache.get(key=DiskCacheAdapter.nodes_history_key) # type: ignore + nodes_history: dict[str, list[str]] = cache.get(key=DiskCacheAdapter.nodes_history_key) # type: ignore new_nodes_history = dict() eviction_counter = 0 @@ -94,15 +94,15 @@ class DiskCacheAdapter( nodes_history_key: str = "_nodes_history" def __init__( - self, cache_vars: Union[List[str], None] = None, cache_path: str = ".", **cache_settings + self, cache_vars: list[str] | None = None, cache_path: str = ".", **cache_settings ): self.cache_vars = cache_vars or [] self.cache_path = cache_path self.cache = diskcache.Cache(directory=cache_path, **cache_settings) - self.nodes_history: Dict[str, List[str]] = self.cache.get( + self.nodes_history: dict[str, list[str]] = self.cache.get( key=DiskCacheAdapter.nodes_history_key, default=dict() ) # type: ignore - self.used_nodes_hash: Dict[str, str] = dict() + self.used_nodes_hash: dict[str, str] = dict() logger.warning( "The `DiskCacheAdapter` is deprecated and will be removed in Hamilton 2.0. " @@ -117,7 +117,7 @@ def run_before_graph_execution(self, *, graph: graph_types.HamiltonGraph, **kwar self.cache_vars = [n.name for n in graph.nodes] def run_to_execute_node( - self, *, node_name: str, node_callable: Any, node_kwargs: Dict[str, Any], **kwargs + self, *, node_name: str, node_callable: Any, node_kwargs: dict[str, Any], **kwargs ): """Create hash key then use cached value if exist""" if node_name not in self.cache_vars: diff --git a/hamilton/plugins/h_experiments/data_model.py b/hamilton/plugins/h_experiments/data_model.py index 091687a98..0375200dc 100644 --- a/hamilton/plugins/h_experiments/data_model.py +++ b/hamilton/plugins/h_experiments/data_model.py @@ -16,7 +16,7 @@ # under the License. import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, create_model, model_validator @@ -53,7 +53,7 @@ class RunMetadata(BaseModel): inputs: dict overrides: dict materialized: list[NodeMaterializer] - graph_version: Optional[int] = None + graph_version: int | None = None @model_validator(mode="before") def pre_root(cls, v: dict[str, Any]): diff --git a/hamilton/plugins/h_experiments/hook.py b/hamilton/plugins/h_experiments/hook.py index 437011b2e..1acced968 100644 --- a/hamilton/plugins/h_experiments/hook.py +++ b/hamilton/plugins/h_experiments/hook.py @@ -25,7 +25,7 @@ import uuid from dataclasses import asdict, dataclass from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any from hamilton import graph_types, lifecycle from hamilton.plugins.h_experiments.cache import JsonCache @@ -77,7 +77,7 @@ class NodeImplementation: class NodeInput: name: str value: Any - default_value: Optional[Any] + default_value: Any | None @dataclass @@ -144,8 +144,8 @@ def run_before_graph_execution( self, *, graph: graph_types.HamiltonGraph, - inputs: Dict[str, Any], - overrides: Dict[str, Any], + inputs: dict[str, Any], + overrides: dict[str, Any], **kwargs, ): """Store execution metadata: graph hash, inputs, overrides""" diff --git a/hamilton/plugins/h_experiments/server.py b/hamilton/plugins/h_experiments/server.py index c60f16713..f73ab3e21 100644 --- a/hamilton/plugins/h_experiments/server.py +++ b/hamilton/plugins/h_experiments/server.py @@ -19,7 +19,6 @@ import itertools import json import os -from typing import Union import pandas as pd from fastapi import FastAPI @@ -134,8 +133,8 @@ class RunFilter(BaseModel): @app.get("/api/runs", response_model=FastUI, response_model_exclude_none=True) def runs_overview( - experiment: Union[str, None] = None, - graph_version: Union[int, None] = None, + experiment: str | None = None, + graph_version: int | None = None, ) -> list[AnyComponent]: """RunOverview page with filters for the table""" @@ -352,9 +351,7 @@ def artifact_tabs(run: RunMetadata) -> list[AnyComponent]: @app.get("/api/artifacts/{run_id}", response_model=FastUI, response_model_exclude_none=True) -def run_artifacts( - run_id: str, artifact_id: int = 0, page: Union[int, None] = None -) -> list[AnyComponent]: +def run_artifacts(run_id: str, artifact_id: int = 0, page: int | None = None) -> list[AnyComponent]: """Individual Run > Artifact""" run = run_lookup()[run_id] diff --git a/hamilton/plugins/h_kedro.py b/hamilton/plugins/h_kedro.py index 9c612913e..15e9bbe6e 100644 --- a/hamilton/plugins/h_kedro.py +++ b/hamilton/plugins/h_kedro.py @@ -16,7 +16,7 @@ # under the License. import inspect -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any from kedro.pipeline.node import Node as KNode from kedro.pipeline.pipeline import Pipeline as KPipeline @@ -27,28 +27,28 @@ from hamilton.node import Node as HNode -def expand_k_node(base_node: HNode, outputs: List[str]) -> List[HNode]: +def expand_k_node(base_node: HNode, outputs: list[str]) -> list[HNode]: """Manually apply `@extract_fields()` on a Hamilton node.Node for a Kedro node that specifies >1 `outputs`. The number of nodes == len(outputs) + 1 because it includes the `base_node` """ - def _convert_output_from_tuple_to_dict(node_result: Any, node_kwargs: Dict[str, Any]): + def _convert_output_from_tuple_to_dict(node_result: Any, node_kwargs: dict[str, Any]): return {out: v for out, v in zip(outputs, node_result, strict=False)} # NOTE isinstance(Any, type) is False for Python < 3.11 extractor = extract_fields(fields={out: Any for out in outputs}) func = base_node.originating_functions[0] - if issubclass(func.__annotations__["return"], Tuple): - base_node = base_node.transform_output(_convert_output_from_tuple_to_dict, Dict) - func.__annotations__["return"] = Dict + if issubclass(func.__annotations__["return"], tuple): + base_node = base_node.transform_output(_convert_output_from_tuple_to_dict, dict) + func.__annotations__["return"] = dict extractor.validate(func) return list(extractor.transform_node(base_node, {}, func)) -def k_node_to_h_nodes(node: KNode) -> List[HNode]: +def k_node_to_h_nodes(node: KNode) -> list[HNode]: """Convert a Kedro node to a list of Hamilton nodes. If the Kedro node specifies 1 output, generate 1 Hamilton node. If it generate >1 output, generate len(outputs) + 1 to include the base node + extracted fields. @@ -73,7 +73,7 @@ def k_node_to_h_nodes(node: KNode) -> List[HNode]: output_type = func_sig.return_annotation if output_type is None: # manually creating `hamilton.node.Node` doesn't accept `typ=None` - output_type = Type[None] # NoneType is introduced in Python 3.10 + output_type = type[None] # NoneType is introduced in Python 3.10 base_node = HNode( name=base_node_name, @@ -104,7 +104,7 @@ def k_node_to_h_nodes(node: KNode) -> List[HNode]: def kedro_pipeline_to_driver( *pipelines: KPipeline, - builder: Optional[driver.Builder] = None, + builder: driver.Builder | None = None, ) -> driver.Driver: """Convert one or mode Kedro `Pipeline` to a Hamilton `Driver`. Pass a Hamilton `Builder` to include lifecycle adapters in your `Driver`. diff --git a/hamilton/plugins/h_logging.py b/hamilton/plugins/h_logging.py index beba4db40..48f91bed6 100644 --- a/hamilton/plugins/h_logging.py +++ b/hamilton/plugins/h_logging.py @@ -19,19 +19,12 @@ import logging import sys +from collections.abc import Mapping, MutableMapping from contextvars import ContextVar from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Set, - Tuple, - Union, ) from hamilton.graph_types import HamiltonNode @@ -65,16 +58,16 @@ class _LoggingContext: """Represents the current logging context.""" - graph: Optional[str] = None - node: Optional[str] = None - task: Optional[str] = None + graph: str | None = None + node: str | None = None + task: str | None = None # Context variables for context-aware logging _local_context = ContextVar("context", default=_LoggingContext()) # noqa: B039 -def get_logger(name: Optional[str] = None) -> "ContextLogger": +def get_logger(name: str | None = None) -> "ContextLogger": """Returns a context-aware logger for the specified name (created if necessary). :param name: Name of the logger, defaults to root logger if not provided. @@ -102,7 +95,7 @@ class ContextLogger(LoggerAdapter): @override def process( self, msg: str, kwargs: MutableMapping[str, Any] - ) -> Tuple[str, MutableMapping[str, Any]]: + ) -> tuple[str, MutableMapping[str, Any]]: # Ensure that the extra fields are passed through correctly kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})} @@ -166,10 +159,10 @@ class LoggingAdapter( and the execution of each *node* as `DEBUG`. """ - def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: + def __init__(self, logger: str | logging.Logger | None = None) -> None: # Precompute or overridden nodes - self._inputs_nodes: Set[str] = set() - self._override_nodes: Set[str] = set() + self._inputs_nodes: set[str] = set() + self._override_nodes: set[str] = set() if logger is None: self.logger = logging.getLogger(__name__) @@ -187,8 +180,8 @@ def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: def run_before_graph_execution( self, *, - inputs: Optional[Dict[str, Any]], - overrides: Optional[Dict[str, Any]], + inputs: dict[str, Any] | None, + overrides: dict[str, Any] | None, run_id: str, **future_kwargs: Any, ): @@ -210,7 +203,7 @@ def run_before_graph_execution( self.logger.info("Using overrides %s", names) @override - def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs): + def run_after_task_grouping(self, *, run_id: str, task_ids: list[str], **future_kwargs): self.logger.info("Dynamic DAG detected; task-based logging is enabled") @override @@ -223,7 +216,7 @@ def run_before_task_submission( *, run_id: str, task_id: str, - spawning_task_id: Optional[str], + spawning_task_id: str | None, **future_kwargs, ): # Set context before logging @@ -240,7 +233,7 @@ def run_before_task_execution( *, task_id: str, run_id: str, - nodes: List[HamiltonNode], + nodes: list[HamiltonNode], **future_kwargs, ): # Set context before logging @@ -263,8 +256,8 @@ def run_before_node_execution( self, *, node_name: str, - node_kwargs: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -286,9 +279,9 @@ def run_after_node_execution( self, *, node_name: str, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, run_id: str, **future_kwargs: Any, ): @@ -339,9 +332,9 @@ def run_after_task_return( *, run_id: str, task_id: str, - nodes: List[Node], + nodes: list[Node], success: bool, - error: Optional[Exception], + error: Exception | None, **future_kwargs, ): # Hard reset context before logging @@ -396,15 +389,15 @@ class AsyncLoggingAdapter(GraphExecutionHook, BasePreNodeExecute, BasePostNodeEx submitted. It cannot currently log the exact moment the async node begins execution. """ - def __init__(self, logger: Union[str, logging.Logger, None] = None) -> None: + def __init__(self, logger: str | logging.Logger | None = None) -> None: self._impl = LoggingAdapter(logger) @override def run_before_graph_execution( self, *, - inputs: Dict[str, Any], - overrides: Dict[str, Any], + inputs: dict[str, Any], + overrides: dict[str, Any], run_id: str, **future_kwargs: Any, ): @@ -412,7 +405,7 @@ def run_before_graph_execution( @override def pre_node_execute( - self, *, run_id: str, node_: Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + self, *, run_id: str, node_: Node, kwargs: dict[str, Any], task_id: str | None = None ): # NOTE: We call the base synchronous method here in order to approximate when the async task # has bee submitted. This is a workaround until further work is done on the async adapter. @@ -436,11 +429,11 @@ async def post_node_execute( *, run_id: str, node_: Node, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], success: bool, - error: Optional[Exception], + error: Exception | None, result: Any, - task_id: Optional[str] = None, + task_id: str | None = None, ): self._impl.run_after_node_execution( node_name=node_.name, diff --git a/hamilton/plugins/h_mlflow.py b/hamilton/plugins/h_mlflow.py index 21723c039..aec340960 100644 --- a/hamilton/plugins/h_mlflow.py +++ b/hamilton/plugins/h_mlflow.py @@ -18,7 +18,7 @@ import logging import pickle import warnings -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any import mlflow import mlflow.data @@ -50,7 +50,7 @@ logger = logging.getLogger(__name__) -def get_path_from_metadata(metadata: dict) -> Union[str, None]: +def get_path_from_metadata(metadata: dict) -> str | None: """Retrieve the `path` attribute from DataSaver output metadata""" path = None if "path" in metadata: @@ -74,16 +74,16 @@ class MLFlowTracker( def __init__( self, - tracking_uri: Optional[str] = None, - registry_uri: Optional[str] = None, - artifact_location: Optional[str] = None, + tracking_uri: str | None = None, + registry_uri: str | None = None, + artifact_location: str | None = None, experiment_name: str = "Hamilton", - experiment_tags: Optional[dict] = None, - experiment_description: Optional[str] = None, - run_id: Optional[str] = None, - run_name: Optional[str] = None, - run_tags: Optional[dict] = None, - run_description: Optional[str] = None, + experiment_tags: dict | None = None, + experiment_description: str | None = None, + run_id: str | None = None, + run_name: str | None = None, + run_tags: dict | None = None, + run_description: str | None = None, log_system_metrics: bool = False, ): """Configure the MLFlow client and experiment for the lifetime of the tracker @@ -144,8 +144,8 @@ def run_before_graph_execution( self, *, run_id: str, - final_vars: List[str], - inputs: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], graph: graph_types.HamiltonGraph, **kwargs, ): @@ -192,7 +192,7 @@ def run_after_node_execution( self, *, node_name: str, - node_return_type: Type, + node_return_type: type, node_tags: dict, node_kwargs: dict, result: Any, diff --git a/hamilton/plugins/h_narwhals.py b/hamilton/plugins/h_narwhals.py index e11df64bc..73b95f690 100644 --- a/hamilton/plugins/h_narwhals.py +++ b/hamilton/plugins/h_narwhals.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Type, Union +from typing import Any import narwhals as nw @@ -55,10 +55,10 @@ def run_to_execute_node( self, *, node_name: str, - node_tags: Dict[str, Any], + node_tags: dict[str, Any], node_callable: Any, - node_kwargs: Dict[str, Any], - task_id: Optional[str], + node_kwargs: dict[str, Any], + task_id: str | None, **future_kwargs: Any, ) -> Any: """This method is responsible for executing the node and returning the result. @@ -109,7 +109,7 @@ class NarwhalsDataFrameResultBuilder(api.ResultBuilder): ) """ - def __init__(self, result_builder: Union[api.ResultBuilder, api.LegacyResultMixin]): + def __init__(self, result_builder: api.ResultBuilder | api.LegacyResultMixin): self.result_builder = result_builder def build_result(self, **outputs: Any) -> Any: @@ -127,7 +127,7 @@ def build_result(self, **outputs: Any) -> Any: return self.result_builder.build_result(**de_narwhaled_outputs) - def output_type(self) -> Type: + def output_type(self) -> type: """Returns the output type of this result builder :return: the type that this creates """ diff --git a/hamilton/plugins/h_openlineage.py b/hamilton/plugins/h_openlineage.py index a186b29ef..526dbd344 100644 --- a/hamilton/plugins/h_openlineage.py +++ b/hamilton/plugins/h_openlineage.py @@ -19,7 +19,7 @@ import sys import traceback from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any import attr from openlineage.client import OpenLineageClient, event_v2, facet_v2 @@ -35,9 +35,9 @@ class HamiltonFacet(facet_v2.RunFacet): hamilton_run_id: str = attr.ib() graph_version: str = attr.ib() - final_vars: List[str] = attr.ib() - inputs: List[str] = attr.ib() - overrides: List[str] = attr.ib() + final_vars: list[str] = attr.ib() + inputs: list[str] = attr.ib() + overrides: list[str] = attr.ib() def get_stack_trace(exception): @@ -204,9 +204,9 @@ def pre_graph_execute( self, run_id: str, graph: h_graph.FunctionGraph, - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], ): """ Emits a Run START event. @@ -258,7 +258,7 @@ def pre_graph_execute( self.client.emit(run_event) def pre_node_execute( - self, run_id: str, node_: node.Node, kwargs: Dict[str, Any], task_id: Optional[str] = None + self, run_id: str, node_: node.Node, kwargs: dict[str, Any], task_id: str | None = None ): """No event emitted.""" pass @@ -267,11 +267,11 @@ def post_node_execute( self, run_id: str, node_: node.Node, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], success: bool, - error: Optional[Exception], - result: Optional[Any], - task_id: Optional[str] = None, + error: Exception | None, + result: Any | None, + task_id: str | None = None, ): """ Run Event: will emit a RUNNING event with updates on input/outputs. @@ -347,8 +347,8 @@ def post_graph_execute( run_id: str, graph: h_graph.FunctionGraph, success: bool, - error: Optional[Exception], - results: Optional[Dict[str, Any]], + error: Exception | None, + results: dict[str, Any] | None, ): """Emits a Run COMPLETE or FAIL event. diff --git a/hamilton/plugins/h_opentelemetry.py b/hamilton/plugins/h_opentelemetry.py index b298cca06..fc73892a7 100644 --- a/hamilton/plugins/h_opentelemetry.py +++ b/hamilton/plugins/h_opentelemetry.py @@ -17,8 +17,9 @@ import json import logging +from collections.abc import Collection from contextvars import ContextVar -from typing import Any, Collection, Dict, List, Optional, Tuple +from typing import Any logger = logging.getLogger(__name__) @@ -39,10 +40,10 @@ # We have to keep track of tokens for the span # As OpenTel has some weird behavior around context managers, we have to account for the latest ones we started # This way we can pop one off and know where to set the current one (as the parent, when the next one ends) -token_stack = ContextVar[Optional[List[Tuple[object, Span]]]]("token_stack", default=None) +token_stack = ContextVar[list[tuple[object, Span]] | None]("token_stack", default=None) -def _exit_span(exc: Optional[Exception] = None): +def _exit_span(exc: Exception | None = None): """Ditto with _enter_span, but for exiting the span. Pops the token off the stack and detaches the context.""" stack = token_stack.get()[:] token, span = stack.pop() @@ -83,7 +84,7 @@ class OpenTelemetryTracer(NodeExecutionHook, GraphExecutionHook, TaskExecutionHo This works by logging to OpenTelemetry, and setting the span processor to be the right one (that knows about the tracker). """ - def __init__(self, tracer_name: Optional[str] = None, tracer: Optional[trace.Tracer] = None): + def __init__(self, tracer_name: str | None = None, tracer: trace.Tracer | None = None): if tracer_name and tracer: raise ValueError( f"Only pass in one of tracer_name or tracer, not both, got: tracer_name={tracer_name} and tracer={tracer}" @@ -102,7 +103,7 @@ def run_before_graph_execution( self, *, graph: HamiltonGraph, - final_vars: List[str], + final_vars: list[str], inputs: dict, overrides: dict, execution_path: Collection[str], @@ -143,9 +144,9 @@ def run_before_task_execution( self, *, task_id: str, - nodes: List[HamiltonNode], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + nodes: list[HamiltonNode], + inputs: dict[str, Any], + overrides: dict[str, Any], **kwargs, ): attributes = { @@ -156,11 +157,11 @@ def run_before_task_execution( task_span = _enter_span(task_id, self.tracer) task_span.set_attributes(attributes) - def run_after_task_execution(self, *, error: Optional[Exception], **kwargs): + def run_after_task_execution(self, *, error: Exception | None, **kwargs): _exit_span(error) - def run_after_node_execution(self, *, error: Optional[Exception], **kwargs): + def run_after_node_execution(self, *, error: Exception | None, **kwargs): _exit_span(error) - def run_after_graph_execution(self, *, error: Optional[Exception], **kwargs): + def run_after_graph_execution(self, *, error: Exception | None, **kwargs): _exit_span(error) diff --git a/hamilton/plugins/h_pandas.py b/hamilton/plugins/h_pandas.py index f6bdb4ac1..d0098443d 100644 --- a/hamilton/plugins/h_pandas.py +++ b/hamilton/plugins/h_pandas.py @@ -16,8 +16,9 @@ # under the License. import sys +from collections.abc import Callable, Collection from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints +from typing import Any, get_type_hints _sys_version_info = sys.version_info _version_tuple = (_sys_version_info.major, _sys_version_info.minor, _sys_version_info.micro) @@ -124,13 +125,13 @@ def final_df(initial_df: pd.DataFrame, ...) -> pd.DataFrame: def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, on_input: str = None, - select: List[str] = None, + select: list[str] = None, namespace: str = None, - config_required: List[str] = None, + config_required: list[str] = None, ): """Instantiates a ``@with_columns`` decorator. @@ -169,8 +170,8 @@ def __init__( ) def _create_column_nodes( - self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] - ) -> List[node.Node]: + self, fn: Callable, inject_parameter: str, params: dict[str, type[type]] + ) -> list[node.Node]: output_type = params[inject_parameter] def temp_fn(**kwargs) -> Any: @@ -190,8 +191,8 @@ def temp_fn(**kwargs) -> Any: return out_nodes[1:] def get_initial_nodes( - self, fn: Callable, params: Dict[str, Type[Type]] - ) -> Tuple[str, Collection[node.Node]]: + self, fn: Callable, params: dict[str, type[type]] + ) -> tuple[str, Collection[node.Node]]: """Selects the correct dataframe and optionally extracts out columns.""" inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) with_columns_base.validate_dataframe( @@ -209,7 +210,7 @@ def get_initial_nodes( return inject_parameter, initial_nodes - def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + def get_subdag_nodes(self, fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: return subdag.collect_nodes(config, self.subdag_functions) def chain_subdag_nodes( diff --git a/hamilton/plugins/h_pandera.py b/hamilton/plugins/h_pandera.py index e754ccc1a..4dbe36929 100644 --- a/hamilton/plugins/h_pandera.py +++ b/hamilton/plugins/h_pandera.py @@ -16,7 +16,6 @@ # under the License. import typing -from typing import List import pandera from pandera import typing as pa_typing @@ -88,7 +87,7 @@ def foo() -> pd.DataFrame: self.importance = importance self.target = target - def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + def get_validators(self, node_to_validate: node.Node) -> list[dq_base.DataValidator]: """Gets validators for the node. Delegates to the standard check_output(schema=...) decorator. :param node_to_validate: Node to validate diff --git a/hamilton/plugins/h_polars.py b/hamilton/plugins/h_polars.py index 79557cf5b..ad60ff948 100644 --- a/hamilton/plugins/h_polars.py +++ b/hamilton/plugins/h_polars.py @@ -16,8 +16,9 @@ # under the License. import sys +from collections.abc import Callable, Collection from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints +from typing import Any, get_type_hints import polars as pl @@ -62,9 +63,7 @@ class PolarsDataFrameResult(base.ResultMixin): Note: this is just a first attempt at something for Polars. Think it should handle more? Come chat/open a PR! """ - def build_result( - self, **outputs: Dict[str, Union[pl.Series, pl.DataFrame, Any]] - ) -> pl.DataFrame: + def build_result(self, **outputs: dict[str, pl.Series | pl.DataFrame | Any]) -> pl.DataFrame: """This is the method that Hamilton will call to build the final result. It will pass in the results of the requested outputs that you passed in to the execute() method. @@ -88,7 +87,7 @@ def build_result( # happen for mixed outputs that include scalars for example. return pl.DataFrame(outputs) - def output_type(self) -> Type: + def output_type(self) -> type: return pl.DataFrame @@ -174,13 +173,13 @@ def final_df(initial_df: pl.DataFrame) -> pl.DataFrame: def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, on_input: str = None, - select: List[str] = None, + select: list[str] = None, namespace: str = None, - config_required: List[str] = None, + config_required: list[str] = None, ): """Instantiates a ``@with_columns`` decorator. @@ -219,8 +218,8 @@ def __init__( ) def _create_column_nodes( - self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] - ) -> List[node.Node]: + self, fn: Callable, inject_parameter: str, params: dict[str, type[type]] + ) -> list[node.Node]: output_type = params[inject_parameter] def temp_fn(**kwargs) -> Any: @@ -240,8 +239,8 @@ def temp_fn(**kwargs) -> Any: return out_nodes[1:] def get_initial_nodes( - self, fn: Callable, params: Dict[str, Type[Type]] - ) -> Tuple[str, Collection[node.Node]]: + self, fn: Callable, params: dict[str, type[type]] + ) -> tuple[str, Collection[node.Node]]: """Selects the correct dataframe and optionally extracts out columns.""" inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) with_columns_base.validate_dataframe( @@ -259,7 +258,7 @@ def get_initial_nodes( return inject_parameter, initial_nodes - def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + def get_subdag_nodes(self, fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: return subdag.collect_nodes(config, self.subdag_functions) def chain_subdag_nodes( diff --git a/hamilton/plugins/h_polars_lazyframe.py b/hamilton/plugins/h_polars_lazyframe.py index 9a8951ac4..92175953c 100644 --- a/hamilton/plugins/h_polars_lazyframe.py +++ b/hamilton/plugins/h_polars_lazyframe.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Callable, Collection from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union, get_type_hints +from typing import Any, get_type_hints import polars as pl @@ -51,9 +52,7 @@ class PolarsLazyFrameResult(base.ResultMixin): Note: this is just a first attempt at something for Polars. Think it should handle more? Come chat/open a PR! """ - def build_result( - self, **outputs: Dict[str, Union[pl.Series, pl.LazyFrame, Any]] - ) -> pl.LazyFrame: + def build_result(self, **outputs: dict[str, pl.Series | pl.LazyFrame | Any]) -> pl.LazyFrame: """This is the method that Hamilton will call to build the final result. It will pass in the results of the requested outputs that you passed in to the execute() method. @@ -68,7 +67,7 @@ def build_result( return value return pl.LazyFrame(outputs) - def output_type(self) -> Type: + def output_type(self) -> type: return pl.LazyFrame @@ -151,13 +150,13 @@ def final_df(initial_df: pl.LazyFrame) -> pl.LazyFrame: def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, on_input: str = None, - select: List[str] = None, + select: list[str] = None, namespace: str = None, - config_required: List[str] = None, + config_required: list[str] = None, ): """Instantiates a ``@with_columns`` decorator. @@ -196,8 +195,8 @@ def __init__( ) def _create_column_nodes( - self, fn: Callable, inject_parameter: str, params: Dict[str, Type[Type]] - ) -> List[node.Node]: + self, fn: Callable, inject_parameter: str, params: dict[str, type[type]] + ) -> list[node.Node]: output_type = params[inject_parameter] def temp_fn(**kwargs) -> Any: @@ -217,8 +216,8 @@ def temp_fn(**kwargs) -> Any: return out_nodes[1:] def get_initial_nodes( - self, fn: Callable, params: Dict[str, Type[Type]] - ) -> Tuple[str, Collection[node.Node]]: + self, fn: Callable, params: dict[str, type[type]] + ) -> tuple[str, Collection[node.Node]]: """Selects the correct dataframe and optionally extracts out columns.""" inject_parameter = _default_inject_parameter(fn=fn, target_dataframe=self.target_dataframe) @@ -237,7 +236,7 @@ def get_initial_nodes( return inject_parameter, initial_nodes - def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + def get_subdag_nodes(self, fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: return subdag.collect_nodes(config, self.subdag_functions) def chain_subdag_nodes( diff --git a/hamilton/plugins/h_pyarrow.py b/hamilton/plugins/h_pyarrow.py index 17b60b243..8f8c6bdc2 100644 --- a/hamilton/plugins/h_pyarrow.py +++ b/hamilton/plugins/h_pyarrow.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any import pyarrow from pyarrow.interchange import from_dataframe @@ -38,7 +38,7 @@ class PyarrowTableResult(ResultBuilder): - duckdb results """ - def output_type(self) -> Type: + def output_type(self) -> type: return pyarrow.Table def build_result(self, **outputs: Any) -> Any: diff --git a/hamilton/plugins/h_pydantic.py b/hamilton/plugins/h_pydantic.py index d3779965c..209b76c62 100644 --- a/hamilton/plugins/h_pydantic.py +++ b/hamilton/plugins/h_pydantic.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import List from pydantic import BaseModel @@ -117,7 +116,7 @@ def foo() -> MyModel: self.importance = importance self.target = target - def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: + def get_validators(self, node_to_validate: node.Node) -> list[dq_base.DataValidator]: output_type = node_to_validate.type if not custom_subclass_check(output_type, BaseModel): raise InvalidDecoratorException( diff --git a/hamilton/plugins/h_ray.py b/hamilton/plugins/h_ray.py index f017ff89c..05bfcf402 100644 --- a/hamilton/plugins/h_ray.py +++ b/hamilton/plugins/h_ray.py @@ -49,7 +49,7 @@ def new_fn(*args, **kwargs): return fn -def parse_ray_remote_options_from_tags(tags: typing.Dict[str, str]) -> typing.Dict[str, typing.Any]: +def parse_ray_remote_options_from_tags(tags: dict[str, str]) -> dict[str, typing.Any]: """DRY helper to parse ray.remote(**options) from Hamilton Tags Tags are added to nodes via the @ray_remote_options decorator @@ -114,7 +114,7 @@ class RayGraphAdapter( def __init__( self, result_builder: base.ResultMixin, - ray_init_config: typing.Dict[str, typing.Any] = None, + ray_init_config: dict[str, typing.Any] = None, shutdown_ray_on_completion: bool = False, ): """Constructor @@ -139,14 +139,14 @@ def __init__( ray.init(**ray_init_config) @staticmethod - def do_validate_input(node_type: typing.Type, input_value: typing.Any) -> bool: + def do_validate_input(node_type: type, input_value: typing.Any) -> bool: # NOTE: the type of a raylet is unknown until they are computed if isinstance(input_value, ray._raylet.ObjectRef): return True return htypes.check_input_type(node_type, input_value) @staticmethod - def do_check_edge_types_match(type_from: typing.Type, type_to: typing.Type) -> bool: + def do_check_edge_types_match(type_from: type, type_to: type) -> bool: return type_from == type_to def do_remote_execute( @@ -154,7 +154,7 @@ def do_remote_execute( *, execute_lifecycle_for_node: typing.Callable, node: node.Node, - **kwargs: typing.Dict[str, typing.Any], + **kwargs: dict[str, typing.Any], ) -> typing.Any: """Function that is called as we walk the graph to determine how to execute a hamilton function. @@ -165,7 +165,7 @@ def do_remote_execute( ray_options = parse_ray_remote_options_from_tags(node.tags) return ray.remote(raify(execute_lifecycle_for_node)).options(**ray_options).remote(**kwargs) - def do_build_result(self, outputs: typing.Dict[str, typing.Any]) -> typing.Any: + def do_build_result(self, outputs: dict[str, typing.Any]) -> typing.Any: """Builds the result and brings it back to this running process. :param outputs: the dictionary of key -> Union[ray object reference | value] @@ -196,7 +196,7 @@ class RayTaskExecutor(executors.TaskExecutor): def __init__( self, num_cpus: int = None, - ray_init_config: typing.Dict[str, typing.Any] = None, + ray_init_config: dict[str, typing.Any] = None, skip_init: bool = False, ): """Creates a ray task executor. Note this will likely take in more parameters. This is diff --git a/hamilton/plugins/h_rich.py b/hamilton/plugins/h_rich.py index 765f34e15..3eb777aaf 100644 --- a/hamilton/plugins/h_rich.py +++ b/hamilton/plugins/h_rich.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Collection, List +from collections.abc import Collection +from typing import Any import rich.progress @@ -109,7 +110,7 @@ def run_after_graph_execution(self, **kwargs: Any): self._progress.stop() # in case progress thread is lagging @override - def run_after_task_grouping(self, *, task_ids: List[str], **kwargs): + def run_after_task_grouping(self, *, task_ids: list[str], **kwargs): # Change the total of the task group to the number of tasks in the group self._progress.update(self._progress.task_ids[0], total=len(task_ids)) self._task_based = True diff --git a/hamilton/plugins/h_schema.py b/hamilton/plugins/h_schema.py index 5b7ffed3f..c1a8a7daa 100644 --- a/hamilton/plugins/h_schema.py +++ b/hamilton/plugins/h_schema.py @@ -19,8 +19,9 @@ import functools import json import logging +from collections.abc import Mapping from pathlib import Path -from typing import Any, Dict, Literal, Mapping, NamedTuple, Union +from typing import Any, Literal, NamedTuple import pyarrow import pyarrow.ipc @@ -49,7 +50,7 @@ class DiffResult(NamedTuple): def _diff_mappings( current: Mapping[str, Any], reference: Mapping[str, Any] -) -> Dict[str, DiffResult]: +) -> dict[str, DiffResult]: """Generate the diff for all fields of two mappings. example: @@ -355,9 +356,7 @@ def _get_spark_schema(df, **kwargs) -> pyarrow.Schema: # ongoing polars discussion: https://github.com/pola-rs/polars/issues/15600 -def get_dataframe_schema( - df: Union[h_databackends.DATAFRAME_TYPES], node: HamiltonNode -) -> pyarrow.Schema: +def get_dataframe_schema(df: h_databackends.DATAFRAME_TYPES, node: HamiltonNode) -> pyarrow.Schema: """Get pyarrow schema of a node result and store node metadata on the pyarrow schema.""" schema = _get_arrow_schema(df) metadata = dict( @@ -368,12 +367,12 @@ def get_dataframe_schema( return schema.with_metadata(metadata) -def load_schema(path: Union[str, Path]) -> pyarrow.Schema: +def load_schema(path: str | Path) -> pyarrow.Schema: """Load pyarrow schema from disk using IPC deserialization""" return pyarrow.ipc.read_schema(path) -def save_schema(path: Union[str, Path], schema: pyarrow.Schema) -> None: +def save_schema(path: str | Path, schema: pyarrow.Schema) -> None: """Save pyarrow schema to disk using IPC serialization""" Path(path).write_bytes(schema.serialize()) @@ -400,8 +399,8 @@ def __init__( "warn": log a warning with the schema diff "fail": raise an exception with the schema diff """ - self.schemas: Dict[str, pyarrow.Schema] = {} - self.reference_schemas: Dict[str, pyarrow.Schema] = {} + self.schemas: dict[str, pyarrow.Schema] = {} + self.reference_schemas: dict[str, pyarrow.Schema] = {} self.schema_diffs: dict = {} self.schema_dir = schema_dir self.check = check @@ -415,7 +414,7 @@ def __init__( Path(schema_dir).mkdir(parents=True, exist_ok=True) @property - def json_schemas(self) -> Dict[str, dict]: + def json_schemas(self) -> dict[str, dict]: """Return schemas collected during the run""" return { node_name: pyarrow_schema_to_json(schema) for node_name, schema in self.schemas.items() @@ -432,7 +431,7 @@ def get_schema_path(self, node_name: str) -> Path: return Path(self.schema_dir, node_name).with_suffix(".schema") def run_before_graph_execution( - self, *, graph: HamiltonGraph, inputs: Dict[str, Any], overrides: Dict[str, Any], **kwargs + self, *, graph: HamiltonGraph, inputs: dict[str, Any], overrides: dict[str, Any], **kwargs ): """Store schemas of inputs and overrides nodes that are tables or columns.""" self.h_graph = graph diff --git a/hamilton/plugins/h_slack.py b/hamilton/plugins/h_slack.py index 9dcb6d1b0..28a65e663 100644 --- a/hamilton/plugins/h_slack.py +++ b/hamilton/plugins/h_slack.py @@ -16,7 +16,7 @@ # under the License. import traceback -from typing import Any, Dict, Optional +from typing import Any from slack_sdk import WebClient @@ -64,8 +64,8 @@ def _send_message(self, message: str): def run_before_node_execution( self, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, **future_kwargs: Any, ): @@ -75,13 +75,13 @@ def run_before_node_execution( def run_after_node_execution( self, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, result: Any, - error: Optional[Exception], + error: Exception | None, success: bool, - task_id: Optional[str], + task_id: str | None, run_id: str, **future_kwargs: Any, ): diff --git a/hamilton/plugins/h_spark.py b/hamilton/plugins/h_spark.py index 76b802d49..576cd4c75 100644 --- a/hamilton/plugins/h_spark.py +++ b/hamilton/plugins/h_spark.py @@ -18,8 +18,9 @@ import functools import inspect import logging +from collections.abc import Callable, Collection from types import CodeType, FunctionType, ModuleType -from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any import numpy as np import pandas as pd @@ -48,7 +49,7 @@ class KoalasDataFrameResult(base.ResultMixin): """Mixin for building a koalas dataframe from the result""" @staticmethod - def build_result(**outputs: Dict[str, Any]) -> ps.DataFrame: + def build_result(**outputs: dict[str, Any]) -> ps.DataFrame: """Right now this class is just used for signaling the return type.""" pass @@ -121,7 +122,7 @@ def __init__(self, spark_session, result_builder: base.ResultMixin, spine_column self.spine_column = spine_column @staticmethod - def check_input_type(node_type: Type, input_value: Any) -> bool: + def check_input_type(node_type: type, input_value: Any) -> bool: """Function to equate an input value, with expected node type. We need this to equate pandas and koalas objects/types. @@ -141,7 +142,7 @@ def check_input_type(node_type: Type, input_value: Any) -> bool: return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: + def check_node_type_equivalence(node_type: type, input_type: type) -> bool: """Function to help equate pandas with koalas types. :param node_type: the declared node type. @@ -158,7 +159,7 @@ def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: return True return node_type == input_type - def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: + def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any: """Function that is called as we walk the graph to determine how to execute a hamilton function. :param node: the node from the graph. @@ -167,7 +168,7 @@ def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: """ return node.callable(**kwargs) - def build_result(self, **outputs: Dict[str, Any]) -> Union[pd.DataFrame, ps.DataFrame, dict]: + def build_result(self, **outputs: dict[str, Any]) -> pd.DataFrame | ps.DataFrame | dict: if isinstance(self.result_builder, base.DictResult): return self.result_builder.build_result(**outputs) # we don't use the actual function for building right now, we use this hacky equivalent @@ -181,7 +182,7 @@ def build_result(self, **outputs: Dict[str, Any]) -> Union[pd.DataFrame, ps.Data return df -def numpy_to_spark_type(numpy_type: Type) -> types.DataType: +def numpy_to_spark_type(numpy_type: type) -> types.DataType: """Function to convert a numpy type to a Spark type. :param numpy_type: the numpy type to convert. @@ -207,7 +208,7 @@ def numpy_to_spark_type(numpy_type: Type) -> types.DataType: raise ValueError("Unsupported NumPy type: " + str(numpy_type)) -def python_to_spark_type(python_type: Type[Union[int, float, bool, str, bytes]]) -> types.DataType: +def python_to_spark_type(python_type: type[int | float | bool | str | bytes]) -> types.DataType: """Function to convert a Python type to a Spark type. :param python_type: the Python type to convert. @@ -245,7 +246,7 @@ def get_spark_type(return_type: Any) -> types.DataType: ) -def _get_pandas_annotations(node_: node.Node, bound_parameters: Dict[str, Any]) -> Dict[str, bool]: +def _get_pandas_annotations(node_: node.Node, bound_parameters: dict[str, Any]) -> dict[str, bool]: """Given a function, return a dictionary of the parameters that are annotated as pandas series. :param hamilton_udf: the function to check. @@ -266,10 +267,10 @@ def _get_type_from_annotation(annotation: Any) -> Any: def _determine_parameters_to_bind( actual_kwargs: dict, - df_columns: Set[str], - node_input_types: Dict[str, Tuple], + df_columns: set[str], + node_input_types: dict[str, tuple], node_name: str, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: +) -> tuple[dict[str, Any], dict[str, Any]]: """Function that we use to bind inputs to the function, or determine we should pull them from the dataframe. It does two things: @@ -300,7 +301,7 @@ def _determine_parameters_to_bind( return params_from_df, bind_parameters -def _inspect_kwargs(kwargs: Dict[str, Any]) -> Tuple[DataFrame, Dict[str, Any]]: +def _inspect_kwargs(kwargs: dict[str, Any]) -> tuple[DataFrame, dict[str, Any]]: """Inspects kwargs, removes any dataframes, and returns the (presumed single) dataframe, with remaining kwargs. :param kwargs: the inputs to the function. @@ -317,7 +318,7 @@ def _inspect_kwargs(kwargs: Dict[str, Any]) -> Tuple[DataFrame, Dict[str, Any]]: return df, actual_kwargs -def _format_pandas_udf(func_name: str, ordered_params: List[str]) -> str: +def _format_pandas_udf(func_name: str, ordered_params: list[str]) -> str: formatting_params = { "name": func_name, "return_type": "pd.Series", @@ -331,7 +332,7 @@ def {name}({params}) -> {return_type}: return func_string -def _format_udf(func_name: str, ordered_params: List[str]) -> str: +def _format_udf(func_name: str, ordered_params: list[str]) -> str: formatting_params = { "name": func_name, "params": ", ".join(ordered_params), @@ -346,8 +347,8 @@ def {name}({params}): def _fabricate_spark_function( node_: node.Node, - params_to_bind: Dict[str, Any], - params_from_df: Dict[str, Any], + params_to_bind: dict[str, Any], + params_from_df: dict[str, Any], pandas_udf: bool, ) -> FunctionType: """Fabricates a spark compatible UDF. We have to do this as we don't actually have a funtion @@ -383,7 +384,7 @@ def _fabricate_spark_function( return FunctionType(func_code, {**globals(), **{"partial_fn": partial_fn}}, func_name) -def _lambda_udf(df: DataFrame, node_: node.Node, actual_kwargs: Dict[str, Any]) -> DataFrame: +def _lambda_udf(df: DataFrame, node_: node.Node, actual_kwargs: dict[str, Any]) -> DataFrame: """Function to create a lambda UDF for a function. This functions does the following: @@ -457,14 +458,14 @@ def __init__(self): self.call_count = 0 @staticmethod - def check_input_type(node_type: Type, input_value: Any) -> bool: + def check_input_type(node_type: type, input_value: Any) -> bool: """If the input is a pyspark dataframe, skip, else delegate the check.""" if isinstance(input_value, DataFrame): return True return htypes.check_input_type(node_type, input_value) @staticmethod - def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: + def check_node_type_equivalence(node_type: type, input_type: type) -> bool: """Checks for the htype.column annotation and deals with it.""" # Good Cases: # [pd.Series, int] -> [pd.Series, int] @@ -480,7 +481,7 @@ def check_node_type_equivalence(node_type: Type, input_type: Type) -> bool: series_to_primitive = False return exact_match or series_to_series or series_to_primitive - def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: + def execute_node(self, node: node.Node, kwargs: dict[str, Any]) -> Any: """Given a node to execute, process it and apply a UDF if applicable. :param node: the node we're processing. @@ -511,7 +512,7 @@ def execute_node(self, node: node.Node, kwargs: Dict[str, Any]) -> Any: logger.debug("%s, After, %s", node.name, df.columns) return df - def build_result(self, **outputs: Dict[str, Any]) -> DataFrame: + def build_result(self, **outputs: dict[str, Any]) -> DataFrame: """Builds the result and brings it back to this running process. :param outputs: the dictionary of key -> Union[ray object reference | value] @@ -540,9 +541,9 @@ def sparkify_node_with_udf( node_: node.Node, linear_df_dependency_name: str, base_df_dependency_name: str, - base_df_dependency_param: Optional[str], - dependent_columns_in_group: Set[str], - dependent_columns_from_dataframe: Set[str], + base_df_dependency_param: str | None, + dependent_columns_in_group: set[str], + dependent_columns_from_dataframe: set[str], ) -> node.Node: """ """ """Turns a node into a spark node. This does the following: @@ -573,8 +574,8 @@ def sparkify_node_with_udf( def new_callable( __linear_df_dependency_name: str = linear_df_dependency_name, __base_df_dependency_name: str = base_df_dependency_name, - __dependent_columns_in_group: Set[str] = dependent_columns_in_group, - __dependent_columns_from_dataframe: Set[str] = dependent_columns_from_dataframe, + __dependent_columns_in_group: set[str] = dependent_columns_in_group, + __dependent_columns_from_dataframe: set[str] = dependent_columns_from_dataframe, __base_df_dependency_param: str = base_df_dependency_param, __node: node.Node = node_, **kwargs, @@ -630,7 +631,7 @@ def new_callable( def derive_dataframe_parameter( - param_types: Dict[str, Type], requested_parameter: str, location_name: Callable + param_types: dict[str, type], requested_parameter: str, location_name: Callable ) -> str: dataframe_parameters = { param for param, val in param_types.items() if custom_subclass_check(val, DataFrame) @@ -731,7 +732,7 @@ def __init__(self, *columns: str): self._columns = columns def transform_node( - self, node_: node.Node, config: Dict[str, Any], fn: Callable + self, node_: node.Node, config: dict[str, Any], fn: Callable ) -> Collection[node.Node]: """Generates nodes for the `@require_columns` decorator. @@ -772,7 +773,7 @@ def new_callable(__input_types=node_.input_types, **kwargs): # if it returns a column, we just turn it into a withColumn expression if custom_subclass_check(node_.type, Column): - def transform_output(output: Column, kwargs: Dict[str, Any]) -> DataFrame: + def transform_output(output: Column, kwargs: dict[str, Any]) -> DataFrame: return kwargs[param].withColumn(node_.name, output) node_out = node_out.transform_output(transform_output, DataFrame) @@ -789,7 +790,7 @@ def validate(self, fn: Callable): _derive_first_dataframe_parameter_from_fn(fn) @staticmethod - def _extract_dataframe_params(node_: node.Node) -> List[str]: + def _extract_dataframe_params(node_: node.Node) -> list[str]: """Extracts the dataframe parameters from a node. :param node_: Node to extract from @@ -829,9 +830,9 @@ def sparkify_node( node_: node.Node, linear_df_dependency_name: str, base_df_dependency_name: str, - base_df_param_name: Optional[str], - dependent_columns_from_upstream: Set[str], - dependent_columns_from_dataframe: Set[str], + base_df_param_name: str | None, + dependent_columns_from_upstream: set[str], + dependent_columns_from_dataframe: set[str], ) -> node.Node: """Transforms a pyspark node into a node that can be run as part of a `with_columns` group. This is only for non-UDF nodes that have already been transformed by `@transforms`. @@ -890,7 +891,7 @@ def new_callable(__callable=node_.callable, **kwargs) -> Any: return node_ -def _identify_upstream_dataframe_nodes(nodes: List[node.Node]) -> List[str]: +def _identify_upstream_dataframe_nodes(nodes: list[node.Node]) -> list[str]: """Gives the upstream dataframe name. This is the only ps.DataFrame parameter not produced from within the subdag. @@ -921,14 +922,14 @@ def _identify_upstream_dataframe_nodes(nodes: List[node.Node]) -> List[str]: class with_columns(with_columns_base): def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, on_input: str = None, - select: List[str] = None, + select: list[str] = None, namespace: str = None, mode: str = "append", - config_required: List[str] = None, + config_required: list[str] = None, ): """Initializes a with_columns decorator for spark. This allows you to efficiently run groups of map operations on a dataframe, represented as pandas/primitives UDFs. This @@ -1027,7 +1028,7 @@ def final_df(initial_df: ps.DataFrame) -> ps.DataFrame: self.mode = mode @staticmethod - def _prep_nodes(initial_nodes: List[node.Node]) -> List[node.Node]: + def _prep_nodes(initial_nodes: list[node.Node]) -> list[node.Node]: """Prepares nodes by decorating "default" UDFs with transform. This allows us to use the sparkify_node function in transforms for both the default ones and the decorated ones. @@ -1047,7 +1048,7 @@ def _prep_nodes(initial_nodes: List[node.Node]) -> List[node.Node]: @staticmethod def create_selector_node( - upstream_name: str, columns: List[str], node_name: str = "select" + upstream_name: str, columns: list[str], node_name: str = "select" ) -> node.Node: """Creates a selector node. The sole job of this is to select just the specified columns. Note this is a utility function that's only called here. @@ -1070,7 +1071,7 @@ def new_callable(**kwargs) -> DataFrame: @staticmethod def create_drop_node( - upstream_name: str, columns: List[str], node_name: str = "select" + upstream_name: str, columns: list[str], node_name: str = "select" ) -> node.Node: """Creates a drop node. The sole job of this is to drop just the specified columns. Note this is a utility function that's only called here. @@ -1091,7 +1092,7 @@ def new_callable(**kwargs) -> DataFrame: input_types={upstream_name: DataFrame}, ) - def _validate_dataframe_subdag_parameter(self, nodes: List[node.Node], fn_name: str): + def _validate_dataframe_subdag_parameter(self, nodes: list[node.Node], fn_name: str): all_upstream_dataframe_nodes = _identify_upstream_dataframe_nodes(nodes) initial_schema = set(self.initial_schema) if self.initial_schema is not None else set() candidates_for_upstream_dataframe = set(all_upstream_dataframe_nodes) - set(initial_schema) @@ -1125,12 +1126,12 @@ def _validate_dataframe_subdag_parameter(self, nodes: List[node.Node], fn_name: f"Instead, we found: {upstream_dependency}." ) - def required_config(self) -> List[str]: + def required_config(self) -> list[str]: return self.config_required def get_initial_nodes( - self, fn: Callable, params: Dict[str, Type[Type]] - ) -> Tuple[str, Collection[node.Node]]: + self, fn: Callable, params: dict[str, type[type]] + ) -> tuple[str, Collection[node.Node]]: inject_parameter = _derive_first_dataframe_parameter_from_fn(fn=fn) with_columns_base.validate_dataframe( fn=fn, @@ -1142,7 +1143,7 @@ def get_initial_nodes( initial_nodes = [] return inject_parameter, initial_nodes - def get_subdag_nodes(self, fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]: + def get_subdag_nodes(self, fn: Callable, config: dict[str, Any]) -> Collection[node.Node]: initial_nodes = subdag.collect_nodes(config, self.subdag_functions) transformed_nodes = with_columns._prep_nodes(initial_nodes) @@ -1237,12 +1238,12 @@ def validate(self, fn: Callable): class select(with_columns): def __init__( self, - *load_from: Union[Callable, ModuleType], - columns_to_pass: List[str] = None, + *load_from: Callable | ModuleType, + columns_to_pass: list[str] = None, pass_dataframe_as: str = None, - output_cols: List[str] = None, + output_cols: list[str] = None, namespace: str = None, - config_required: List[str] = None, + config_required: list[str] = None, ): """Initializes a select decorator for spark. This allows you to efficiently run groups of map operations on a dataframe, represented as pandas/primitives UDFs. This diff --git a/hamilton/plugins/h_threadpool.py b/hamilton/plugins/h_threadpool.py index f01559ff6..db8a20373 100644 --- a/hamilton/plugins/h_threadpool.py +++ b/hamilton/plugins/h_threadpool.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Callable from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Type +from typing import Any from hamilton import lifecycle, node from hamilton.lifecycle import base @@ -74,7 +75,7 @@ def __init__( ) self.result_builder = result_builder - def input_types(self) -> List[Type[Type]]: + def input_types(self) -> list[type[type]]: """Gives the applicable types to this result builder. This is optional for backwards compatibility, but is recommended. @@ -84,7 +85,7 @@ def input_types(self) -> List[Type[Type]]: # result builder doesn't make sense. return [Any] - def output_type(self) -> Type: + def output_type(self) -> type: """Returns the output type of this result builder :return: the type that this creates """ @@ -97,7 +98,7 @@ def do_remote_execute( *, execute_lifecycle_for_node: Callable, node: node.Node, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> Any: """Function that submits the passed in function to the ThreadPoolExecutor to be executed after wrapping it with the _new_fn function. diff --git a/hamilton/plugins/h_tqdm.py b/hamilton/plugins/h_tqdm.py index 16a38444b..a056721db 100644 --- a/hamilton/plugins/h_tqdm.py +++ b/hamilton/plugins/h_tqdm.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Collection, Dict, List, Optional +from collections.abc import Collection +from typing import Any import tqdm @@ -78,9 +79,9 @@ def run_before_graph_execution( self, *, graph: graph_types.HamiltonGraph, - final_vars: List[str], - inputs: Dict[str, Any], - overrides: Dict[str, Any], + final_vars: list[str], + inputs: dict[str, Any], + overrides: dict[str, Any], execution_path: Collection[str], **future_kwargs: Any, ): @@ -98,10 +99,10 @@ def run_before_node_execution( self, *, node_name: str, - node_tags: Dict[str, Any], - node_kwargs: Dict[str, Any], + node_tags: dict[str, Any], + node_kwargs: dict[str, Any], node_return_type: type, - task_id: Optional[str], + task_id: str | None, **future_kwargs: Any, ): name_display = self._get_node_name_display(node_name) diff --git a/hamilton/plugins/h_vaex.py b/hamilton/plugins/h_vaex.py index 88356f48e..7dc1eb679 100644 --- a/hamilton/plugins/h_vaex.py +++ b/hamilton/plugins/h_vaex.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Type, Union +from typing import Any import numpy as np import pandas as pd @@ -54,7 +54,7 @@ class VaexDataFrameResult(base.ResultMixin): def build_result( self, - **outputs: Dict[str, Union[vaex.expression.Expression, vaex.dataframe.DataFrame, Any]], + **outputs: dict[str, vaex.expression.Expression | vaex.dataframe.DataFrame | Any], ): """This is the method that Hamilton will call to build the final result. It will pass in the results of the requested outputs that @@ -65,9 +65,9 @@ def build_result( """ # We split all outputs into DataFrames, arrays and scalars - dfs: List[vaex.dataframe.DataFrame] = [] # Vaex DataFrames from outputs - arrays: Dict[str, np.ndarray] = dict() # arrays from outputs - scalars: Dict[str, Any] = dict() # scalars from outputs + dfs: list[vaex.dataframe.DataFrame] = [] # Vaex DataFrames from outputs + arrays: dict[str, np.ndarray] = dict() # arrays from outputs + scalars: dict[str, Any] = dict() # scalars from outputs for name, value in outputs.items(): if isinstance(value, vaex.dataframe.DataFrame): @@ -125,5 +125,5 @@ def build_result( return vaex.concat(dfs) - def output_type(self) -> Type: + def output_type(self) -> type: return vaex.dataframe.DataFrame diff --git a/hamilton/plugins/huggingface_extensions.py b/hamilton/plugins/huggingface_extensions.py index b0ba3772c..94ca9206a 100644 --- a/hamilton/plugins/huggingface_extensions.py +++ b/hamilton/plugins/huggingface_extensions.py @@ -16,19 +16,11 @@ # under the License. import dataclasses +from collections.abc import Collection, Mapping, Sequence from os import PathLike from typing import ( Any, BinaryIO, - Collection, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, ) try: @@ -68,29 +60,29 @@ class HuggingFaceDSLoader(DataLoader): """Data loader for hugging face datasets. Uses load_data method.""" path: str - dataset_name: Optional[str] = None # this can't be `name` because it clashes with `.name()` - data_dir: Optional[str] = None - data_files: Optional[Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]] = None - split: Optional[str] = None - cache_dir: Optional[str] = None - features: Optional[Features] = None - download_config: Optional[DownloadConfig] = None - download_mode: Optional[Union[DownloadMode, str]] = None - verification_mode: Optional[Union[VerificationMode, str]] = None + dataset_name: str | None = None # this can't be `name` because it clashes with `.name()` + data_dir: str | None = None + data_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None + split: str | None = None + cache_dir: str | None = None + features: Features | None = None + download_config: DownloadConfig | None = None + download_mode: DownloadMode | str | None = None + verification_mode: VerificationMode | str | None = None ignore_verifications = "deprecated" - keep_in_memory: Optional[bool] = None + keep_in_memory: bool | None = None save_infos: bool = False - revision: Optional[Union[str, Version]] = None - token: Optional[Union[bool, str]] = None + revision: str | Version | None = None + token: bool | str | None = None use_auth_token = "deprecated" task = "deprecated" streaming: bool = False - num_proc: Optional[int] = None - storage_options: Optional[Dict] = None - config_kwargs: Optional[Dict] = None + num_proc: int | None = None + storage_options: dict | None = None + config_kwargs: dict | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return list(HF_types) def _get_loading_kwargs(self) -> dict: @@ -98,7 +90,7 @@ def _get_loading_kwargs(self) -> dict: kwargs = dataclasses.asdict(self) # we send path separately del kwargs["path"] - config_kwargs: Optional[dict] = kwargs.pop("config_kwargs", None) + config_kwargs: dict | None = kwargs.pop("config_kwargs", None) if config_kwargs: # add config kwargs as needed. kwargs.update(config_kwargs) @@ -108,7 +100,7 @@ def _get_loading_kwargs(self) -> dict: return kwargs - def load_data(self, type_: Type) -> Tuple[Union[HF_types], Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[HF_types, dict[str, Any]]: """Loads the data set given the path and class values.""" ds = load_dataset(self.path, **self._get_loading_kwargs()) is_dataset = isinstance(ds, Dataset) @@ -128,16 +120,16 @@ def name(cls) -> str: class HuggingFaceDSParquetSaver(DataSaver): """Saves a Huggingface dataset to parquet.""" - path_or_buf: Union[PathLike, BinaryIO] - batch_size: Optional[int] = None - parquet_writer_kwargs: Optional[dict] = None + path_or_buf: PathLike | BinaryIO + batch_size: int | None = None + parquet_writer_kwargs: dict | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return list(HF_types) @classmethod - def applies_to(cls, type_: Type[Type]) -> bool: + def applies_to(cls, type_: type[type]) -> bool: return type_ in HF_types def _get_saving_kwargs(self) -> dict: @@ -145,14 +137,14 @@ def _get_saving_kwargs(self) -> dict: kwargs = dataclasses.asdict(self) # we put path_or_buff as a positional argument del kwargs["path_or_buf"] - parquet_writer_kwargs: Optional[dict] = kwargs.pop("parquet_writer_kwargs", None) + parquet_writer_kwargs: dict | None = kwargs.pop("parquet_writer_kwargs", None) if parquet_writer_kwargs: # add config kwargs as needed. kwargs.update(parquet_writer_kwargs) return kwargs - def save_data(self, ds: Union[HF_types]) -> Dict[str, Any]: + def save_data(self, ds: HF_types) -> dict[str, Any]: """Saves the data to parquet.""" is_dataset = isinstance(ds, Dataset) ds.to_parquet(self.path_or_buf, **self._get_saving_kwargs()) @@ -197,14 +189,14 @@ class HuggingFaceDSLanceDBSaver(DataSaver): db_client: lancedb.DBConnection table_name: str - columns_to_write: List[str] = None # None means all. + columns_to_write: list[str] = None # None means all. write_batch_size: int = 100 @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return list(HF_types) - def save_data(self, ds: Union[HF_types]) -> Dict[str, Any]: + def save_data(self, ds: HF_types) -> dict[str, Any]: """This batches writes to lancedb.""" ds.map( _batch_write, diff --git a/hamilton/plugins/ibis_extensions.py b/hamilton/plugins/ibis_extensions.py index e080ea03b..66fa493e0 100644 --- a/hamilton/plugins/ibis_extensions.py +++ b/hamilton/plugins/ibis_extensions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any from hamilton import registry @@ -65,7 +65,7 @@ def arg(cls) -> str: return "schema" @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: return issubclass(datatype, ir.Table) def description(self) -> str: diff --git a/hamilton/plugins/jupyter_magic.py b/hamilton/plugins/jupyter_magic.py index 4cf81e929..f41a3c1d0 100644 --- a/hamilton/plugins/jupyter_magic.py +++ b/hamilton/plugins/jupyter_magic.py @@ -21,7 +21,7 @@ import os from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any from IPython.core.magic import Magics, cell_magic, line_magic, magics_class from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring @@ -32,7 +32,7 @@ from hamilton import ad_hoc_utils, driver -def get_assigned_variables(module_node: ast.Module) -> Set[str]: +def get_assigned_variables(module_node: ast.Module) -> set[str]: """Get the set of variable names assigned in a AST Module""" assigned_vars = set() @@ -50,7 +50,7 @@ def visit_node(ast_node): return assigned_vars -def execute_and_get_assigned_values(shell: InteractiveShellApp, cell: str) -> Dict[str, Any]: +def execute_and_get_assigned_values(shell: InteractiveShellApp, cell: str) -> dict[str, Any]: """Execute source code from a cell in the user namespace and collect the values of all assigned variables into a dictionary. """ @@ -155,7 +155,7 @@ def determine_notebook_type() -> str: return "default" -def parse_known_argstring(magic_func, argstring) -> Tuple[argparse.Namespace, List[str]]: +def parse_known_argstring(magic_func, argstring) -> tuple[argparse.Namespace, list[str]]: """IPython magic arguments parsing doesn't allow unknown args. Used instead of IPython.core.magic_arguments.parse_argstring @@ -186,7 +186,7 @@ def __init__(self, **kwargs): self.notebook_env = determine_notebook_type() self.incremental_cells_state = defaultdict(dict) - def resolve_unknown_args_cell_to_module(self, unknown: List[str]): + def resolve_unknown_args_cell_to_module(self, unknown: list[str]): """Handle unknown arguments. It won't make the magic execution fail.""" # deprecated in V2 because it's less useful since `%%cell_to_module` can execute itself @@ -211,7 +211,7 @@ def resolve_unknown_args_cell_to_module(self, unknown: List[str]): if any(arg in ("-h", "--help") for arg in unknown): print(help(self.cell_to_module)) - def resolve_config_arg(self, config_arg) -> Union[bool, dict]: + def resolve_config_arg(self, config_arg) -> bool | dict: # default case: didn't receive `-c/--config`. Set an empty dict if config_arg is None: config = {} diff --git a/hamilton/plugins/kedro_extensions.py b/hamilton/plugins/kedro_extensions.py index 64b7738da..94e6f027b 100644 --- a/hamilton/plugins/kedro_extensions.py +++ b/hamilton/plugins/kedro_extensions.py @@ -16,7 +16,8 @@ # under the License. import dataclasses -from typing import Any, Collection, Dict, Optional, Tuple, Type +from collections.abc import Collection +from typing import Any from kedro.io import DataCatalog @@ -51,10 +52,10 @@ class KedroSaver(DataSaver): catalog: DataCatalog @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] - def save_data(self, data: Any) -> Dict[str, Any]: + def save_data(self, data: Any) -> dict[str, Any]: self.catalog.save(self.dataset_name, data) return dict(success=True) @@ -87,13 +88,13 @@ class KedroLoader(DataLoader): dataset_name: str catalog: DataCatalog - version: Optional[str] = None + version: str | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] - def load_data(self, type_: Type) -> Tuple[Any, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[Any, dict[str, Any]]: data = self.catalog.load(self.dataset_name, self.version) metadata = dict(dataset_name=self.dataset_name, version=self.version) return data, metadata diff --git a/hamilton/plugins/lightgbm_extensions.py b/hamilton/plugins/lightgbm_extensions.py index 742e50d24..f25e27a85 100644 --- a/hamilton/plugins/lightgbm_extensions.py +++ b/hamilton/plugins/lightgbm_extensions.py @@ -16,8 +16,9 @@ # under the License. import dataclasses +from collections.abc import Collection from pathlib import Path -from typing import Any, Collection, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, Literal, Union try: import lightgbm @@ -37,16 +38,16 @@ class LightGBMFileWriter(DataSaver): """Write LighGBM models and boosters to a file""" - path: Union[str, Path] - num_iteration: Optional[int] = None + path: str | Path + num_iteration: int | None = None start_iteration: int = 0 importance_type: Literal["split", "gain"] = "split" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return LIGHTGBM_MODEL_TYPES - def save_data(self, data: LIGHTGBM_MODEL_TYPES_ANNOTATION) -> Dict[str, Any]: + def save_data(self, data: LIGHTGBM_MODEL_TYPES_ANNOTATION) -> dict[str, Any]: if isinstance(data, lightgbm.LGBMModel): data = data.booster_ @@ -67,15 +68,15 @@ def name(cls) -> str: class LightGBMFileReader(DataLoader): """Load LighGBM models and boosters from a file""" - path: Union[str, Path] + path: str | Path @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return LIGHTGBM_MODEL_TYPES def load_data( - self, type_: Type - ) -> Tuple[Union[lightgbm.Booster, lightgbm.CVBooster], Dict[str, Any]]: + self, type_: type + ) -> tuple[lightgbm.Booster | lightgbm.CVBooster, dict[str, Any]]: model = type_(model_file=self.path) metadata = utils.get_file_metadata(self.path) return model, metadata diff --git a/hamilton/plugins/matplotlib_extensions.py b/hamilton/plugins/matplotlib_extensions.py index aaf9ba9bf..ccec535fa 100644 --- a/hamilton/plugins/matplotlib_extensions.py +++ b/hamilton/plugins/matplotlib_extensions.py @@ -16,8 +16,9 @@ # under the License. import dataclasses +from collections.abc import Collection from os import PathLike -from typing import IO, Any, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import IO, Any try: from matplotlib.artist import Artist @@ -37,20 +38,20 @@ class MatplotlibWriter(DataSaver): ref: https://matplotlib.org/stable/api/figure_api.html#matplotlib.figure.Figure """ - path: Union[str, PathLike, IO] - dpi: Optional[Union[float, str]] = None - format: Optional[str] = None - metadata: Optional[Dict] = None - bbox_inches: Optional[Union[str, Bbox]] = None - pad_inches: Optional[Union[float, str]] = None - facecolor: Optional[Union[str, float, Tuple]] = None - edgecolor: Optional[Union[str, float, Tuple]] = None - backend: Optional[str] = None - orientation: Optional[str] = None - papertype: Optional[str] = None - transparent: Optional[bool] = None - bbox_extra_artists: Optional[List[Artist]] = None - pil_kwargs: Optional[Dict] = None + path: str | PathLike | IO + dpi: float | str | None = None + format: str | None = None + metadata: dict | None = None + bbox_inches: str | Bbox | None = None + pad_inches: float | str | None = None + facecolor: str | float | tuple | None = None + edgecolor: str | float | tuple | None = None + backend: str | None = None + orientation: str | None = None + papertype: str | None = None + transparent: bool | None = None + bbox_extra_artists: list[Artist] | None = None + pil_kwargs: dict | None = None def _get_saving_kwargs(self) -> dict: kwargs = {} @@ -81,13 +82,13 @@ def _get_saving_kwargs(self) -> dict: return kwargs - def save_data(self, data: Figure) -> Dict[str, Any]: + def save_data(self, data: Figure) -> dict[str, Any]: data.savefig(fname=self.path, **self._get_saving_kwargs()) # TODO make utils.get_file_metadata() safer for when self.path is IO type return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Figure] @classmethod diff --git a/hamilton/plugins/mlflow_extensions.py b/hamilton/plugins/mlflow_extensions.py index cddbe96c2..8a4dd930b 100644 --- a/hamilton/plugins/mlflow_extensions.py +++ b/hamilton/plugins/mlflow_extensions.py @@ -17,8 +17,9 @@ import dataclasses import pathlib +from collections.abc import Collection from types import ModuleType -from typing import Any, Collection, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, Literal try: import mlflow @@ -40,11 +41,11 @@ class MLFlowModelSaver(DataSaver): :param mlflow_kwargs: Arguments for `.log_model()`. Can be flavor-specific. """ - path: Union[str, pathlib.Path] = "model" - register_as: Optional[str] = None - flavor: Optional[Union[str, ModuleType]] = None - run_id: Optional[str] = None - mlflow_kwargs: Dict[str, Any] = None + path: str | pathlib.Path = "model" + register_as: str | None = None + flavor: str | ModuleType | None = None + run_id: str | None = None + mlflow_kwargs: dict[str, Any] = None def __post_init__(self): self.mlflow_kwargs = self.mlflow_kwargs or {} @@ -54,10 +55,10 @@ def name(cls) -> str: return "mlflow" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] - def save_data(self, data) -> Dict[str, Any]: + def save_data(self, data) -> dict[str, Any]: if self.flavor: flavor = self.flavor else: @@ -125,15 +126,15 @@ class MLFlowModelLoader(DataLoader): :param mlflow_kwargs: Arguments for `.load_model()`. Can be flavor-specific. """ - model_uri: Optional[str] = None + model_uri: str | None = None mode: Literal["tracking", "registry"] = "tracking" - run_id: Optional[str] = None - path: Union[str, pathlib.Path] = "model" - model_name: Optional[str] = None - version: Optional[Union[str, int]] = None - version_alias: Optional[str] = None - flavor: Optional[Union[ModuleType, str]] = None - mlflow_kwargs: Dict[str, Any] = None + run_id: str | None = None + path: str | pathlib.Path = "model" + model_name: str | None = None + version: str | int | None = None + version_alias: str | None = None + flavor: ModuleType | str | None = None + mlflow_kwargs: dict[str, Any] = None # __post_init__ is required to set kwargs as empty dict because # can't set: kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -170,10 +171,10 @@ def name(cls) -> str: return "mlflow" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [Any] - def load_data(self, type_: Type) -> Tuple[Any, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[Any, dict[str, Any]]: model_info = mlflow.models.model.get_model_info(self.model_uri) metadata = {k.strip("_"): v for k, v in model_info.__dict__.items()} diff --git a/hamilton/plugins/numpy_extensions.py b/hamilton/plugins/numpy_extensions.py index c4061576b..18b058c64 100644 --- a/hamilton/plugins/numpy_extensions.py +++ b/hamilton/plugins/numpy_extensions.py @@ -17,7 +17,8 @@ import dataclasses import pathlib -from typing import IO, Any, Collection, Dict, Optional, Tuple, Type, Union +from collections.abc import Collection +from typing import IO, Any try: import numpy as np @@ -37,11 +38,11 @@ class NumpyNpyWriter(DataSaver): ref: https://numpy.org/doc/stable/reference/routines.io.html """ - path: Union[str, pathlib.Path, IO] - allow_pickle: Optional[bool] = None - fix_imports: Optional[bool] = None + path: str | pathlib.Path | IO + allow_pickle: bool | None = None + fix_imports: bool | None = None - def save_data(self, data: np.ndarray) -> Dict[str, Any]: + def save_data(self, data: np.ndarray) -> dict[str, Any]: np.save( file=self.path, arr=data, @@ -51,7 +52,7 @@ def save_data(self, data: np.ndarray) -> Dict[str, Any]: return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [np.ndarray] @classmethod @@ -65,17 +66,17 @@ class NumpyNpyReader(DataLoader): ref: https://numpy.org/doc/stable/reference/routines.io.html """ - path: Union[str, pathlib.Path, IO] - mmap_mode: Optional[str] = None - allow_pickle: Optional[bool] = None - fix_imports: Optional[bool] = None + path: str | pathlib.Path | IO + mmap_mode: str | None = None + allow_pickle: bool | None = None + fix_imports: bool | None = None encoding: Literal["ASCII", "latin1", "bytes"] = "ASCII" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [np.ndarray] - def load_data(self, type_: Type) -> Tuple[np.ndarray, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[np.ndarray, dict[str, Any]]: array = np.load( file=self.path, mmap_mode=self.mmap_mode, diff --git a/hamilton/plugins/pandas_extensions.py b/hamilton/plugins/pandas_extensions.py index b4ce2786d..0937cce9d 100644 --- a/hamilton/plugins/pandas_extensions.py +++ b/hamilton/plugins/pandas_extensions.py @@ -18,31 +18,27 @@ import abc import csv import dataclasses -from collections.abc import Hashable +from collections.abc import Callable, Collection, Hashable, Iterator from datetime import datetime from io import BufferedReader, BytesIO, StringIO from pathlib import Path -from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, TypeAlias try: import pandas as pd except ImportError as e: raise NotImplementedError("Pandas is not installed.") from e +from collections.abc import Iterable, Mapping, Sequence from typing import Literal -try: - from collections.abc import Iterable, Mapping, Sequence -except ImportError: - from collections import Iterable, Mapping, Sequence - try: import fsspec import pyarrow.fs - FILESYSTEM_TYPE = Optional[Union[pyarrow.fs.FileSystem, fsspec.spec.AbstractFileSystem]] + FILESYSTEM_TYPE = pyarrow.fs.FileSystem | fsspec.spec.AbstractFileSystem | None except ImportError: - FILESYSTEM_TYPE = Optional[Type] + FILESYSTEM_TYPE = type | None from sqlite3 import Connection @@ -56,9 +52,9 @@ DATAFRAME_TYPE = pd.DataFrame COLUMN_TYPE = pd.Series -JSONSerializable = Optional[Union[str, float, bool, List, Dict]] -IndexLabel = Optional[Union[Hashable, Iterator[Hashable]]] -Dtype = Union[ExtensionDtype, NpDtype] +JSONSerializable: TypeAlias = str | float | bool | list | dict | None +IndexLabel: TypeAlias = Hashable | Iterator[Hashable] | None +Dtype: TypeAlias = ExtensionDtype | NpDtype @registry.get_column.register(pd.DataFrame) @@ -91,15 +87,15 @@ class DataFrameDataLoader(DataLoader, DataSaver, abc.ABC): we are good to go.""" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] @abc.abstractmethod - def load_data(self, type_: Type[DATAFRAME_TYPE]) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type[DATAFRAME_TYPE]) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: pass @abc.abstractmethod - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: pass @@ -111,65 +107,64 @@ class PandasCSVReader(DataLoader): """ # the filepath_or_buffer param will be changed to path for backwards compatibility - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs - sep: Union[str, None] = "," - delimiter: Optional[str] = None - header: Union[Sequence, int, Literal["infer"], None] = "infer" - names: Optional[Sequence] = None - index_col: Optional[Union[Hashable, Sequence, Literal[False]]] = None - usecols: Optional[Union[List[Hashable], Callable, tuple]] = None - dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None - engine: Optional[Literal["c", "python", "pyarrow", "python-fwf"]] = None - converters: Optional[Mapping] = None - true_values: Optional[List] = None - false_values: Optional[List] = None - skipinitialspace: Optional[bool] = False - skiprows: Optional[Union[List[int], int, Callable[[Hashable], bool]]] = None + sep: str | None = "," + delimiter: str | None = None + header: Sequence | int | Literal["infer"] | None = "infer" + names: Sequence | None = None + index_col: Hashable | Sequence | Literal[False] | None = None + usecols: list[Hashable] | Callable | tuple | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None + engine: Literal["c", "python", "pyarrow", "python-fwf"] | None = None + converters: Mapping | None = None + true_values: list | None = None + false_values: list | None = None + skipinitialspace: bool | None = False + skiprows: list[int] | int | Callable[[Hashable], bool] | None = None skipfooter: int = 0 - nrows: Optional[int] = None - na_values: Optional[Union[Hashable, Iterable, Mapping]] = None + nrows: int | None = None + na_values: Hashable | Iterable | Mapping | None = None keep_default_na: bool = True na_filter: bool = True verbose: bool = False skip_blank_lines: bool = True - parse_dates: Optional[Union[bool, Sequence, None]] = False + parse_dates: bool | Sequence | None | None = False keep_date_col: bool = False - date_format: Optional[str] = None + date_format: str | None = None dayfirst: bool = False cache_dates: bool = True iterator: bool = False - chunksize: Optional[int] = None - compression: Optional[ - Union[Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"], Dict[str, Any]] - ] = "infer" - thousands: Optional[str] = None + chunksize: int | None = None + compression: ( + Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] | dict[str, Any] | None + ) = "infer" + thousands: str | None = None decimal: str = "." - lineterminator: Optional[str] = None - quotechar: Optional[str] = None + lineterminator: str | None = None + quotechar: str | None = None quoting: int = 0 doublequote: bool = True - escapechar: Optional[str] = None - comment: Optional[str] = None + escapechar: str | None = None + comment: str | None = None encoding: str = "utf-8" - encoding_errors: Union[ - Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"], - str, - ] = "strict" - dialect: Optional[Union[str, csv.Dialect]] = None - on_bad_lines: Union[Literal["error", "warn", "skip"], Callable] = "error" + encoding_errors: ( + Literal["strict", "ignore", "replace", "backslashreplace", "surrogateescape"] | str + ) = "strict" + dialect: str | csv.Dialect | None = None + on_bad_lines: Literal["error", "warn", "skip"] | Callable = "error" delim_whitespace: bool = False low_memory: bool = True memory_map: bool = False - float_precision: Optional[Literal["high", "legacy", "round_trip"]] = None - storage_options: Optional[Dict[str, Any]] = None + float_precision: Literal["high", "legacy", "round_trip"] | None = None + storage_options: dict[str, Any] | None = None dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.sep is not None: kwargs["sep"] = self.sep @@ -262,7 +257,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pd.read_csv(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) return df, metadata @@ -278,36 +273,36 @@ class PandasCSVWriter(DataSaver): Maps to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_csv.html """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs - sep: Union[str, None] = "," + sep: str | None = "," na_rep: str = "" - float_format: Optional[Union[str, Callable]] = None - columns: Optional[Sequence] = None - header: Optional[Union[bool, List[str]]] = True - index: Optional[bool] = False - index_label: Optional[IndexLabel] = None + float_format: str | Callable | None = None + columns: Sequence | None = None + header: bool | list[str] | None = True + index: bool | None = False + index_label: IndexLabel | None = None mode: str = "w" - encoding: Optional[str] = None - compression: Optional[ - Union[Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"], Dict[str, Any]] - ] = "infer" - quoting: Optional[int] = None - quotechar: Optional[str] = '"' - lineterminator: Optional[str] = None - chunksize: Optional[int] = None - date_format: Optional[str] = None + encoding: str | None = None + compression: ( + Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] | dict[str, Any] | None + ) = "infer" + quoting: int | None = None + quotechar: str | None = '"' + lineterminator: str | None = None + chunksize: int | None = None + date_format: str | None = None doublequote: bool = True - escapechar: Optional[str] = None + escapechar: str | None = None decimal: str = "." errors: str = "strict" - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = {} if self.sep is not None: @@ -353,7 +348,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_csv(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -368,18 +363,18 @@ class PandasParquetReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html#pandas.read_parquet """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs engine: Literal["auto", "pyarrow", "fastparquet"] = "auto" - columns: Optional[List[str]] = None - storage_options: Optional[Dict[str, Any]] = None + columns: list[str] | None = None + storage_options: dict[str, Any] | None = None use_nullable_dtypes: bool = False dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable" - filesystem: Optional[str] = None - filters: Optional[Union[List[Tuple], List[List[Tuple]]]] = None + filesystem: str | None = None + filters: list[tuple] | list[list[tuple]] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -401,7 +396,7 @@ def _get_loading_kwargs(self): return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the pickle df = pd.read_parquet(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) @@ -418,20 +413,20 @@ class PandasParquetWriter(DataSaver): Maps to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_parquet.html#pandas.DataFrame.to_parquet """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs engine: Literal["auto", "pyarrow", "fastparquet"] = "auto" - compression: Optional[str] = "snappy" - index: Optional[bool] = None - partition_cols: Optional[List[str]] = None - storage_options: Optional[Dict[str, Any]] = None - extra_kwargs: Optional[Dict[str, Any]] = None + compression: str | None = "snappy" + index: bool | None = None + partition_cols: list[str] | None = None + storage_options: dict[str, Any] | None = None + extra_kwargs: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = {} if self.engine is not None: @@ -448,7 +443,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: kwargs.update(self.extra_kwargs) return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_parquet(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -463,20 +458,20 @@ class PandasPickleReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_pickle.html#pandas.read_pickle """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] = None - path: Union[str, Path, BytesIO, BufferedReader] = ( + filepath_or_buffer: str | Path | BytesIO | BufferedReader = None + path: str | Path | BytesIO | BufferedReader = ( None # alias for `filepath_or_buffer` to keep reading/writing args symmetric. ) # kwargs: - compression: Union[str, Dict[str, Any], None] = "infer" - storage_options: Optional[Dict[str, Any]] = None + compression: str | dict[str, Any] | None = "infer" + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: # Returns type for which data loader is available return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = {} if self.compression is not None: @@ -485,7 +480,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: kwargs["storage_options"] = self.storage_options return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the pickle df = pd.read_pickle(self.filepath_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, df) @@ -518,17 +513,17 @@ class PandasPickleWriter(DataSaver): Maps to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_pickle.html#pandas.DataFrame.to_pickle """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs: - compression: Union[str, Dict[str, Any], None] = "infer" + compression: str | dict[str, Any] | None = "infer" protocol: int = pickle_protocol_default - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = {} if self.compression is not None: @@ -539,7 +534,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: kwargs["storage_options"] = self.storage_options return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_pickle(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -559,31 +554,31 @@ class PandasJsonReader(DataLoader): Should map to https://pandas.pydata.org/docs/reference/api/pandas.read_json.html """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] + filepath_or_buffer: str | Path | BytesIO | BufferedReader # kwargs - chunksize: Optional[int] = None - compression: Optional[Union[str, Dict[str, Any]]] = "infer" - convert_axes: Optional[bool] = None - convert_dates: Union[bool, List[str]] = True - date_unit: Optional[str] = None - dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None - dtype_backend: Optional[str] = None - encoding: Optional[str] = None - encoding_errors: Optional[str] = "strict" + chunksize: int | None = None + compression: str | dict[str, Any] | None = "infer" + convert_axes: bool | None = None + convert_dates: bool | list[str] = True + date_unit: str | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None + dtype_backend: str | None = None + encoding: str | None = None + encoding_errors: str | None = "strict" engine: str = "ujson" keep_default_dates: bool = True lines: bool = False - nrows: Optional[int] = None - orient: Optional[str] = None + nrows: int | None = None + orient: str | None = None precise_float: bool = False - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None typ: str = "frame" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.chunksize is not None: kwargs["chunksize"] = self.chunksize @@ -621,7 +616,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: kwargs["typ"] = self.typ return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pd.read_json(self.filepath_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, df) return df, metadata @@ -642,23 +637,23 @@ class PandasJsonWriter(DataSaver): Should map to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] + filepath_or_buffer: str | Path | BytesIO | BufferedReader # kwargs compression: str = "infer" date_format: str = "epoch" date_unit: str = "ms" - default_handler: Optional[Callable[[Any], JSONSerializable]] = None + default_handler: Callable[[Any], JSONSerializable] | None = None double_precision: int = 10 force_ascii: bool = True - index: Optional[bool] = None + index: bool | None = None indent: int = 0 lines: bool = False mode: str = "w" - orient: Optional[str] = None - storage_options: Optional[Dict[str, Any]] = None + orient: str | None = None + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -689,7 +684,7 @@ def _get_saving_kwargs(self): kwargs["storage_options"] = self.storage_options return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_json(self.filepath_or_buffer, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, data) @@ -711,22 +706,22 @@ class PandasSqlReader(DataLoader): """ query_or_table: str - db_connection: Union[str, Connection] # can pass in SQLAlchemy engine/connection + db_connection: str | Connection # can pass in SQLAlchemy engine/connection # kwarg - chunksize: Optional[int] = None + chunksize: int | None = None coerce_float: bool = True - columns: Optional[List[str]] = None - dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None - dtype_backend: Optional[str] = None - index_col: Optional[Union[str, List[str]]] = None - params: Optional[Union[List, Tuple, Dict]] = None - parse_dates: Optional[Union[List, Dict]] = None + columns: list[str] | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None + dtype_backend: str | None = None + index_col: str | list[str] | None = None + params: list | tuple | dict | None = None + parse_dates: list | dict | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.chunksize is not None: kwargs["chunksize"] = self.chunksize @@ -746,7 +741,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: kwargs["parse_dates"] = self.parse_dates return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pd.read_sql(self.query_or_table, self.db_connection, **self._get_loading_kwargs()) sql_metadata = utils.get_sql_metadata(self.query_or_table, df) df_metadata = utils.get_dataframe_metadata(df) @@ -772,19 +767,19 @@ class PandasSqlWriter(DataSaver): table_name: str db_connection: Any # can pass in SQLAlchemy engine/connection # kwargs - chunksize: Optional[int] = None - dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None + chunksize: int | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None if_exists: str = "fail" index: bool = True - index_label: Optional[IndexLabel] = None - method: Optional[Union[str, Callable]] = None - schema: Optional[str] = None + index_label: IndexLabel | None = None + method: str | Callable | None = None + schema: str | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: kwargs = {} if self.chunksize is not None: kwargs["chunksize"] = self.chunksize @@ -802,7 +797,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: kwargs["schema"] = self.schema return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: results = data.to_sql(self.table_name, self.db_connection, **self._get_saving_kwargs()) sql_metadata = utils.get_sql_metadata(self.table_name, results) df_metadata = utils.get_dataframe_metadata(data) @@ -821,29 +816,29 @@ class PandasXmlReader(DataLoader): Requires `lxml`. See https://pandas.pydata.org/docs/getting_started/install.html#xml """ - path_or_buffer: Union[str, Path, BytesIO, BufferedReader] + path_or_buffer: str | Path | BytesIO | BufferedReader # kwargs - xpath: Optional[str] = "./*" - namespace: Optional[Dict[str, str]] = None - elems_only: Optional[bool] = False - attrs_only: Optional[bool] = False - names: Optional[List[str]] = None - dtype: Optional[Dict[str, Any]] = None - converters: Optional[Dict[Union[int, str], Any]] = None - parse_dates: Union[bool, List[Union[int, str, List[List], Dict[str, List[int]]]]] = False - encoding: Optional[str] = "utf-8" + xpath: str | None = "./*" + namespace: dict[str, str] | None = None + elems_only: bool | None = False + attrs_only: bool | None = False + names: list[str] | None = None + dtype: dict[str, Any] | None = None + converters: dict[int | str, Any] | None = None + parse_dates: bool | list[int | str | list[list] | dict[str, list[int]]] = False + encoding: str | None = "utf-8" parser: str = "lxml" - stylesheet: Union[str, Path, BytesIO, BufferedReader] = None - iterparse: Optional[Dict[str, List[str]]] = None - compression: Union[str, Dict[str, Any], None] = "infer" - storage_options: Optional[Dict[str, Any]] = None + stylesheet: str | Path | BytesIO | BufferedReader = None + iterparse: dict[str, list[str]] | None = None + compression: str | dict[str, Any] | None = "infer" + storage_options: dict[str, Any] | None = None dtype_backend: str = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.xpath is not None: kwargs["xpath"] = self.xpath @@ -881,7 +876,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: kwargs["dtype_backend"] = self.dtype_backend return kwargs - def load_data(self, type: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the xml df = pd.read_xml(self.path_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path_or_buffer, df) @@ -900,26 +895,26 @@ class PandasXmlWriter(DataSaver): Requires `lxml`. See https://pandas.pydata.org/docs/getting_started/install.html#xml. """ - path_or_buffer: Union[str, Path, BytesIO, BufferedReader] + path_or_buffer: str | Path | BytesIO | BufferedReader # kwargs index: bool = True root_name: str = "data" row_name: str = "row" - na_rep: Optional[str] = None - attr_cols: Optional[List[str]] = None - elems_cols: Optional[List[str]] = None - namespaces: Optional[Dict[str, str]] = None - prefix: Optional[str] = None + na_rep: str | None = None + attr_cols: list[str] | None = None + elems_cols: list[str] | None = None + namespaces: dict[str, str] | None = None + prefix: str | None = None encoding: str = "utf-8" xml_declaration: bool = True pretty_print: bool = True parser: str = "lxml" - stylesheet: Optional[Union[str, Path, BytesIO, BufferedReader]] = None - compression: Union[str, Dict[str, Any], None] = "infer" - storage_options: Optional[Dict[str, Any]] = None + stylesheet: str | Path | BytesIO | BufferedReader | None = None + compression: str | dict[str, Any] | None = "infer" + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -956,7 +951,7 @@ def _get_saving_kwargs(self): kwargs["storage_options"] = self.storage_options return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_xml(self.path_or_buffer, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path_or_buffer, data) @@ -971,31 +966,31 @@ class PandasHtmlReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_html.html """ - io: Union[str, Path, BytesIO, BufferedReader] + io: str | Path | BytesIO | BufferedReader # kwargs - match: Optional[str] = ".+" - flavor: Optional[Union[str, Sequence]] = None - header: Optional[Union[int, Sequence]] = None - index_col: Optional[Union[int, Sequence]] = None - skiprows: Optional[Union[int, Sequence, slice]] = None - attrs: Optional[Dict[str, str]] = None - parse_dates: Optional[bool] = None - thousands: Optional[str] = "," - encoding: Optional[str] = None + match: str | None = ".+" + flavor: str | Sequence | None = None + header: int | Sequence | None = None + index_col: int | Sequence | None = None + skiprows: int | Sequence | slice | None = None + attrs: dict[str, str] | None = None + parse_dates: bool | None = None + thousands: str | None = "," + encoding: str | None = None decimal: str = "." - converters: Optional[Dict[Any, Any]] = None + converters: dict[Any, Any] | None = None na_values: Iterable = None keep_default_na: bool = True displayed_only: bool = True - extract_links: Optional[Literal["header", "footer", "body", "all"]] = None + extract_links: Literal["header", "footer", "body", "all"] | None = None dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.match is not None: kwargs["match"] = self.match @@ -1034,7 +1029,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type: Type) -> Tuple[List[DATAFRAME_TYPE], Dict[str, Any]]: + def load_data(self, type: type) -> tuple[list[DATAFRAME_TYPE], dict[str, Any]]: # Loads the data and returns the df and metadata of the xml df = pd.read_html(self.io, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.io, df[0]) @@ -1051,33 +1046,33 @@ class PandasHtmlWriter(DataSaver): Should map to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_html.html#pandas.DataFrame.to_html """ - buf: Union[str, Path, StringIO, None] = None + buf: str | Path | StringIO | None = None # kwargs - columns: Optional[List[str]] = None - col_space: Optional[Union[str, int, List, Dict]] = None - header: Optional[bool] = True - index: Optional[bool] = True - na_rep: Optional[str] = "NaN" - formatters: Optional[Union[List, Tuple, Dict]] = None - float_format: Optional[str] = None - sparsify: Optional[bool] = True - index_names: Optional[bool] = True + columns: list[str] | None = None + col_space: str | int | list | dict | None = None + header: bool | None = True + index: bool | None = True + na_rep: str | None = "NaN" + formatters: list | tuple | dict | None = None + float_format: str | None = None + sparsify: bool | None = True + index_names: bool | None = True justify: str = None - max_rows: Optional[int] = None - max_cols: Optional[int] = None + max_rows: int | None = None + max_cols: int | None = None show_dimensions: bool = False decimal: str = "." bold_rows: bool = True - classes: Union[str, List[str], Tuple, None] = None - escape: Optional[bool] = True + classes: str | list[str] | tuple | None = None + escape: bool | None = True notebook: Literal[True, False] = False border: int = None - table_id: Optional[str] = None + table_id: str | None = None render_links: bool = False - encoding: Optional[str] = "utf-8" + encoding: str | None = "utf-8" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -1129,7 +1124,7 @@ def _get_saving_kwargs(self): return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_html(self.buf, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.buf, data) @@ -1144,27 +1139,27 @@ class PandasStataReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_stata.html#pandas.read_stata """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] + filepath_or_buffer: str | Path | BytesIO | BufferedReader # kwargs convert_dates: bool = True convert_categoricals: bool = True - index_col: Optional[str] = None + index_col: str | None = None convert_missing: bool = False preserve_dtypes: bool = True - columns: Optional[Sequence] = None + columns: Sequence | None = None order_categoricals: bool = True - chunksize: Optional[int] = None + chunksize: int | None = None iterator: bool = False - compression: Union[ - Dict[str, Any], Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] - ] = "infer" - storage_options: Optional[Dict[str, Any]] = None + compression: dict[str, Any] | Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] = ( + "infer" + ) + storage_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.convert_dates is not None: kwargs["convert_dates"] = self.convert_dates @@ -1191,7 +1186,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the xml df = pd.read_stata(self.filepath_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, df) @@ -1208,24 +1203,24 @@ class PandasStataWriter(DataSaver): Should map to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_stata.html """ - path: Union[str, Path, BufferedReader] = None + path: str | Path | BufferedReader = None # kwargs - convert_dates: Optional[Dict[Hashable, str]] = None + convert_dates: dict[Hashable, str] | None = None write_index: bool = True - byteorder: Optional[str] = None - time_stamp: Optional[datetime] = None - data_label: Optional[str] = None - variable_labels: Optional[Dict[Hashable, str]] = None + byteorder: str | None = None + time_stamp: datetime | None = None + data_label: str | None = None + variable_labels: dict[Hashable, str] | None = None version: Literal[114, 117, 118, 119] = 114 - convert_strl: Optional[str] = None - compression: Union[ - Dict[str, Any], Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] - ] = "infer" - storage_options: Optional[Dict[str, Any]] = None - value_labels: Optional[Dict[Hashable, str]] = None + convert_strl: str | None = None + compression: dict[str, Any] | Literal["infer", "gzip", "bz2", "zip", "xz", "zstd", "tar"] = ( + "infer" + ) + storage_options: dict[str, Any] | None = None + value_labels: dict[Hashable, str] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -1255,7 +1250,7 @@ def _get_saving_kwargs(self): return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_stata(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -1270,18 +1265,18 @@ class PandasFeatherReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_feather.html """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs - columns: Optional[Sequence] = None + columns: Sequence | None = None use_threads: bool = True - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.columns is not None: kwargs["columns"] = self.columns @@ -1294,7 +1289,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the xml df = pd.read_feather(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) @@ -1314,16 +1309,16 @@ class PandasFeatherWriter(DataSaver): Requires `lz4` https://pypi.org/project/lz4/ """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs - dest: Optional[str] = None + dest: str | None = None compression: Literal["zstd", "lz4", "uncompressed"] = None - compression_level: Optional[int] = None - chunksize: Optional[int] = None - version: Optional[int] = 2 + compression_level: int | None = None + chunksize: int | None = None + version: int | None = 2 @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -1341,7 +1336,7 @@ def _get_saving_kwargs(self): return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_feather(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -1357,17 +1352,17 @@ class PandasORCReader(DataLoader): Maps to: https://pandas.pydata.org/docs/reference/api/pandas.read_orc.html#pandas.read_orc """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs - columns: Optional[List[str]] = None + columns: list[str] | None = None dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" - filesystem: Optional[FILESYSTEM_TYPE] = None + filesystem: FILESYSTEM_TYPE | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: kwargs = {} if self.columns is not None: kwargs["columns"] = self.columns @@ -1378,7 +1373,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the orc df = pd.read_orc(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) @@ -1396,14 +1391,14 @@ class PandasORCWriter(DataSaver): Maps to: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_orc.html """ - path: Union[str, Path, BytesIO, BufferedReader] + path: str | Path | BytesIO | BufferedReader # kwargs engine: Literal["pyarrow"] = "pyarrow" - index: Optional[bool] = None - engine_kwargs: Optional[Union[Dict[str, Any], None]] = None + index: bool | None = None + engine_kwargs: dict[str, Any] | None | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_saving_kwargs(self): @@ -1417,7 +1412,7 @@ def _get_saving_kwargs(self): return kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: data.to_orc(self.path, **self._get_saving_kwargs()) return utils.get_file_and_dataframe_metadata(self.path, data) @@ -1432,44 +1427,44 @@ class PandasExcelReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html """ - path: Union[str, Path, BytesIO, BufferedReader] = None + path: str | Path | BytesIO | BufferedReader = None # kwargs: # inspect.get_type_hints doesn't work with type aliases, # which are used in pandas.read_excel. # So we have to list all the arguments in plain code. - sheet_name: Union[str, int, List[Union[int, str]], None] = 0 - header: Union[int, Sequence, None] = 0 - names: Optional[Sequence] = None - index_col: Union[int, str, Sequence, None] = None - usecols: Union[int, str, Sequence, Sequence, Callable[[str], bool], None] = None - dtype: Union[Dtype, Dict[Hashable, Dtype], None] = None - engine: Optional[Literal["xlrd", "openpyxl", "odf", "pyxlsb", "calamine"]] = None - converters: Union[Dict[str, Callable], Dict[int, Callable], None] = None - true_values: Optional[Iterable] = None - false_values: Optional[Iterable] = None - skiprows: Union[Sequence, int, Callable[[int], object], None] = None - nrows: Optional[int] = None + sheet_name: str | int | list[int | str] | None = 0 + header: int | Sequence | None = 0 + names: Sequence | None = None + index_col: int | str | Sequence | None = None + usecols: int | str | Sequence | Sequence | Callable[[str], bool] | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None + engine: Literal["xlrd", "openpyxl", "odf", "pyxlsb", "calamine"] | None = None + converters: dict[str, Callable] | dict[int, Callable] | None = None + true_values: Iterable | None = None + false_values: Iterable | None = None + skiprows: Sequence | int | Callable[[int], object] | None = None + nrows: int | None = None na_values = None # in pandas.read_excel there are not type hints for na_values keep_default_na: bool = True na_filter: bool = True verbose: bool = False - parse_dates: Union[List[Union[int, str]], Dict[str, List[Union[int, str]]], bool] = False + parse_dates: list[int | str] | dict[str, list[int | str]] | bool = False # date_parser: Optional[Callable] # date_parser is deprecated since pandas=2.0.0 - date_format: Union[Dict[Hashable, str], str, None] = None - thousands: Optional[str] = None + date_format: dict[Hashable, str] | str | None = None + thousands: str | None = None decimal: str = "." - comment: Optional[str] = None + comment: str | None = None skipfooter: int = 0 - storage_options: Optional[Dict[str, Any]] = None + storage_options: dict[str, Any] | None = None dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" - engine_kwargs: Optional[Dict[str, Any]] = None + engine_kwargs: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: # Returns type for which data loader is available return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = dataclasses.asdict(self) @@ -1484,7 +1479,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the excel file df = pd.read_excel(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) @@ -1502,36 +1497,36 @@ class PandasExcelWriter(DataSaver): Additional parameters passed to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_excel.html """ - path: Union[str, Path, BytesIO] + path: str | Path | BytesIO # kwargs: # inspect.get_type_hints doesn't work with type aliases, # which are used in pandas.DataFrame.to_excel. # So we have to list all the arguments in plain code sheet_name: str = "Sheet1" na_rep: str = "" - float_format: Optional[str] = None - columns: Optional[Sequence] = None - header: Union[Sequence, bool] = True + float_format: str | None = None + columns: Sequence | None = None + header: Sequence | bool = True index: bool = True - index_label: Optional[IndexLabel] = None + index_label: IndexLabel | None = None startrow: int = 0 startcol: int = 0 - engine: Optional[Literal["openpyxl", "xlsxwriter"]] = None + engine: Literal["openpyxl", "xlsxwriter"] | None = None merge_cells: bool = True inf_rep: str = "inf" - freeze_panes: Optional[Tuple[int, int]] = None - storage_options: Optional[Dict[str, Any]] = None - engine_kwargs: Optional[Dict[str, Any]] = None - mode: Optional[Literal["w", "a"]] = "w" - if_sheet_exists: Optional[Literal["error", "new", "replace", "overlay"]] = None + freeze_panes: tuple[int, int] | None = None + storage_options: dict[str, Any] | None = None + engine_kwargs: dict[str, Any] | None = None + mode: Literal["w", "a"] | None = "w" + if_sheet_exists: Literal["error", "new", "replace", "overlay"] | None = None datetime_format: str = None date_format: str = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = dataclasses.asdict(self) @@ -1562,7 +1557,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: return writer_kwargs, to_excel_kwargs - def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE) -> dict[str, Any]: writer_kwargs, to_excel_kwargs = self._get_saving_kwargs() with pd.ExcelWriter(self.path, **writer_kwargs) as writer: @@ -1580,62 +1575,62 @@ class PandasTableReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_table.html """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] + filepath_or_buffer: str | Path | BytesIO | BufferedReader # kwargs - sep: Union[str, None] = None - delimiter: Optional[str] = None - header: Union[int, Sequence, str, None] = "infer" - names: Optional[Sequence] = None - index_col: Union[int, str, Sequence, None] = None - usecols: Union[Sequence, None] = None - dtype: Union[Dtype, Dict[Hashable, Dtype], None] = None - engine: Optional[Literal["c", "python", "pyarrow"]] = None - converters: Optional[Dict[Hashable, Callable]] = None - true_values: Optional[Iterable] = None - false_values: Optional[Iterable] = None + sep: str | None = None + delimiter: str | None = None + header: int | Sequence | str | None = "infer" + names: Sequence | None = None + index_col: int | str | Sequence | None = None + usecols: Sequence | None = None + dtype: Dtype | dict[Hashable, Dtype] | None = None + engine: Literal["c", "python", "pyarrow"] | None = None + converters: dict[Hashable, Callable] | None = None + true_values: Iterable | None = None + false_values: Iterable | None = None skipinitialspace: bool = False - skiprows: Optional[Union[List[int], int, List[Callable]]] = None + skiprows: list[int] | int | list[Callable] | None = None skipfooter: int = 0 - nrows: Optional[int] = None - na_values: Optional[Union[Hashable, Iterable, Dict[Hashable, Iterable]]] = None + nrows: int | None = None + na_values: Hashable | Iterable | dict[Hashable, Iterable] | None = None keep_default_na: bool = True na_filter: bool = True verbose: bool = False skip_blank_lines: bool = True - parse_dates: Union[List[Union[int, str]], Dict[str, List[Union[int, str]]], bool] = False + parse_dates: list[int | str] | dict[str, list[int | str]] | bool = False infer_datetime_format: bool = False keep_date_col: bool = False - date_parser: Optional[Callable] = None - date_format: Optional[Union[str, str]] = None + date_parser: Callable | None = None + date_format: str | str | None = None dayfirst: bool = False cache_dates: bool = True iterator: bool = False - chunksize: Optional[int] = None - compression: Union[str, Dict] = "infer" - thousands: Optional[str] = None + chunksize: int | None = None + compression: str | dict = "infer" + thousands: str | None = None decimal: str = "." - lineterminator: Optional[str] = None - quotechar: Optional[str] = '"' + lineterminator: str | None = None + quotechar: str | None = '"' quoting: int = 0 doublequote: bool = True - escapechar: Optional[str] = None - comment: Optional[str] = None - encoding: Optional[str] = None - encoding_errors: Optional[str] = "strict" - dialect: Optional[str] = None - on_bad_lines: Union[Literal["error", "warn", "skip"], Callable] = "error" + escapechar: str | None = None + comment: str | None = None + encoding: str | None = None + encoding_errors: str | None = "strict" + dialect: str | None = None + on_bad_lines: Literal["error", "warn", "skip"] | Callable = "error" delim_whitespace: bool = False low_memory: bool = True memory_map: bool = False - float_precision: Optional[Literal["high", "legacy", "round_trip"]] = None - storage_options: Optional[Dict] = None + float_precision: Literal["high", "legacy", "round_trip"] | None = None + storage_options: dict | None = None dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = dataclasses.asdict(self) @@ -1645,7 +1640,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the table df = pd.read_table(self.filepath_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, df) @@ -1662,18 +1657,18 @@ class PandasFWFReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_fwf.html """ - filepath_or_buffer: Union[str, Path, BytesIO, BufferedReader] + filepath_or_buffer: str | Path | BytesIO | BufferedReader # kwargs - colspecs: Union[str, List[Tuple[int, int]], Tuple[int, int]] = "infer" - widths: Optional[List[int]] = None + colspecs: str | list[tuple[int, int]] | tuple[int, int] = "infer" + widths: list[int] | None = None infer_nrows: int = 100 dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = dataclasses.asdict(self) @@ -1683,7 +1678,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the fwf file df = pd.read_fwf(self.filepath_or_buffer, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.filepath_or_buffer, df) @@ -1700,17 +1695,17 @@ class PandasSPSSReader(DataLoader): Maps to https://pandas.pydata.org/docs/reference/api/pandas.read_spss.html """ - path: Union[str, Path] + path: str | Path # kwargs - usecols: Optional[Union[List[Hashable], Callable[[str], bool]]] = None + usecols: list[Hashable] | Callable[[str], bool] | None = None convert_categoricals: bool = True dtype_backend: Literal["pyarrow", "numpy_nullable"] = "numpy_nullable" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def _get_loading_kwargs(self) -> Dict[str, Any]: + def _get_loading_kwargs(self) -> dict[str, Any]: # Puts kwargs in a dict kwargs = dataclasses.asdict(self) @@ -1720,7 +1715,7 @@ def _get_loading_kwargs(self) -> Dict[str, Any]: return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: # Loads the data and returns the df and metadata of the spss file df = pd.read_spss(self.path, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.path, df) diff --git a/hamilton/plugins/plotly_extensions.py b/hamilton/plugins/plotly_extensions.py index 432ab7bd7..5ff74747a 100644 --- a/hamilton/plugins/plotly_extensions.py +++ b/hamilton/plugins/plotly_extensions.py @@ -17,7 +17,8 @@ import dataclasses import pathlib -from typing import IO, Any, Collection, Dict, List, Optional, Type, Union +from collections.abc import Collection +from typing import IO, Any try: import plotly.graph_objects @@ -35,11 +36,11 @@ class PlotlyStaticWriter(DataSaver): ref: https://plotly.com/python/static-image-export/ """ - path: Union[str, pathlib.Path, IO] - format: Optional[str] = None - width: Optional[int] = None - height: Optional[int] = None - scale: Optional[Union[int, float]] = None + path: str | pathlib.Path | IO + format: str | None = None + width: int | None = None + height: int | None = None + scale: int | float | None = None validate: bool = True engine: str = "auto" @@ -60,12 +61,12 @@ def _get_saving_kwargs(self) -> dict: return kwargs - def save_data(self, data: plotly.graph_objects.Figure) -> Dict[str, Any]: + def save_data(self, data: plotly.graph_objects.Figure) -> dict[str, Any]: data.write_image(file=self.path, **self._get_saving_kwargs()) return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [plotly.graph_objects.Figure] @classmethod @@ -79,21 +80,21 @@ class PlotlyInteractiveWriter(DataSaver): ref: https://plotly.com/python/interactive-html-export/ """ - path: Union[str, pathlib.Path, IO] - config: Optional[Dict] = None + path: str | pathlib.Path | IO + config: dict | None = None auto_play: bool = True - include_plotlyjs: Union[bool, str] = ( + include_plotlyjs: bool | str = ( True # or "cdn", "directory", "require", "False", "other string .js" ) - include_mathjax: Union[bool, str] = False # "cdn", "string .js" - post_script: Union[str, List[str], None] = None + include_mathjax: bool | str = False # "cdn", "string .js" + post_script: str | list[str] | None = None full_html: bool = True - animation_opts: Optional[Dict] = None - default_width: Union[int, float, str] = "100%" - default_height: Union[int, float, str] = "100%" + animation_opts: dict | None = None + default_width: int | float | str = "100%" + default_height: int | float | str = "100%" validate: bool = True auto_open: bool = True - div_id: Optional[str] = None + div_id: str | None = None def _get_saving_kwargs(self) -> dict: kwargs = {} @@ -123,12 +124,12 @@ def _get_saving_kwargs(self) -> dict: kwargs["div_id"] = self.div_id return kwargs - def save_data(self, data: plotly.graph_objects.Figure) -> Dict[str, Any]: + def save_data(self, data: plotly.graph_objects.Figure) -> dict[str, Any]: data.write_html(file=self.path, **self._get_saving_kwargs()) return utils.get_file_metadata(self.path) @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [plotly.graph_objects.Figure] @classmethod diff --git a/hamilton/plugins/polars_extensions.py b/hamilton/plugins/polars_extensions.py index 53f0f963d..0a8de7386 100644 --- a/hamilton/plugins/polars_extensions.py +++ b/hamilton/plugins/polars_extensions.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Type +from typing import Any from packaging import version @@ -27,7 +27,7 @@ try: from xlsxwriter.workbook import Workbook except ImportError: - Workbook = Type + Workbook = type try: import polars as pl diff --git a/hamilton/plugins/polars_lazyframe_extensions.py b/hamilton/plugins/polars_lazyframe_extensions.py index 6e53e478d..b06f2db94 100644 --- a/hamilton/plugins/polars_lazyframe_extensions.py +++ b/hamilton/plugins/polars_lazyframe_extensions.py @@ -16,21 +16,13 @@ # under the License. import dataclasses +from collections.abc import Collection, Mapping, Sequence from io import BytesIO from pathlib import Path from typing import ( Any, BinaryIO, - Collection, - Dict, - List, - Mapping, - Optional, - Sequence, TextIO, - Tuple, - Type, - Union, ) try: @@ -48,7 +40,7 @@ if has_alias and hasattr(pl.type_aliases, "CsvEncoding"): from polars.type_aliases import CsvEncoding else: - CsvEncoding = Type + CsvEncoding = type # import these types to make type hinting work from polars.datatypes import DataType, DataTypeClass # noqa: F401 @@ -95,17 +87,17 @@ class PolarsScanCSVReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: has_header: bool = True - columns: Union[Sequence[int], Sequence[str]] = None + columns: Sequence[int] | Sequence[str] = None new_columns: Sequence[str] = None separator: str = "," comment_char: str = None quote_char: str = '"' skip_rows: int = 0 - dtypes: Union[Mapping[str, Any], Sequence[Any]] = None - null_values: Union[str, Sequence[str], Dict[str, str]] = None + dtypes: Mapping[str, Any] | Sequence[Any] = None + null_values: str | Sequence[str] | dict[str, str] = None missing_utf8_is_empty_string: bool = False ignore_errors: bool = False try_parse_dates: bool = False @@ -113,11 +105,11 @@ class PolarsScanCSVReader(DataLoader): infer_schema_length: int = 100 batch_size: int = 8192 n_rows: int = None - encoding: Union[CsvEncoding, str] = "utf8" + encoding: CsvEncoding | str = "utf8" low_memory: bool = False rechunk: bool = True use_pyarrow: bool = False - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None skip_rows_after_header: int = 0 row_count_name: str = None row_count_offset: int = 0 @@ -177,10 +169,10 @@ def _get_loading_kwargs(self): return kwargs @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.scan_csv(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) @@ -197,13 +189,13 @@ class PolarsScanParquetReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: - columns: Union[List[int], List[str]] = None + columns: list[int] | list[str] = None n_rows: int = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None parallel: Any = "auto" row_count_name: str = None row_count_offset: int = 0 @@ -212,7 +204,7 @@ class PolarsScanParquetReader(DataLoader): rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -237,7 +229,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.scan_parquet(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -254,19 +246,19 @@ class PolarsScanFeatherReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html """ - source: Union[str, BinaryIO, BytesIO, Path, bytes] + source: str | BinaryIO | BytesIO | Path | bytes # kwargs: - columns: Optional[Union[List[str], List[int]]] = None - n_rows: Optional[int] = None + columns: list[str] | list[int] | None = None + n_rows: int | None = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Optional[Dict[str, Any]] = None - row_count_name: Optional[str] = None + storage_options: dict[str, Any] | None = None + row_count_name: str | None = None row_count_offset: int = 0 rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -287,7 +279,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.scan_ipc(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata diff --git a/hamilton/plugins/polars_post_1_0_0_extensions.py b/hamilton/plugins/polars_post_1_0_0_extensions.py index 25c0f6826..dab1a3973 100644 --- a/hamilton/plugins/polars_post_1_0_0_extensions.py +++ b/hamilton/plugins/polars_post_1_0_0_extensions.py @@ -16,28 +16,20 @@ # under the License. import dataclasses +from collections.abc import Collection, Mapping, Sequence from io import BytesIO, IOBase, TextIOWrapper from pathlib import Path from typing import ( Any, BinaryIO, - Collection, - Dict, - List, Literal, - Mapping, - Optional, - Sequence, TextIO, - Tuple, - Type, - Union, ) try: from xlsxwriter.workbook import Workbook except ImportError: - Workbook = Type + Workbook = type import polars as pl from polars._typing import ConnectionOrCursor @@ -58,14 +50,14 @@ Selector = type(_selector_proxy_) else: # Stub for older polars versions - Selector = Type + Selector = type # for polars 0.18.0 we need to check what to do. from polars._typing import CsvEncoding, SchemaDefinition -CsvQuoteStyle = Type +CsvQuoteStyle = type -IpcCompression = Type +IpcCompression = type from hamilton import registry from hamilton.io import utils @@ -81,18 +73,18 @@ class PolarsCSVReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: has_header: bool = True include_header: bool = True - columns: Union[Sequence[int], Sequence[str]] = None + columns: Sequence[int] | Sequence[str] = None new_columns: Sequence[str] = None separator: str = "," comment_char: str = None quote_char: str = '"' skip_rows: int = 0 - dtypes: Union[Mapping[str, Any], Sequence[Any]] = None - null_values: Union[str, Sequence[str], Dict[str, str]] = None + dtypes: Mapping[str, Any] | Sequence[Any] = None + null_values: str | Sequence[str] | dict[str, str] = None missing_utf8_is_empty_string: bool = False ignore_errors: bool = False try_parse_dates: bool = False @@ -100,11 +92,11 @@ class PolarsCSVReader(DataLoader): infer_schema_length: int = 100 batch_size: int = 8192 n_rows: int = None - encoding: Union[CsvEncoding, str] = "utf8" + encoding: CsvEncoding | str = "utf8" low_memory: bool = False rechunk: bool = True use_pyarrow: bool = False - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None skip_rows_after_header: int = 0 row_count_name: str = None row_count_offset: int = 0 @@ -113,7 +105,7 @@ class PolarsCSVReader(DataLoader): raise_if_empty: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -172,7 +164,7 @@ def _get_loading_kwargs(self): kwargs["raise_if_empty"] = self.raise_if_empty return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_csv(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) @@ -189,7 +181,7 @@ class PolarsCSVWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_csv.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: include_header: bool = True separator: str = "," @@ -204,7 +196,7 @@ class PolarsCSVWriter(DataSaver): quote_style: CsvQuoteStyle = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -235,7 +227,7 @@ def _get_saving_kwargs(self): kwargs["quote_style"] = self.quote_style return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() data.write_csv(self.file, **self._get_saving_kwargs()) @@ -252,23 +244,23 @@ class PolarsParquetReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: - columns: Union[List[int], List[str]] = None + columns: list[int] | list[str] = None n_rows: int = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None parallel: Any = "auto" row_count_name: str = None row_count_offset: int = 0 low_memory: bool = False - pyarrow_options: Dict[str, Any] = None + pyarrow_options: dict[str, Any] = None use_statistics: bool = True rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -299,7 +291,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_parquet(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -315,17 +307,17 @@ class PolarsParquetWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: compression: Any = "zstd" compression_level: int = None statistics: bool = False row_group_size: int = None use_pyarrow: bool = False - pyarrow_options: Dict[str, Any] = None + pyarrow_options: dict[str, Any] = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -344,7 +336,7 @@ def _get_saving_kwargs(self): kwargs["pyarrow_options"] = self.pyarrow_options return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -364,19 +356,19 @@ class PolarsFeatherReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html """ - source: Union[str, BinaryIO, BytesIO, Path, bytes] + source: str | BinaryIO | BytesIO | Path | bytes # kwargs: - columns: Optional[Union[List[str], List[int]]] = None - n_rows: Optional[int] = None + columns: list[str] | list[int] | None = None + n_rows: int | None = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Optional[Dict[str, Any]] = None - row_count_name: Optional[str] = None + storage_options: dict[str, Any] | None = None + row_count_name: str | None = None row_count_offset: int = 0 rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -399,7 +391,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_ipc(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -416,12 +408,12 @@ class PolarsFeatherWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_ipc.html """ - file: Optional[Union[BinaryIO, BytesIO, str, Path]] = None + file: BinaryIO | BytesIO | str | Path | None = None # kwargs: compression: IpcCompression = "uncompressed" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -430,7 +422,7 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() data.write_ipc(self.file, **self._get_saving_kwargs()) @@ -447,13 +439,13 @@ class PolarsAvroReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_avro.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: - columns: Union[List[int], List[str], None] = None - n_rows: Union[int, None] = None + columns: list[int] | list[str] | None = None + n_rows: int | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -464,7 +456,7 @@ def _get_loading_kwargs(self): kwargs["n_rows"] = self.n_rows return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_avro(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -480,12 +472,12 @@ class PolarsAvroWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_avro.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: compression: Any = "uncompressed" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -494,7 +486,7 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -513,12 +505,12 @@ class PolarsJSONReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html """ - source: Union[str, Path, IOBase, bytes] + source: str | Path | IOBase | bytes schema: SchemaDefinition = None schema_overrides: SchemaDefinition = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -529,7 +521,7 @@ def _get_loading_kwargs(self): kwargs["schema_overrides"] = self.schema_overrides return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_json(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -546,13 +538,13 @@ class PolarsJSONWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html """ - file: Union[IOBase, str, Path] + file: IOBase | str | Path @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -571,12 +563,12 @@ class PolarsNDJSONReader(DataLoader): Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.read_ndjson.html """ - source: Union[str, Path, IOBase, bytes] + source: str | Path | IOBase | bytes schema: SchemaDefinition = None schema_overrides: SchemaDefinition = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -587,7 +579,7 @@ def _get_loading_kwargs(self): kwargs["schema_overrides"] = self.schema_overrides return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_ndjson(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -604,13 +596,13 @@ class PolarsNDJSONWriter(DataSaver): Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_ndjson.html """ - file: Union[IOBase, str, Path] + file: IOBase | str | Path @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -629,18 +621,18 @@ class PolarsSpreadsheetReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_excel.html """ - source: Union[str, Path, IOBase, bytes] + source: str | Path | IOBase | bytes # kwargs: - sheet_id: Union[int, Sequence[int], None] = None - sheet_name: Union[str, List[str], Tuple[str], None] = None + sheet_id: int | Sequence[int] | None = None + sheet_name: str | list[str] | tuple[str] | None = None engine: Literal["xlsx2csv", "openpyxl", "pyxlsb", "odf", "xlrd", "xlsxwriter"] = "xlsx2csv" - engine_options: Union[Dict[str, Any], None] = None - read_options: Union[Dict[str, Any], None] = None - schema_overrides: Union[Dict[str, Any], None] = None + engine_options: dict[str, Any] | None = None + read_options: dict[str, Any] | None = None + schema_overrides: dict[str, Any] | None = None raise_if_empty: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -661,7 +653,7 @@ def _get_loading_kwargs(self): kwargs["raise_if_empty"] = self.raise_if_empty return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_excel(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -684,43 +676,38 @@ class PolarsSpreadsheetWriter(DataSaver): from polars._typing import ColumnTotalsDefinition, RowTotalsDefinition from polars.datatypes import DataType, DataTypeClass - workbook: Union[Workbook, BytesIO, Path, str] - worksheet: Union[str, None] = None + workbook: Workbook | BytesIO | Path | str + worksheet: str | None = None # kwargs: - position: Union[Tuple[int, int], str] = "A1" - table_style: Union[str, Dict[str, Any], None] = None - table_name: Union[str, None] = None - column_formats: Union[ - Mapping[Union[str, Tuple[str, ...]], Union[str, Mapping[str, str]]], None - ] = None - dtype_formats: Union[Dict[Union[DataType, DataTypeClass], str], None] = None - conditional_formats: Union[ - Mapping[ - Union[str, Collection[str]], - Union[str, Union[Mapping[str, Any], Sequence[Union[str, Mapping[str, Any]]]]], - ], - None, - ] = None - header_format: Union[Dict[str, Any], None] = None - column_totals: Union[ColumnTotalsDefinition, None] = None - column_widths: Union[Mapping[str, Union[Tuple[str, ...], int]], int, None] = None - row_totals: Union[RowTotalsDefinition, None] = None - row_heights: Union[Dict[Union[int, Tuple[int, ...]], int], int, None] = None - sparklines: Union[Dict[str, Union[Sequence[str], Dict[str, Any]]], None] = None - formulas: Union[Dict[str, Union[str, Dict[str, str]]], None] = None + position: tuple[int, int] | str = "A1" + table_style: str | dict[str, Any] | None = None + table_name: str | None = None + column_formats: Mapping[str | tuple[str, ...], str | Mapping[str, str]] | None = None + dtype_formats: dict[DataType | DataTypeClass, str] | None = None + conditional_formats: ( + Mapping[str | Collection[str], str | Mapping[str, Any] | Sequence[str | Mapping[str, Any]]] + | None + ) = None + header_format: dict[str, Any] | None = None + column_totals: ColumnTotalsDefinition | None = None + column_widths: Mapping[str, tuple[str, ...] | int] | int | None = None + row_totals: RowTotalsDefinition | None = None + row_heights: dict[int | tuple[int, ...], int] | int | None = None + sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None + formulas: dict[str, str | dict[str, str]] | None = None float_precision: int = 3 include_header: bool = True autofilter: bool = True autofit: bool = False - hidden_columns: Union[Sequence[str], str, None] = None + hidden_columns: Sequence[str] | str | None = None hide_gridlines: bool = None - sheet_zoom: Union[int, None] = None - freeze_panes: Union[ - str, Tuple[int, int], Tuple[str, int, int], Tuple[int, int, int, int], None - ] = None + sheet_zoom: int | None = None + freeze_panes: ( + str | tuple[int, int] | tuple[str, int, int] | tuple[int, int, int, int] | None + ) = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -769,7 +756,7 @@ def _get_saving_kwargs(self): kwargs["freeze_panes"] = self.freeze_panes return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -788,16 +775,16 @@ class PolarsDatabaseReader(DataLoader): """ query: str - connection: Union[ConnectionOrCursor, str] + connection: ConnectionOrCursor | str # kwargs: iter_batches: bool = False - batch_size: Union[int, None] = None - schema_overrides: Union[Dict[str, Any], None] = None - infer_schema_length: Union[int, None] = None - execute_options: Union[Dict[str, Any], None] = None + batch_size: int | None = None + schema_overrides: dict[str, Any] | None = None + infer_schema_length: int | None = None + execute_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -814,7 +801,7 @@ def _get_loading_kwargs(self): kwargs["execute_options"] = self.execute_options return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_database( query=self.query, connection=self.connection, @@ -835,12 +822,12 @@ class PolarsDatabaseWriter(DataSaver): """ table_name: str - connection: Union[ConnectionOrCursor, str] + connection: ConnectionOrCursor | str if_table_exists: Literal["fail", "replace", "append"] = "fail" engine: Literal["auto", "sqlalchemy", "adbc"] = "sqlalchemy" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -851,7 +838,7 @@ def _get_saving_kwargs(self): kwargs["engine"] = self.engine return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() diff --git a/hamilton/plugins/polars_pre_1_0_0_extension.py b/hamilton/plugins/polars_pre_1_0_0_extension.py index 80605bbbf..4dd92dc43 100644 --- a/hamilton/plugins/polars_pre_1_0_0_extension.py +++ b/hamilton/plugins/polars_pre_1_0_0_extension.py @@ -16,28 +16,20 @@ # under the License. import dataclasses +from collections.abc import Collection, Mapping, Sequence from io import BytesIO, IOBase, TextIOWrapper from pathlib import Path from typing import ( Any, BinaryIO, - Collection, - Dict, - List, Literal, - Mapping, - Optional, - Sequence, TextIO, - Tuple, - Type, - Union, ) try: from xlsxwriter.workbook import Workbook except ImportError: - Workbook = Type + Workbook = type try: import polars as pl @@ -54,15 +46,15 @@ if has_alias and hasattr(pl.type_aliases, "CsvEncoding"): from polars.type_aliases import CsvEncoding, SchemaDefinition else: - CsvEncoding = Type + CsvEncoding = type if has_alias and hasattr(pl.type_aliases, "CsvQuoteStyle"): from polars.type_aliases import CsvQuoteStyle else: - CsvQuoteStyle = Type + CsvQuoteStyle = type if has_alias and hasattr(pl.type_aliases, "IpcCompression"): from polars.type_aliases import IpcCompression else: - IpcCompression = Type + IpcCompression = type from hamilton import registry from hamilton.io import utils @@ -90,18 +82,18 @@ class PolarsCSVReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: has_header: bool = True include_header: bool = True - columns: Union[Sequence[int], Sequence[str]] = None + columns: Sequence[int] | Sequence[str] = None new_columns: Sequence[str] = None separator: str = "," comment_char: str = None quote_char: str = '"' skip_rows: int = 0 - dtypes: Union[Mapping[str, Any], Sequence[Any]] = None - null_values: Union[str, Sequence[str], Dict[str, str]] = None + dtypes: Mapping[str, Any] | Sequence[Any] = None + null_values: str | Sequence[str] | dict[str, str] = None missing_utf8_is_empty_string: bool = False ignore_errors: bool = False try_parse_dates: bool = False @@ -109,11 +101,11 @@ class PolarsCSVReader(DataLoader): infer_schema_length: int = 100 batch_size: int = 8192 n_rows: int = None - encoding: Union[CsvEncoding, str] = "utf8" + encoding: CsvEncoding | str = "utf8" low_memory: bool = False rechunk: bool = True use_pyarrow: bool = False - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None skip_rows_after_header: int = 0 row_count_name: str = None row_count_offset: int = 0 @@ -122,7 +114,7 @@ class PolarsCSVReader(DataLoader): raise_if_empty: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -181,7 +173,7 @@ def _get_loading_kwargs(self): kwargs["raise_if_empty"] = self.raise_if_empty return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_csv(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) @@ -198,7 +190,7 @@ class PolarsCSVWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_csv.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: has_header: bool = True separator: str = "," @@ -213,7 +205,7 @@ class PolarsCSVWriter(DataSaver): quote_style: CsvQuoteStyle = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -244,7 +236,7 @@ def _get_saving_kwargs(self): kwargs["quote_style"] = self.quote_style return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() data.write_csv(self.file, **self._get_saving_kwargs()) @@ -261,23 +253,23 @@ class PolarsParquetReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: - columns: Union[List[int], List[str]] = None + columns: list[int] | list[str] = None n_rows: int = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Dict[str, Any] = None + storage_options: dict[str, Any] = None parallel: Any = "auto" row_count_name: str = None row_count_offset: int = 0 low_memory: bool = False - pyarrow_options: Dict[str, Any] = None + pyarrow_options: dict[str, Any] = None use_statistics: bool = True rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -308,7 +300,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_parquet(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -324,17 +316,17 @@ class PolarsParquetWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: compression: Any = "zstd" compression_level: int = None statistics: bool = False row_group_size: int = None use_pyarrow: bool = False - pyarrow_options: Dict[str, Any] = None + pyarrow_options: dict[str, Any] = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -353,7 +345,7 @@ def _get_saving_kwargs(self): kwargs["pyarrow_options"] = self.pyarrow_options return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -373,19 +365,19 @@ class PolarsFeatherReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ipc.html """ - source: Union[str, BinaryIO, BytesIO, Path, bytes] + source: str | BinaryIO | BytesIO | Path | bytes # kwargs: - columns: Optional[Union[List[str], List[int]]] = None - n_rows: Optional[int] = None + columns: list[str] | list[int] | None = None + n_rows: int | None = None use_pyarrow: bool = False memory_map: bool = True - storage_options: Optional[Dict[str, Any]] = None - row_count_name: Optional[str] = None + storage_options: dict[str, Any] | None = None + row_count_name: str | None = None row_count_offset: int = 0 rechunk: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -408,7 +400,7 @@ def _get_loading_kwargs(self): kwargs["rechunk"] = self.rechunk return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_ipc(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -425,12 +417,12 @@ class PolarsFeatherWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_ipc.html """ - file: Optional[Union[BinaryIO, BytesIO, str, Path]] = None + file: BinaryIO | BytesIO | str | Path | None = None # kwargs: compression: IpcCompression = "uncompressed" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -439,7 +431,7 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() data.write_ipc(self.file, **self._get_saving_kwargs()) @@ -456,13 +448,13 @@ class PolarsAvroReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_avro.html """ - file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes] + file: str | TextIO | BytesIO | Path | BinaryIO | bytes # kwargs: - columns: Union[List[int], List[str], None] = None - n_rows: Union[int, None] = None + columns: list[int] | list[str] | None = None + n_rows: int | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -473,7 +465,7 @@ def _get_loading_kwargs(self): kwargs["n_rows"] = self.n_rows return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_avro(self.file, **self._get_loading_kwargs()) metadata = utils.get_file_and_dataframe_metadata(self.file, df) return df, metadata @@ -489,12 +481,12 @@ class PolarsAvroWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_avro.html """ - file: Union[BytesIO, TextIOWrapper, str, Path] + file: BytesIO | TextIOWrapper | str | Path # kwargs: compression: Any = "uncompressed" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -503,7 +495,7 @@ def _get_saving_kwargs(self): kwargs["compression"] = self.compression return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -522,12 +514,12 @@ class PolarsJSONReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html """ - source: Union[str, Path, IOBase, bytes] + source: str | Path | IOBase | bytes schema: SchemaDefinition = None schema_overrides: SchemaDefinition = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -538,7 +530,7 @@ def _get_loading_kwargs(self): kwargs["schema_overrides"] = self.schema_overrides return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_json(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -555,12 +547,12 @@ class PolarsJSONWriter(DataSaver): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html """ - file: Union[IOBase, str, Path] + file: IOBase | str | Path pretty: bool = False row_oriented: bool = False @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -571,7 +563,7 @@ def _get_saving_kwargs(self): kwargs["row_oriented"] = self.row_oriented return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -590,18 +582,18 @@ class PolarsSpreadsheetReader(DataLoader): Should map to https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_excel.html """ - source: Union[str, Path, IOBase, bytes] + source: str | Path | IOBase | bytes # kwargs: - sheet_id: Union[int, Sequence[int], None] = None - sheet_name: Union[str, List[str], Tuple[str], None] = None + sheet_id: int | Sequence[int] | None = None + sheet_name: str | list[str] | tuple[str] | None = None engine: Literal["xlsx2csv", "openpyxl", "pyxlsb", "odf", "xlrd", "xlsxwriter"] = "xlsx2csv" - engine_options: Union[Dict[str, Any], None] = None - read_options: Union[Dict[str, Any], None] = None - schema_overrides: Union[Dict[str, Any], None] = None + engine_options: dict[str, Any] | None = None + read_options: dict[str, Any] | None = None + schema_overrides: dict[str, Any] | None = None raise_if_empty: bool = True @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -622,7 +614,7 @@ def _get_loading_kwargs(self): kwargs["raise_if_empty"] = self.raise_if_empty return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_excel(self.source, **self._get_loading_kwargs()) metadata = utils.get_file_metadata(self.source) return df, metadata @@ -645,43 +637,38 @@ class PolarsSpreadsheetWriter(DataSaver): from polars.datatypes import DataType, DataTypeClass from polars.type_aliases import ColumnTotalsDefinition, RowTotalsDefinition - workbook: Union[Workbook, BytesIO, Path, str] - worksheet: Union[str, None] = None + workbook: Workbook | BytesIO | Path | str + worksheet: str | None = None # kwargs: - position: Union[Tuple[int, int], str] = "A1" - table_style: Union[str, Dict[str, Any], None] = None - table_name: Union[str, None] = None - column_formats: Union[ - Mapping[Union[str, Tuple[str, ...]], Union[str, Mapping[str, str]]], None - ] = None - dtype_formats: Union[Dict[Union[DataType, DataTypeClass], str], None] = None - conditional_formats: Union[ - Mapping[ - Union[str, Collection[str]], - Union[str, Union[Mapping[str, Any], Sequence[Union[str, Mapping[str, Any]]]]], - ], - None, - ] = None - header_format: Union[Dict[str, Any], None] = None - column_totals: Union[ColumnTotalsDefinition, None] = None - column_widths: Union[Mapping[str, Union[Tuple[str, ...], int]], int, None] = None - row_totals: Union[RowTotalsDefinition, None] = None - row_heights: Union[Dict[Union[int, Tuple[int, ...]], int], int, None] = None - sparklines: Union[Dict[str, Union[Sequence[str], Dict[str, Any]]], None] = None - formulas: Union[Dict[str, Union[str, Dict[str, str]]], None] = None + position: tuple[int, int] | str = "A1" + table_style: str | dict[str, Any] | None = None + table_name: str | None = None + column_formats: Mapping[str | tuple[str, ...], str | Mapping[str, str]] | None = None + dtype_formats: dict[DataType | DataTypeClass, str] | None = None + conditional_formats: ( + Mapping[str | Collection[str], str | Mapping[str, Any] | Sequence[str | Mapping[str, Any]]] + | None + ) = None + header_format: dict[str, Any] | None = None + column_totals: ColumnTotalsDefinition | None = None + column_widths: Mapping[str, tuple[str, ...] | int] | int | None = None + row_totals: RowTotalsDefinition | None = None + row_heights: dict[int | tuple[int, ...], int] | int | None = None + sparklines: dict[str, Sequence[str] | dict[str, Any]] | None = None + formulas: dict[str, str | dict[str, str]] | None = None float_precision: int = 3 include_header: bool = True autofilter: bool = True autofit: bool = False - hidden_columns: Union[Sequence[str], str, None] = None + hidden_columns: Sequence[str] | str | None = None hide_gridlines: bool = None - sheet_zoom: Union[int, None] = None - freeze_panes: Union[ - str, Tuple[int, int], Tuple[str, int, int], Tuple[int, int, int, int], None - ] = None + sheet_zoom: int | None = None + freeze_panes: ( + str | tuple[int, int] | tuple[str, int, int] | tuple[int, int, int, int] | None + ) = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -730,7 +717,7 @@ def _get_saving_kwargs(self): kwargs["freeze_panes"] = self.freeze_panes return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() @@ -752,13 +739,13 @@ class PolarsDatabaseReader(DataLoader): connection: str # kwargs: iter_batches: bool = False - batch_size: Union[int, None] = None - schema_overrides: Union[Dict[str, Any], None] = None - infer_schema_length: Union[int, None] = None - execute_options: Union[Dict[str, Any], None] = None + batch_size: int | None = None + schema_overrides: dict[str, Any] | None = None + infer_schema_length: int | None = None + execute_options: dict[str, Any] | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE] def _get_loading_kwargs(self): @@ -775,7 +762,7 @@ def _get_loading_kwargs(self): kwargs["execute_options"] = self.execute_options return kwargs - def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[DATAFRAME_TYPE, dict[str, Any]]: df = pl.read_database( query=self.query, connection=self.connection, @@ -801,7 +788,7 @@ class PolarsDatabaseWriter(DataSaver): engine: Literal["auto", "sqlalchemy", "adbc"] = "sqlalchemy" @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [DATAFRAME_TYPE, pl.LazyFrame] def _get_saving_kwargs(self): @@ -812,7 +799,7 @@ def _get_saving_kwargs(self): kwargs["engine"] = self.engine return kwargs - def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + def save_data(self, data: DATAFRAME_TYPE | pl.LazyFrame) -> dict[str, Any]: if isinstance(data, pl.LazyFrame): data = data.collect() diff --git a/hamilton/plugins/pydantic_extensions.py b/hamilton/plugins/pydantic_extensions.py index 810959123..e0029e28e 100644 --- a/hamilton/plugins/pydantic_extensions.py +++ b/hamilton/plugins/pydantic_extensions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any from hamilton.data_quality import base, default_validators from hamilton.htypes import custom_subclass_check @@ -51,13 +51,13 @@ class PydanticModelValidator(base.BaseDefaultValidator): :param arbitrary_types_allowed: Whether arbitrary types are allowed in the model """ - def __init__(self, model: Type[BaseModel], importance: str): + def __init__(self, model: type[BaseModel], importance: str): super(PydanticModelValidator, self).__init__(importance) self.model = model self._model_adapter = TypeAdapter(model) @classmethod - def applies_to(cls, datatype: Type[Type]) -> bool: + def applies_to(cls, datatype: type[type]) -> bool: # In addition to checking for a subclass of BaseModel, we also check for dict # as this is the standard 'de-serialized' format of pydantic models in python return custom_subclass_check(datatype, BaseModel) or custom_subclass_check(datatype, dict) diff --git a/hamilton/plugins/sklearn_plot_extensions.py b/hamilton/plugins/sklearn_plot_extensions.py index dfb1d6a13..27aca35ce 100644 --- a/hamilton/plugins/sklearn_plot_extensions.py +++ b/hamilton/plugins/sklearn_plot_extensions.py @@ -16,8 +16,9 @@ # under the License. import dataclasses +from collections.abc import Collection from os import PathLike -from typing import Any, Collection, Dict, Optional, Type, Union +from typing import Any, Union try: import sklearn.inspection @@ -63,24 +64,24 @@ @dataclasses.dataclass class SklearnPlotSaver(DataSaver): - path: Union[str, PathLike] + path: str | PathLike # kwargs dpi: float = 200 format: str = "png" - metadata: Optional[dict] = None + metadata: dict | None = None bbox_inches: str = None pad_inches: float = 0.1 - backend: Optional[str] = None + backend: str | None = None papertype: str = None transparent: bool = None - bbox_extra_artists: Optional[list] = None - pil_kwargs: Optional[dict] = None + bbox_extra_artists: list | None = None + pil_kwargs: dict | None = None @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return SKLEARN_PLOT_TYPES - def _get_saving_kwargs(self) -> Dict[str, Any]: + def _get_saving_kwargs(self) -> dict[str, Any]: kwargs = {} if self.dpi is not None: kwargs["dpi"] = self.dpi @@ -104,7 +105,7 @@ def _get_saving_kwargs(self) -> Dict[str, Any]: kwargs["pil_kwargs"] = self.pil_kwargs return kwargs - def save_data(self, data: SKLEARN_PLOT_TYPES_ANNOTATION) -> Dict[str, Any]: + def save_data(self, data: SKLEARN_PLOT_TYPES_ANNOTATION) -> dict[str, Any]: if isinstance(data, pyplot.Figure): figure = data else: diff --git a/hamilton/plugins/spark_extensions.py b/hamilton/plugins/spark_extensions.py index df1cd15f3..593fe2728 100644 --- a/hamilton/plugins/spark_extensions.py +++ b/hamilton/plugins/spark_extensions.py @@ -17,7 +17,8 @@ import abc import dataclasses -from typing import Any, Collection, Dict, Tuple, Type +from collections.abc import Collection +from typing import Any try: import pyspark.sql as ps @@ -41,11 +42,11 @@ class SparkDataFrameDataLoader(DataLoader): spark: SparkSession @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [ps.DataFrame] @abc.abstractmethod - def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + def load_data(self, type_: type[DataFrame]) -> tuple[ps.DataFrame, dict[str, Any]]: pass @@ -56,7 +57,7 @@ class CSVDataLoader(SparkDataFrameDataLoader): header: bool = True sep: str = "," - def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + def load_data(self, type_: type[DataFrame]) -> tuple[ps.DataFrame, dict[str, Any]]: return ( self.spark.read.csv(self.path, header=self.header, sep=self.sep, inferSchema=True), utils.get_file_metadata(self.path), @@ -73,7 +74,7 @@ class ParquetDataLoader(SparkDataFrameDataLoader): # We can always make that a list of strings, or make a multiple reader (.multicsv) - def load_data(self, type_: Type[DataFrame]) -> Tuple[ps.DataFrame, Dict[str, Any]]: + def load_data(self, type_: type[DataFrame]) -> tuple[ps.DataFrame, dict[str, Any]]: return self.spark.read.parquet(self.path), utils.get_file_metadata(self.path) @classmethod diff --git a/hamilton/plugins/xgboost_extensions.py b/hamilton/plugins/xgboost_extensions.py index 6d537a8b6..97e97e18a 100644 --- a/hamilton/plugins/xgboost_extensions.py +++ b/hamilton/plugins/xgboost_extensions.py @@ -16,8 +16,9 @@ # under the License. import dataclasses +from collections.abc import Collection from os import PathLike -from typing import Any, Collection, Dict, Tuple, Type, Union +from typing import Any, Union try: import xgboost @@ -39,13 +40,13 @@ class XGBoostJsonWriter(DataSaver): See differences with pickle format: https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html """ - path: Union[str, PathLike] + path: str | PathLike @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return XGBOOST_MODEL_TYPES - def save_data(self, data: XGBOOST_MODEL_TYPES_ANNOTATION) -> Dict[str, Any]: + def save_data(self, data: XGBOOST_MODEL_TYPES_ANNOTATION) -> dict[str, Any]: data.save_model(self.path) return utils.get_file_metadata(self.path) @@ -60,13 +61,13 @@ class XGBoostJsonReader(DataLoader): See differences with pickle format: https://xgboost.readthedocs.io/en/stable/tutorials/saving_model.html """ - path: Union[str, bytearray, PathLike] + path: str | bytearray | PathLike @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return XGBOOST_MODEL_TYPES - def load_data(self, type_: Type) -> Tuple[XGBOOST_MODEL_TYPES_ANNOTATION, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[XGBOOST_MODEL_TYPES_ANNOTATION, dict[str, Any]]: model = type_() model.load_model(self.path) metadata = utils.get_file_metadata(self.path) diff --git a/hamilton/plugins/yaml_extensions.py b/hamilton/plugins/yaml_extensions.py index 466171963..df53db1f6 100644 --- a/hamilton/plugins/yaml_extensions.py +++ b/hamilton/plugins/yaml_extensions.py @@ -22,7 +22,8 @@ import dataclasses import pathlib -from typing import Any, Collection, Dict, Tuple, Type, Union +from collections.abc import Collection +from typing import Any, Union from hamilton import registry from hamilton.io.data_adapters import DataLoader, DataSaver @@ -34,17 +35,17 @@ @dataclasses.dataclass class YAMLDataLoader(DataLoader): - path: Union[str, pathlib.Path] + path: str | pathlib.Path @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [*PrimitiveTypes] @classmethod def name(cls) -> str: return "yaml" - def load_data(self, type_: Type) -> Tuple[AcceptedTypes, Dict[str, Any]]: + def load_data(self, type_: type) -> tuple[AcceptedTypes, dict[str, Any]]: path = self.path if isinstance(self.path, str): path = pathlib.Path(self.path) @@ -55,17 +56,17 @@ def load_data(self, type_: Type) -> Tuple[AcceptedTypes, Dict[str, Any]]: @dataclasses.dataclass class YAMLDataSaver(DataSaver): - path: Union[str, pathlib.Path] + path: str | pathlib.Path @classmethod - def applicable_types(cls) -> Collection[Type]: + def applicable_types(cls) -> Collection[type]: return [*PrimitiveTypes] @classmethod def name(cls) -> str: return "yaml" - def save_data(self, data: AcceptedTypes) -> Dict[str, Any]: + def save_data(self, data: AcceptedTypes) -> dict[str, Any]: path = self.path if isinstance(path, str): path = pathlib.Path(path) diff --git a/hamilton/registry.py b/hamilton/registry.py index 487ddd918..ae1fc8dac 100644 --- a/hamilton/registry.py +++ b/hamilton/registry.py @@ -22,7 +22,7 @@ import logging import os import pathlib -from typing import Any, Dict, Literal, Optional, Tuple, Type, get_args +from typing import Any, Literal, get_args logger = logging.getLogger(__name__) @@ -51,14 +51,14 @@ "mlflow", "pydantic", ] -HAMILTON_EXTENSIONS: Tuple[ExtensionName, ...] = get_args(ExtensionName) +HAMILTON_EXTENSIONS: tuple[ExtensionName, ...] = get_args(ExtensionName) HAMILTON_AUTOLOAD_ENV = "HAMILTON_AUTOLOAD_EXTENSIONS" # NOTE the variable DEFAULT_CONFIG_LOCAITON is redundant with `hamilton.telemetry` # but this `registry` module must avoid circular imports DEFAULT_CONFIG_LOCATION = pathlib.Path("~/.hamilton.conf").expanduser() # This is a dictionary of extension name -> dict with dataframe and column types. -DF_TYPE_AND_COLUMN_TYPES: Dict[str, Dict[str, Type]] = {} +DF_TYPE_AND_COLUMN_TYPES: dict[str, dict[str, type]] = {} COLUMN_TYPE = "column_type" DATAFRAME_TYPE = "dataframe_type" @@ -166,7 +166,7 @@ def config_disable_autoload(): config.write(f) -def register_types(extension_name: str, dataframe_type: Type, column_type: Optional[Type]): +def register_types(extension_name: str, dataframe_type: type, column_type: type | None): """Registers the dataframe and column types for the extension. Note that column types are optional as some extensions may not have a column type (E.G. spark). In this case, this is not included @@ -207,7 +207,7 @@ def fill_with_scalar(df: Any, column_name: str, scalar_value: Any) -> Any: raise NotImplementedError() -def get_column_type_from_df_type(dataframe_type: Type) -> Type: +def get_column_type_from_df_type(dataframe_type: type) -> type: """Function to cycle through the registered extensions and return the column type for the dataframe type. :param dataframe_type: the dataframe type to find the column type for. @@ -239,7 +239,7 @@ def register_adapter(adapter: Any): SAVER_REGISTRY[adapter.name()].append(adapter) -def get_registered_dataframe_types() -> Dict[str, Type]: +def get_registered_dataframe_types() -> dict[str, type]: """Returns a dictionary of extension name -> dataframe type. :return: the dictionary. @@ -250,7 +250,7 @@ def get_registered_dataframe_types() -> Dict[str, Type]: } -def get_registered_column_types() -> Dict[str, Type]: +def get_registered_column_types() -> dict[str, type]: """Returns a dictionary of extension name -> column type. :return: the dictionary. diff --git a/hamilton/telemetry.py b/hamilton/telemetry.py index b967834cc..f6623d5e6 100644 --- a/hamilton/telemetry.py +++ b/hamilton/telemetry.py @@ -40,7 +40,6 @@ import threading import traceback import uuid -from typing import Dict, List, Optional from urllib import request try: @@ -180,12 +179,12 @@ def create_start_event_json( number_of_nodes: int, number_of_modules: int, number_of_config_items: int, - decorators_used: Dict[str, int], + decorators_used: dict[str, int], graph_adapter_used: str, - lifecycle_adapters_used: List[str], + lifecycle_adapters_used: list[str], result_builder_used: str, driver_run_id: uuid.UUID, - error: Optional[str], + error: str | None, graph_executor_class: str, ): """Creates the start event JSON. @@ -232,7 +231,7 @@ def create_end_event_json( number_of_overrides: int, number_of_inputs: int, driver_run_id: uuid.UUID, - error: Optional[str], + error: str | None, ): """Creates the end event JSON. @@ -458,7 +457,7 @@ def sanitize_error(exc_type, exc_value, exc_traceback) -> str: return "FAILED_TO_SANITIZE_ERROR" -def get_all_adapters_names(adapter: lifecycle_base.LifecycleAdapterSet) -> List[str]: +def get_all_adapters_names(adapter: lifecycle_base.LifecycleAdapterSet) -> list[str]: """Gives a list of all adapter names in the LifecycleAdapterSet. Simply a loop over the adapters it contains. diff --git a/plugin_tests/h_dask/test_h_dask.py b/plugin_tests/h_dask/test_h_dask.py index b9203ccb0..abea8596a 100644 --- a/plugin_tests/h_dask/test_h_dask.py +++ b/plugin_tests/h_dask/test_h_dask.py @@ -207,9 +207,7 @@ def test_smoke_screen_module(client): @pytest.mark.parametrize("outputs, expected", dd_test_cases, ids=dd_test_case_ids) -def test_DDFR_build_result_pandas( - client, outputs: typing.Dict[str, typing.Any], expected: dd.DataFrame -): +def test_DDFR_build_result_pandas(client, outputs: dict[str, typing.Any], expected: dd.DataFrame): """Tests using pandas objects works""" actual = h_dask.DaskDataFrameResult.build_result(**outputs) actual_pdf = actual.compute().convert_dtypes(dtype_backend="pyarrow") @@ -218,9 +216,7 @@ def test_DDFR_build_result_pandas( @pytest.mark.parametrize("outputs, expected", dd_test_cases, ids=dd_test_case_ids) -def test_DDFR_build_result_dask( - client, outputs: typing.Dict[str, typing.Any], expected: dd.DataFrame -): +def test_DDFR_build_result_dask(client, outputs: dict[str, typing.Any], expected: dd.DataFrame): """Tests that using dask objects works.""" dask_outputs = {} for k, v in outputs.items(): diff --git a/pyproject.toml b/pyproject.toml index 839e6cd1f..bb6b03ff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -231,7 +231,8 @@ extend-select = [ "TC", # Move type-only imports to a type-checking block. "TID", # Helps you write tidier imports. # "TRY", # Prevent exception handling anti-patterns -# "UP", # pyupgrade + "UP006", # pyupgrade + "UP045", # pyupgrade "W", # pycodestyle warnings ] extend-ignore = [ diff --git a/scripts/add_license_headers.py b/scripts/add_license_headers.py index c7b5d48ba..29fe8ccec 100755 --- a/scripts/add_license_headers.py +++ b/scripts/add_license_headers.py @@ -21,7 +21,6 @@ import json import sys from pathlib import Path -from typing import List # Base Apache 2 license text (without comment characters) # This is used by all formatters below to generate file-type-specific headers @@ -45,17 +44,17 @@ ] -def format_hash_comment(lines: List[str]) -> str: +def format_hash_comment(lines: list[str]) -> str: """Format license as # comments (for Python, Shell, etc.).""" return "\n".join(f"# {line}" if line else "#" for line in lines) + "\n\n" -def format_dash_comment(lines: List[str]) -> str: +def format_dash_comment(lines: list[str]) -> str: """Format license as -- comments (for SQL).""" return "\n".join(f"-- {line}" if line else "--" for line in lines) + "\n\n" -def format_c_style_comment(lines: List[str]) -> str: +def format_c_style_comment(lines: list[str]) -> str: """Format license as /* */ comments (for TypeScript, JavaScript, etc.).""" formatted_lines = ["/*"] for line in lines: @@ -64,7 +63,7 @@ def format_c_style_comment(lines: List[str]) -> str: return "\n".join(formatted_lines) + "\n\n" -def format_html_comment(lines: List[str]) -> str: +def format_html_comment(lines: list[str]) -> str: """Format license as HTML comments (for Markdown).""" formatted_lines = ["