diff --git a/clean_slides/chart_engine/builder.py b/clean_slides/chart_engine/builder.py index 46ab28e..7cb3796 100644 --- a/clean_slides/chart_engine/builder.py +++ b/clean_slides/chart_engine/builder.py @@ -3,17 +3,33 @@ from __future__ import annotations import copy -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Union, cast +from typing import Callable, Protocol, Union, cast from pptx import Presentation +from pptx.chart.data import CategoryChartData from pptx.dml.color import RGBColor +from pptx.enum.chart import XL_CHART_TYPE from pptx.enum.shapes import MSO_SHAPE, PP_PLACEHOLDER from pptx.oxml.ns import qn from pptx.util import Emu, Inches, Pt -from ..pptx_access import chart_series_names, chart_xml_space, shape_has_text_frame, slide_charts +from ..pptx_access import ( + chart_first_plot, + chart_part_name, + chart_series, + chart_series_names, + chart_xml_space, + plot_data_labels, + set_chart_has_legend, + set_plot_has_data_labels, + shape_chart, + shape_has_text_frame, + shape_xml_element, + slide_add_chart, + slide_charts, +) from . import annotations as _annotations from . import payloads as _payloads from .colors import apply_color @@ -53,8 +69,174 @@ ChartBox = tuple[int, int, int, int] SpecMap = Mapping[str, object] -AddWaterfallTitleFn = Callable[[Any, ChartBox, str, object], None] -BuildPayloadFn = Callable[[dict[str, object]], tuple[Any, Any, dict[str, object]]] + +class _DataLabelFontLike(Protocol): + size: object + color: object + + +class _DataLabelsLike(Protocol): + number_format: str + number_format_is_linked: bool + font: _DataLabelFontLike + position: object + + +class _DataLabelsShowValue(Protocol): + show_value: bool + + +class _PlaceholderFormatLike(Protocol): + type: object + + +class _PlaceholderLike(Protocol): + placeholder_format: _PlaceholderFormatLike + left: int + top: int + width: int + height: int + text: str + + +class _ShapeXmlElementLike(Protocol): + def getparent(self) -> object | None: ... + + +class _ShapeXmlParentLike(Protocol): + def remove(self, element: object) -> None: ... + + +class _SeriesXmlLike(Protocol): + def find(self, path: str) -> object | None: ... + + def remove(self, element: object) -> None: ... + + def insert(self, index: int, element: object) -> None: ... + + def __len__(self) -> int: ... + + +class _ShapeForeColorLike(Protocol): + rgb: object + + +class _ShapeFillLike(Protocol): + fore_color: _ShapeForeColorLike + + def solid(self) -> None: ... + + +class _ShapeLineFillLike(Protocol): + def background(self) -> None: ... + + +class _ShapeLineLike(Protocol): + fill: _ShapeLineFillLike + + +class _ShapeLike(Protocol): + name: str + fill: _ShapeFillLike + line: _ShapeLineLike + + +class _TextFrameParagraphFontLike(Protocol): + size: object + + +class _TextFrameParagraphLike(Protocol): + font: _TextFrameParagraphFontLike + + +class _TextFrameLike(Protocol): + text: str + paragraphs: Sequence[_TextFrameParagraphLike] + + +class _TextBoxLike(Protocol): + name: str + text_frame: _TextFrameLike + + +class _SlideShapesLike(Protocol): + def add_shape( + self, + shape_type: object, + left: object, + top: object, + width: object, + height: object, + ) -> _ShapeLike: ... + + def add_textbox( + self, + left: object, + top: object, + width: object, + height: object, + ) -> _TextBoxLike: ... + + +class _SlideLike(Protocol): + shapes: _SlideShapesLike + placeholders: Sequence[object] + + +class _SlidesCollectionLike(Protocol): + def __len__(self) -> int: ... + + def __getitem__(self, index: int) -> object: ... + + def add_slide(self, slide_layout: object) -> object: ... + + +class _SlideLayoutsCollectionLike(Protocol): + def __len__(self) -> int: ... + + def __getitem__(self, index: int) -> object: ... + + def __iter__(self) -> Iterator[object]: ... + + +class _PresentationLike(Protocol): + slides: _SlidesCollectionLike + slide_layouts: _SlideLayoutsCollectionLike + slide_width: int + slide_height: int + + def save(self, file: object) -> None: ... + + +class _ChartTitleTextFrameLike(Protocol): + text: str + + +class _ChartTitleLike(Protocol): + text_frame: _ChartTitleTextFrameLike + + +class _ChartLegendLike(Protocol): + include_in_layout: bool + + +class _ChartLike(Protocol): + has_title: bool + chart_title: _ChartTitleLike + legend: _ChartLegendLike + + +class _SeriesLike(Protocol): + has_data_labels: bool + + @property + def data_labels(self) -> object: ... + + +AddWaterfallTitleFn = Callable[[object, ChartBox, str, object], None] +BuildPayloadFn = Callable[ + [dict[str, object]], tuple[XL_CHART_TYPE, CategoryChartData, dict[str, object]] +] def _require_attr(module: object, name: str) -> object: @@ -192,8 +374,8 @@ def apply_chart_template_dlbls( if series_index >= len(target_series): return - template_series_element = template_series[series_index] - target_series_element = target_series[series_index] + template_series_element = cast(_SeriesXmlLike, template_series[series_index]) + target_series_element = cast(_SeriesXmlLike, target_series[series_index]) template_dlbls = template_series_element.find(qn("c:dLbls")) if template_dlbls is None: @@ -216,25 +398,28 @@ def resolve_series_indices(series_spec: object, series_names: Sequence[str]) -> return list(dict.fromkeys(indices)) -def apply_data_label_style(labels: Any, data_cfg: SpecMap) -> None: - labels.number_format = _str(data_cfg.get("format"), DEFAULT_BAR_DATA_LABEL_FORMAT) - labels.number_format_is_linked = False +def apply_data_label_style(labels: object, data_cfg: SpecMap) -> None: + data_labels = cast(_DataLabelsLike, labels) + data_labels.number_format = _str(data_cfg.get("format"), DEFAULT_BAR_DATA_LABEL_FORMAT) + data_labels.number_format_is_linked = False font_size = data_cfg.get("font_size", DEFAULT_BAR_DATA_LABEL_FONT_SIZE) if isinstance(font_size, Pt): - labels.font.size = font_size + data_labels.font.size = font_size else: - labels.font.size = Pt(float(_int(font_size, int(DEFAULT_BAR_DATA_LABEL_FONT_SIZE)))) + data_labels.font.size = Pt(float(_int(font_size, int(DEFAULT_BAR_DATA_LABEL_FONT_SIZE)))) label_position = normalize_label_position(_str_or_none(data_cfg.get("position"))) if label_position is not None: - labels.position = label_position - if hasattr(labels, "show_value"): - labels.show_value = True + data_labels.position = label_position + + if hasattr(data_labels, "show_value"): + labels_with_value = cast(_DataLabelsShowValue, data_labels) + labels_with_value.show_value = True color_value = data_cfg.get("color") if isinstance(color_value, (RGBColor, str)): - apply_color(labels.font.color, color_value) + apply_color(data_labels.font.color, color_value) def chart_box_from_spec(raw: object) -> ChartBox | None: @@ -259,7 +444,7 @@ def chart_box_from_spec(raw: object) -> ChartBox | None: return box -def template_content_box(slide: Any, template_path: Path | None) -> ChartBox | None: +def template_content_box(slide: object, template_path: Path | None) -> ChartBox | None: if template_path is None: return None @@ -292,12 +477,14 @@ def template_content_box(slide: Any, template_path: Path | None) -> ChartBox | N ) -def find_content_placeholder(slide: Any) -> Any | None: - candidates: list[Any] = [] - for placeholder in slide.placeholders: - ph_type = placeholder.placeholder_format.type +def find_content_placeholder(slide: object) -> _PlaceholderLike | None: + slide_like = cast(_SlideLike, slide) + candidates: list[_PlaceholderLike] = [] + for placeholder in slide_like.placeholders: + candidate = cast(_PlaceholderLike, placeholder) + ph_type = candidate.placeholder_format.type if ph_type in (PP_PLACEHOLDER.BODY, PP_PLACEHOLDER.OBJECT): - candidates.append(placeholder) + candidates.append(candidate) if not candidates: return None @@ -305,11 +492,18 @@ def find_content_placeholder(slide: Any) -> Any | None: return max(candidates, key=lambda ph: int(ph.width) * int(ph.height)) -def remove_shape(shape: Any) -> None: - element = shape._element - parent = element.getparent() - if parent is not None: - parent.remove(element) +def remove_shape(shape: object) -> None: + element_obj = shape_xml_element(shape) + if element_obj is None: + return + + element = cast(_ShapeXmlElementLike, element_obj) + parent_obj = element.getparent() + if parent_obj is None: + return + + parent = cast(_ShapeXmlParentLike, parent_obj) + parent.remove(element_obj) def adjust_bar_chart_box_for_overlays( @@ -351,20 +545,23 @@ def adjust_bar_chart_box_for_overlays( return (int(x), int(y), int(cx), int(new_cy)) -def find_layout(prs: Any, name: str | None) -> Any | None: +def find_layout(prs: object, name: str | None) -> object | None: if not name: return None + prs_like = cast(_PresentationLike, prs) target = name.strip().lower() - for layout in prs.slide_layouts: + for layout in prs_like.slide_layouts: layout_name = str(getattr(layout, "name", "")).strip().lower() if layout_name == target: return layout return None -def apply_template_placeholders(slide: Any, title: str | None, subtitle: str | None) -> None: - for placeholder in slide.placeholders: +def apply_template_placeholders(slide: object, title: str | None, subtitle: str | None) -> None: + slide_like = cast(_SlideLike, slide) + for placeholder_obj in slide_like.placeholders: + placeholder = cast(_PlaceholderLike, placeholder_obj) if not shape_has_text_frame(placeholder): continue @@ -377,8 +574,9 @@ def apply_template_placeholders(slide: Any, title: str | None, subtitle: str | N placeholder.text = "" -def add_hidden_anchor(slide: Any) -> None: - shape = slide.shapes.add_shape( +def add_hidden_anchor(slide: object) -> None: + slide_like = cast(_SlideLike, slide) + shape = slide_like.shapes.add_shape( MSO_SHAPE.RECTANGLE, Inches(0.02), Inches(0.02), @@ -391,7 +589,8 @@ def add_hidden_anchor(slide: Any) -> None: shape.line.fill.background() -def add_overlay_labels(slide: Any, categories: Sequence[object], chart_box: ChartBox) -> None: +def add_overlay_labels(slide: object, categories: Sequence[object], chart_box: ChartBox) -> None: + slide_like = cast(_SlideLike, slide) x, y, cx, cy = chart_box label_y = int(y + cy + Inches(0.1)) if not categories: @@ -399,7 +598,7 @@ def add_overlay_labels(slide: Any, categories: Sequence[object], chart_box: Char label_width = int(cx / len(categories)) for idx, label in enumerate(categories): - text_box = slide.shapes.add_textbox( + text_box = slide_like.shapes.add_textbox( int(x + (label_width * idx)), label_y, label_width, @@ -413,11 +612,13 @@ def add_overlay_labels(slide: Any, categories: Sequence[object], chart_box: Char def select_slide( - prs: Any, - slide_layout: Any, + prs: object, + slide_layout: object, use_template: bool, spec: SpecMap, -) -> Any: +) -> object: + prs_like = cast(_PresentationLike, prs) + template_slide_index = spec.get("template_slide_index") append_slide = _bool(spec.get("append_slide", False), False) @@ -427,25 +628,25 @@ def select_slide( except ValueError as exc: raise ValueError("template_slide_index must be an integer (1-based)") from exc - if index < 0 or index >= len(prs.slides): + if index < 0 or index >= len(prs_like.slides): raise ValueError( - f"template_slide_index {template_slide_index} is out of range (1-{len(prs.slides)})" + f"template_slide_index {template_slide_index} is out of range (1-{len(prs_like.slides)})" ) - return prs.slides[index] + return prs_like.slides[index] - if use_template and not append_slide and len(prs.slides) == 1: - return prs.slides[0] + if use_template and not append_slide and len(prs_like.slides) == 1: + return prs_like.slides[0] - if use_template and append_slide and len(prs.slides) == 1: - slide = prs.slides[0] + if use_template and append_slide and len(prs_like.slides) == 1: + slide = prs_like.slides[0] if not slide_charts(slide): return slide - return prs.slides.add_slide(slide_layout) + return prs_like.slides.add_slide(slide_layout) def build_chart( - prs: Any, + prs: object, spec: SpecMap, output_path: Path, template_path: Path | None = None, @@ -453,14 +654,20 @@ def build_chart( save: bool = True, defer_template_copy: bool = False, ) -> list[ChartTemplateReplacement]: + prs_like = cast(_PresentationLike, prs) spec_map = _mapping(spec) - slide_layout = find_layout(prs, layout_name) + slide_layout = find_layout(prs_like, layout_name) if slide_layout is None: - slide_layout = prs.slide_layouts[6] if len(prs.slide_layouts) > 6 else prs.slide_layouts[0] + slide_layout = ( + prs_like.slide_layouts[6] + if len(prs_like.slide_layouts) > 6 + else prs_like.slide_layouts[0] + ) use_template = template_path is not None - slide = select_slide(prs, slide_layout, use_template, spec_map) + slide = select_slide(prs_like, slide_layout, use_template, spec_map) + slide_like = cast(_SlideLike, slide) add_hidden_anchor(slide) @@ -479,7 +686,7 @@ def build_chart( title = _str_or_none(spec_map.get("title")) subtitle = _str_or_none(spec_map.get("subtitle")) - content_placeholder: Any | None = None + content_placeholder: _PlaceholderLike | None = None remove_placeholder_value = spec_map.get("remove_content_placeholder") if remove_placeholder_value is None: remove_placeholder = use_template @@ -522,12 +729,33 @@ def build_chart( chart_box = adjust_bar_chart_box_for_overlays(chart_box, bar_style) x, y, cx, cy = chart_box - chart = slide.shapes.add_chart(chart_type, x, y, cx, cy, chart_data).chart - chart_part = str(chart.part.partname).lstrip("/") + + chart_frame = slide_add_chart( + slide_like, + chart_type, + Emu(x), + Emu(y), + Emu(cx), + Emu(cy), + chart_data, + ) + if chart_frame is None: + return [] + + chart = shape_chart(chart_frame) + if chart is None: + return [] + + chart_part_value = chart_part_name(chart) + if chart_part_value is None: + raise ValueError("Chart part name is unavailable") + chart_part = chart_part_value.lstrip("/") + + chart_like = cast(_ChartLike, chart) if title and not is_waterfall and not use_template: - chart.has_title = True - chart.chart_title.text_frame.text = title + chart_like.has_title = True + chart_like.chart_title.text_frame.text = title show_legend_value = spec_map.get("show_legend") if show_legend_value is None: @@ -535,9 +763,9 @@ def build_chart( else: show_legend = bool(show_legend_value) - chart.has_legend = show_legend - if chart.has_legend: - chart.legend.include_in_layout = False + set_chart_has_legend(chart, show_legend) + if show_legend: + chart_like.legend.include_in_layout = False if is_waterfall: show_segment_labels = _bool(spec_map.get("show_data_labels", True), True) @@ -552,15 +780,23 @@ def build_chart( if series_selector is not None: series_names = chart_series_names(chart) indices = resolve_series_indices(series_selector, series_names) + series_list = chart_series(chart) + for idx in indices: - if 0 <= idx < len(chart.series): - series = chart.series[idx] - series.has_data_labels = True - apply_data_label_style(series.data_labels, data_cfg) + if idx < 0 or idx >= len(series_list): + continue + + series_obj = series_list[idx] + series = cast(_SeriesLike, series_obj) + series.has_data_labels = True + apply_data_label_style(series.data_labels, data_cfg) else: - plot = chart.plots[0] - plot.has_data_labels = True - apply_data_label_style(plot.data_labels, data_cfg) + plot = chart_first_plot(chart) + if plot is not None: + set_plot_has_data_labels(plot, True) + labels = plot_data_labels(plot) + if labels is not None: + apply_data_label_style(labels, data_cfg) apply_series_colors(chart, _string_or_none_list(style.get("series_colors"))) @@ -599,11 +835,17 @@ def build_chart( if _bool(spec_map.get("add_overlay_labels", False), False): chart_box_emu = (x, y, cx, cy) if is_waterfall: + slide_size: tuple[int, int] | None = None + try: + slide_size = (int(prs_like.slide_width), int(prs_like.slide_height)) + except (TypeError, ValueError): + slide_size = None + add_waterfall_overlays( slide, chart_box_emu, waterfall_style, - slide_size=(int(prs.slide_width), int(prs.slide_height)), + slide_size=slide_size, ) elif bar_style: add_bar_overlays(slide, chart_box_emu, bar_style) @@ -624,7 +866,7 @@ def build_chart( ) if save: - prs.save(output_path) + prs_like.save(output_path) if replacements and not defer_template_copy: apply_chart_template_replacements(output_path, replacements) replacements = [] diff --git a/clean_slides/chart_engine/overlay_bar.py b/clean_slides/chart_engine/overlay_bar.py index d88ccd4..3bf1da7 100644 --- a/clean_slides/chart_engine/overlay_bar.py +++ b/clean_slides/chart_engine/overlay_bar.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Union, cast +from typing import Callable, Union, cast from pptx.dml.color import RGBColor from pptx.util import Pt @@ -51,8 +51,8 @@ OverlaySpec = Mapping[str, object] MetaSpec = Mapping[str, object] -AddLineAnnotationFn = Callable[[Any, dict[str, object]], None] -AddShapeAnnotationFn = Callable[[Any, dict[str, object]], None] +AddLineAnnotationFn = Callable[[object, dict[str, object]], None] +AddShapeAnnotationFn = Callable[[object, dict[str, object]], None] LoadTemplateFn = Callable[[Path, str], object] @@ -218,7 +218,7 @@ def _plot_layout(value: object) -> dict[str, float]: def add_bar_overlays( - slide: Any, + slide: object, chart_box: tuple[int, int, int, int], meta: MetaSpec, ) -> None: diff --git a/clean_slides/chart_engine/overlay_bar_legend.py b/clean_slides/chart_engine/overlay_bar_legend.py index ba1a728..5aa0650 100644 --- a/clean_slides/chart_engine/overlay_bar_legend.py +++ b/clean_slides/chart_engine/overlay_bar_legend.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Union, cast +from typing import Callable, Protocol, Union, cast from pptx.dml.color import RGBColor from pptx.enum.shapes import MSO_SHAPE @@ -23,6 +23,46 @@ TemplateMap = Mapping[str, object] LegendGeometry = Mapping[str, float] + +class _ShapeForeColorLike(Protocol): + theme_color: object + rgb: object + + +class _ShapeFillLike(Protocol): + fore_color: _ShapeForeColorLike + + def solid(self) -> None: ... + + +class _ShapeLineFillLike(Protocol): + def background(self) -> None: ... + + +class _ShapeLineLike(Protocol): + fill: _ShapeLineFillLike + + +class _ShapeLike(Protocol): + fill: _ShapeFillLike + line: _ShapeLineLike + + +class _SlideShapesLike(Protocol): + def add_shape( + self, + shape_type: object, + left: object, + top: object, + width: object, + height: object, + ) -> _ShapeLike: ... + + +class _SlideLike(Protocol): + shapes: _SlideShapesLike + + AddTextLabelFn = Callable[..., object] ResolveTemplateFn = Callable[[PathOrNone, str, object], object] @@ -54,7 +94,7 @@ def _bool(value: object, default: bool) -> bool: def add_bar_legend( - slide: Any, + slide: object, *, overlay: Mapping[str, object], chart_box: tuple[int, int, int, int], @@ -78,6 +118,8 @@ def add_bar_legend( templates: TemplateMap, ) -> None: """Render legend labels and optional color markers.""" + slide_like = cast(_SlideLike, slide) + legend_layout = _optional_str(overlay.get("legend_layout")) legend_align = normalize_alignment(_optional_str(overlay.get("legend_alignment"))) legend_show_markers = _bool(overlay.get("legend_show_markers", True), True) @@ -162,7 +204,7 @@ def add_bar_legend( if legend_show_markers and series_color: marker_x = plot_left + plot_width * (marker_left_ratio + marker_step_ratio * idx) marker_y = legend_y + marker_y_offset - shape = slide.shapes.add_shape( + shape = slide_like.shapes.add_shape( MSO_SHAPE.RECTANGLE, int(marker_x), int(marker_y), diff --git a/clean_slides/chart_engine/overlay_bar_segments.py b/clean_slides/chart_engine/overlay_bar_segments.py index aa6e60d..6887ca6 100644 --- a/clean_slides/chart_engine/overlay_bar_segments.py +++ b/clean_slides/chart_engine/overlay_bar_segments.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Union, cast +from typing import Callable, Union, cast from pptx.dml.color import RGBColor @@ -108,7 +108,7 @@ def _coerce_fill_color(fill_value: object, series_color: StrOrNone) -> ColorValu def add_bar_segment_labels( - slide: Any, + slide: object, *, overlay: OverlaySpec, categories: Sequence[object], diff --git a/clean_slides/chart_engine/overlay_bar_totals_categories.py b/clean_slides/chart_engine/overlay_bar_totals_categories.py index 8854764..c259947 100644 --- a/clean_slides/chart_engine/overlay_bar_totals_categories.py +++ b/clean_slides/chart_engine/overlay_bar_totals_categories.py @@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Any, Callable, Union, cast +from typing import Callable, Union, cast from pptx.dml.color import RGBColor from pptx.enum.text import PP_ALIGN @@ -66,7 +66,7 @@ def _bar_center(geometry: Geometry, idx: int) -> FloatOrNone: def add_bar_total_labels( - slide: Any, + slide: object, *, overlay: Mapping[str, object], totals: Sequence[FloatOrNone], @@ -176,7 +176,7 @@ def add_bar_total_labels( def add_bar_category_labels( - slide: Any, + slide: object, *, overlay: Mapping[str, object], categories: Sequence[object], diff --git a/clean_slides/chart_engine/overlay_waterfall.py b/clean_slides/chart_engine/overlay_waterfall.py index 20d4d61..635cc49 100644 --- a/clean_slides/chart_engine/overlay_waterfall.py +++ b/clean_slides/chart_engine/overlay_waterfall.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any, Union, cast +from typing import Union, cast from pptx.dml.color import RGBColor from pptx.util import Emu, Pt @@ -110,7 +110,7 @@ def _index_set(value: object) -> set[int]: def add_waterfall_overlays( - slide: Any, + slide: object, chart_box: tuple[int, int, int, int], meta: Mapping[str, object], slide_size: SlideSize = None, diff --git a/clean_slides/chart_engine/overlay_waterfall_connectors.py b/clean_slides/chart_engine/overlay_waterfall_connectors.py index 5a97a80..e483813 100644 --- a/clean_slides/chart_engine/overlay_waterfall_connectors.py +++ b/clean_slides/chart_engine/overlay_waterfall_connectors.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, Union, cast +from typing import Protocol, Union, cast from pptx.dml.color import RGBColor from pptx.enum.shapes import MSO_SHAPE @@ -17,6 +17,44 @@ Geometry = Mapping[str, object] +class _ShapeForeColorLike(Protocol): + rgb: object + + +class _ShapeFillLike(Protocol): + fore_color: _ShapeForeColorLike + + def solid(self) -> None: ... + + +class _ShapeLineFillLike(Protocol): + def background(self) -> None: ... + + +class _ShapeLineLike(Protocol): + fill: _ShapeLineFillLike + + +class _ShapeLike(Protocol): + fill: _ShapeFillLike + line: _ShapeLineLike + + +class _SlideShapesLike(Protocol): + def add_shape( + self, + shape_type: object, + left: object, + top: object, + width: object, + height: object, + ) -> _ShapeLike: ... + + +class _SlideLike(Protocol): + shapes: _SlideShapesLike + + def _geometry_series(geometry: Geometry, key: str) -> list[float]: raw_values = geometry.get(key) if not isinstance(raw_values, list): @@ -40,7 +78,7 @@ def _connector_value(values: Sequence[FloatOrNone], idx: int) -> FloatOrNone: def render_waterfall_connectors( - slide: Any, + slide: object, categories: Sequence[object], connector_values: Sequence[FloatOrNone], geometry: Geometry, @@ -60,11 +98,12 @@ def render_waterfall_connectors( connector_color: ColorValue, ) -> None: """Render connector segments between adjacent waterfall categories.""" + slide_like = cast(_SlideLike, slide) def add_dash_segment(x: float, y: float, width: float, height: float) -> None: if width <= 0 or height <= 0: return - shape = slide.shapes.add_shape( + shape = slide_like.shapes.add_shape( MSO_SHAPE.RECTANGLE, round(x), round(y), diff --git a/clean_slides/chart_engine/overlay_waterfall_data_labels.py b/clean_slides/chart_engine/overlay_waterfall_data_labels.py index 2628997..4cae8b0 100644 --- a/clean_slides/chart_engine/overlay_waterfall_data_labels.py +++ b/clean_slides/chart_engine/overlay_waterfall_data_labels.py @@ -2,13 +2,14 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence from dataclasses import dataclass -from typing import Any, Union, cast +from typing import Protocol, Union, cast from pptx.oxml.ns import qn from pptx.oxml.xmlchemy import OxmlElement +from ..pptx_access import chart_xml_space from .defaults import ( DEFAULT_WATERFALL_DLABEL_INSIDE_OFFSET_RATIO, DEFAULT_WATERFALL_DLABEL_MIN_INSIDE_RATIO, @@ -37,6 +38,26 @@ LabelLayout = dict[str, float] +class _XmlElementLike(Protocol): + tag: str + + def find(self, path: str) -> object | None: ... + + def findall(self, path: str) -> list[object]: ... + + def append(self, element: object) -> None: ... + + def insert(self, index: int, element: object) -> None: ... + + def index(self, element: object) -> int: ... + + def remove(self, element: object) -> None: ... + + def set(self, key: str, value: str) -> None: ... + + def __iter__(self) -> Iterator[object]: ... + + @dataclass class SegmentLabel: series_idx: int @@ -168,17 +189,25 @@ def _label_at(values: Sequence[FloatOrNone], index: int) -> FloatOrNone: return values[index] -def get_chart_series(chart_space: Any) -> list[Any]: - chart = chart_space.find(qn("c:chart")) - if chart is None: +def get_chart_series(chart_space: object) -> list[object]: + chart_space_el = cast(_XmlElementLike, chart_space) + + chart_obj = chart_space_el.find(qn("c:chart")) + if chart_obj is None: return [] - plot_area = chart.find(qn("c:plotArea")) - if plot_area is None: + chart_el = cast(_XmlElementLike, chart_obj) + + plot_area_obj = chart_el.find(qn("c:plotArea")) + if plot_area_obj is None: return [] - bar_chart = plot_area.find(qn("c:barChart")) - if bar_chart is None: + plot_area = cast(_XmlElementLike, plot_area_obj) + + bar_chart_obj = plot_area.find(qn("c:barChart")) + if bar_chart_obj is None: return [] - return cast(list[Any], bar_chart.findall(qn("c:ser"))) + bar_chart = cast(_XmlElementLike, bar_chart_obj) + + return list(bar_chart.findall(qn("c:ser"))) def measure_label_width(text: str) -> int: @@ -186,36 +215,40 @@ def measure_label_width(text: str) -> int: return int(DEFAULT_WATERFALL_LABEL_WIDTH_BASE + DEFAULT_WATERFALL_LABEL_WIDTH_PER_CHAR * length) -def _ensure_dlbls(series_element: Any) -> Any: - dlbls = series_element.find(qn("c:dLbls")) - if dlbls is not None: - return dlbls +def _ensure_dlbls(series_element: object) -> object: + series_el = cast(_XmlElementLike, series_element) + dlbls_obj = series_el.find(qn("c:dLbls")) + if dlbls_obj is not None: + return dlbls_obj dlbls = OxmlElement("c:dLbls") - insert_at = series_element.find(qn("c:val")) - if insert_at is not None: - series_element.insert(series_element.index(insert_at), dlbls) + insert_at_obj = series_el.find(qn("c:val")) + if insert_at_obj is not None: + series_el.insert(series_el.index(insert_at_obj), dlbls) else: - series_element.append(dlbls) + series_el.append(dlbls) return dlbls -def _set_child_val(parent: Any, tag: str, value: Union[str, int]) -> Any: - elem = parent.find(qn(tag)) - if elem is None: - elem = OxmlElement(tag) - parent.append(elem) +def _set_child_val(parent: object, tag: str, value: Union[str, int]) -> object: + parent_el = cast(_XmlElementLike, parent) + elem_obj = parent_el.find(qn(tag)) + if elem_obj is None: + elem_obj = OxmlElement(tag) + parent_el.append(elem_obj) + + elem = cast(_XmlElementLike, elem_obj) elem.set("val", str(value)) return elem def _add_dlbl( - dlbls: Any, + dlbls: object, point_idx: int, show_val: bool = True, manual_x: Union[float, None] = None, manual_y: Union[float, None] = None, -) -> Any: +) -> object: dlbl = OxmlElement("c:dLbl") idx_el = OxmlElement("c:idx") idx_el.set("val", str(point_idx)) @@ -242,7 +275,9 @@ def _add_dlbl( _set_child_val(dlbl, "c:showSerName", 0) _set_child_val(dlbl, "c:showPercent", 0) _set_child_val(dlbl, "c:showBubbleSize", 0) - dlbls.append(dlbl) + + dlbls_el = cast(_XmlElementLike, dlbls) + dlbls_el.append(dlbl) return dlbl @@ -416,7 +451,7 @@ def apply_waterfall_data_label_layout( "y": dy / plot_height if plot_height else 0.0, } - chart_space = getattr(chart, "_chartSpace", None) + chart_space = chart_xml_space(chart) if chart_space is None: return @@ -429,18 +464,20 @@ def apply_waterfall_data_label_layout( if 0 <= offset_idx < len(series_elements): series_element = series_elements[offset_idx] dlbls = _ensure_dlbls(series_element) + dlbls_el = cast(_XmlElementLike, dlbls) - for child in list(dlbls): + for child_obj in list(dlbls_el): + child = cast(_XmlElementLike, child_obj) if child.tag == qn("c:dLbl"): - dlbls.remove(child) + dlbls_el.remove(child) - _set_child_val(dlbls, "c:dLblPos", "ctr") - _set_child_val(dlbls, "c:showLegendKey", 0) - _set_child_val(dlbls, "c:showVal", 0) - _set_child_val(dlbls, "c:showCatName", 0) - _set_child_val(dlbls, "c:showSerName", 0) - _set_child_val(dlbls, "c:showPercent", 0) - _set_child_val(dlbls, "c:showBubbleSize", 0) + _set_child_val(dlbls_el, "c:dLblPos", "ctr") + _set_child_val(dlbls_el, "c:showLegendKey", 0) + _set_child_val(dlbls_el, "c:showVal", 0) + _set_child_val(dlbls_el, "c:showCatName", 0) + _set_child_val(dlbls_el, "c:showSerName", 0) + _set_child_val(dlbls_el, "c:showPercent", 0) + _set_child_val(dlbls_el, "c:showBubbleSize", 0) default_manual_y = default_dy / plot_height if plot_height else 0.0 @@ -449,7 +486,7 @@ def apply_waterfall_data_label_layout( manual_x = layout.get("x", 0.0) if layout else 0.0 manual_y = layout.get("y", default_manual_y) if layout else default_manual_y _add_dlbl( - dlbls, + dlbls_el, idx, show_val=True, manual_x=manual_x, diff --git a/clean_slides/chart_engine/overlay_waterfall_labels.py b/clean_slides/chart_engine/overlay_waterfall_labels.py index 5b3e600..3ed0ac6 100644 --- a/clean_slides/chart_engine/overlay_waterfall_labels.py +++ b/clean_slides/chart_engine/overlay_waterfall_labels.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, Callable, Union, cast +from typing import Callable, Union, cast from pptx.enum.text import PP_ALIGN @@ -195,7 +195,7 @@ def build_waterfall_value_label_specs( return label_specs -def add_waterfall_value_labels(slide: Any, label_specs: Sequence[LabelSpec]) -> None: +def add_waterfall_value_labels(slide: object, label_specs: Sequence[LabelSpec]) -> None: """Render prepared value-label specs.""" for spec in label_specs: text = spec.get("text") @@ -230,7 +230,7 @@ def add_waterfall_value_labels(slide: Any, label_specs: Sequence[LabelSpec]) -> def add_waterfall_category_labels( - slide: Any, + slide: object, categories: Sequence[object], chart_box: tuple[int, int, int, int], geometry: Geometry, @@ -289,7 +289,7 @@ def add_waterfall_category_labels( def add_waterfall_series_labels( - slide: Any, + slide: object, chart_box: tuple[int, int, int, int], chart_series_names: Sequence[object], segment_values: Mapping[object, object], diff --git a/clean_slides/inspect_pptx.py b/clean_slides/inspect_pptx.py index 092479a..aff631c 100644 --- a/clean_slides/inspect_pptx.py +++ b/clean_slides/inspect_pptx.py @@ -8,7 +8,7 @@ from __future__ import annotations from dataclasses import asdict, dataclass, field -from typing import Any, Protocol, TypedDict, cast +from typing import Protocol, TypedDict, cast from lxml import etree from pptx.enum.shapes import MSO_SHAPE_TYPE @@ -24,13 +24,11 @@ from typing_extensions import TypeGuard from .pptx_access import chart_xml_space, paragraph_xml_element, text_frame_xml_element +from .xml_helpers import XmlElement # ── Helpers ──────────────────────────────────────────────────────────── -XmlElement = Any - - def _is_text_shape(shape: BaseShape) -> TypeGuard[PptxShape]: return isinstance(shape, PptxShape) and shape.has_text_frame