diff --git a/clean_slides/chart_render.py b/clean_slides/chart_render.py index 2f4bcc6..1c5205b 100644 --- a/clean_slides/chart_render.py +++ b/clean_slides/chart_render.py @@ -9,8 +9,9 @@ from __future__ import annotations +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, cast +from typing import Protocol, cast from pptx.oxml.xmlchemy import OxmlElement from pptx.slide import Slide @@ -19,9 +20,24 @@ from .chart_engine.spec_utils import object_list, optional_str_list, str_key_dict from .charts import ChartEngine from .pptx_access import ( + chart_first_plot, + chart_series, + chart_xml_element, + paragraph_font, + plot_data_labels, + point_fill_fore_color, + point_fill_solid, + point_line_fill_background, + series_points, + set_chart_has_legend, + set_font_size, + set_plot_has_data_labels, set_text_frame_text, + shape_chart, shape_has_text_frame, shape_text_frame, + slide_add_chart, + slide_size_emu, text_frame_paragraphs, text_frame_text, ) @@ -44,6 +60,22 @@ class ChartGroup: max_col: int +class _XmlElementLike(Protocol): + def iter(self, tag: str) -> Iterable[object]: ... + + def find(self, path: str) -> object | None: ... + + def append(self, element: object) -> None: ... + + def insert(self, index: int, element: object) -> None: ... + + def addprevious(self, element: object) -> None: ... + + +class _XmlSettable(Protocol): + def set(self, key: str, value: str) -> None: ... + + def _iterable_objects(value: object) -> list[object]: """Coerce list/tuple/set payloads into ``list[object]``.""" if isinstance(value, set): @@ -481,14 +513,14 @@ def _rewrite_overlay_value_label_texts( paragraphs = text_frame_paragraphs(text_frame) if paragraphs: first_paragraph = paragraphs[0] - font = getattr(first_paragraph, "font", None) + font = paragraph_font(first_paragraph) if font is not None: - font.size = Pt(font_size_pt) + set_font_size(font, Pt(font_size_pt)) label_idx += 1 -def _set_label_nowrap(chart: Any) -> None: +def _set_label_nowrap(chart: object) -> None: """Set ``wrap="none"`` on every data-label ``bodyPr`` in the chart. Without this, PowerPoint (and LibreOffice) may wrap short labels @@ -498,14 +530,22 @@ def _set_label_nowrap(chart: Any) -> None: ns_a = "http://schemas.openxmlformats.org/drawingml/2006/main" ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart" - chart_el = chart._element - for dlbls in chart_el.iter(f"{{{ns_c}}}dLbls"): - tx_pr = dlbls.find(f"{{{ns_c}}}txPr") - if tx_pr is None: + chart_el_obj = chart_xml_element(chart) + if chart_el_obj is None: + return + chart_el = cast(_XmlElementLike, chart_el_obj) + + for dlbls_obj in chart_el.iter(f"{{{ns_c}}}dLbls"): + dlbls = cast(_XmlElementLike, dlbls_obj) + tx_pr_obj = dlbls.find(f"{{{ns_c}}}txPr") + if tx_pr_obj is None: continue - body_pr = tx_pr.find(f"{{{ns_a}}}bodyPr") - if body_pr is not None: - body_pr.set("wrap", "none") + tx_pr = cast(_XmlElementLike, tx_pr_obj) + body_pr_obj = tx_pr.find(f"{{{ns_a}}}bodyPr") + if body_pr_obj is None: + continue + body_pr = cast(_XmlSettable, body_pr_obj) + body_pr.set("wrap", "none") # Visual gap between bar tip and label left edge, as a fraction of chart width. @@ -513,7 +553,7 @@ def _set_label_nowrap(chart: Any) -> None: def _set_horizontal_label_offsets( - chart: Any, + chart: object, values: list[float], axis_max: float, plot_w: float, @@ -535,7 +575,10 @@ def _set_horizontal_label_offsets( """ from lxml import etree - chart_el = chart._element + chart_el_obj = chart_xml_element(chart) + if chart_el_obj is None: + return + chart_el = cast(_XmlElementLike, chart_el_obj) ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart" # The OOXML structure is: @@ -544,29 +587,35 @@ def _set_horizontal_label_offsets( # python-pptx creates dLbls at the barChart level. We need a ser-level # dLbls container to hold per-point dLbl elements with manual offsets. - bar_chart = chart_el.find(f".//{{{ns_c}}}barChart") - if bar_chart is None: + bar_chart_obj = chart_el.find(f".//{{{ns_c}}}barChart") + if bar_chart_obj is None: return - ser = bar_chart.find(f"{{{ns_c}}}ser") - if ser is None: + bar_chart = cast(_XmlElementLike, bar_chart_obj) + + ser_obj = bar_chart.find(f"{{{ns_c}}}ser") + if ser_obj is None: return + ser = cast(_XmlElementLike, ser_obj) # Create ser-level dLbls container (or reuse if it already exists) - ser_dlbls = ser.find(f"{{{ns_c}}}dLbls") - if ser_dlbls is None: + ser_dlbls_obj = ser.find(f"{{{ns_c}}}dLbls") + if ser_dlbls_obj is None: ser_dlbls_el = OxmlElement("c:dLbls") # Insert before cat/val - insert_before: Any | None = None + insert_before_obj: object | None = None for tag_suffix in ("cat", "val", "shape", "extLst"): - insert_before = ser.find(f"{{{ns_c}}}{tag_suffix}") - if insert_before is not None: + insert_before_obj = ser.find(f"{{{ns_c}}}{tag_suffix}") + if insert_before_obj is not None: break - if insert_before is not None: + if insert_before_obj is not None: + insert_before = cast(_XmlElementLike, insert_before_obj) insert_before.addprevious(ser_dlbls_el) else: ser.append(ser_dlbls_el) - ser_dlbls = ser.find(f"{{{ns_c}}}dLbls") - assert ser_dlbls is not None + ser_dlbls_obj = ser.find(f"{{{ns_c}}}dLbls") + if ser_dlbls_obj is None: + return + ser_dlbls = cast(_XmlElementLike, ser_dlbls_obj) # Character width estimate: ~0.6 × font size per character char_width_emu = font_size_pt * 0.6 * _EMU_PER_PT @@ -659,18 +708,24 @@ def _set_horizontal_label_offsets( ser_dlbls.append(shared_el) -def _delete_auto_title(chart: Any) -> None: +def _delete_auto_title(chart: object) -> None: """Suppress the auto-generated chart title (series name watermark).""" ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart" - chart_el = chart._element - ns = {"c": ns_c} - auto_title = chart_el.find(".//c:autoTitleDeleted", ns) - if auto_title is not None: - auto_title.set("val", "1") + chart_el_obj = chart_xml_element(chart) + if chart_el_obj is None: + return + + chart_el = cast(_XmlElementLike, chart_el_obj) + auto_title_obj = chart_el.find(".//{" + ns_c + "}autoTitleDeleted") + if auto_title_obj is None: + return + + auto_title = cast(_XmlSettable, auto_title_obj) + auto_title.set("val", "1") def _apply_point_colors( - chart: Any, + chart: object, charts_module: ChartEngine, point_colors: list[str | None], series_idx: int, @@ -679,15 +734,25 @@ def _apply_point_colors( if not point_colors: return - series = chart.series[series_idx] + series_list = chart_series(chart) + if series_idx < 0 or series_idx >= len(series_list): + return + + series = series_list[series_idx] + points = series_points(series) for point_idx, color in enumerate(point_colors): - if not color: + if not color or point_idx >= len(points): continue - point = series.points[point_idx] - point.format.fill.solid() - applied = bool(charts_module.apply_color(point.format.fill.fore_color, color)) + + point = points[point_idx] + point_fill_solid(point) + fore_color = point_fill_fore_color(point) + if fore_color is None: + continue + + applied = bool(charts_module.apply_color(fore_color, color)) if applied: - point.format.line.fill.background() + point_line_fill_background(point) # --------------------------------------------------------------------------- @@ -785,7 +850,8 @@ def _render_bar_group( chart_type, chart_data, style = charts_module.build_bar_payload(chart_spec) - chart_frame = cast(Any, slide.shapes).add_chart( + chart_frame = slide_add_chart( + slide, chart_type, Emu(x), Emu(y), @@ -793,9 +859,14 @@ def _render_bar_group( Emu(h), chart_data, ) - chart = chart_frame.chart + if chart_frame is None: + return - chart.has_legend = False + chart = shape_chart(chart_frame) + if chart is None: + return + + set_chart_has_legend(chart, False) _delete_auto_title(chart) series_colors = optional_str_list(style.get("series_colors", [])) @@ -810,9 +881,16 @@ def _render_bar_group( if has_labels: data_cfg = str_key_dict(chart_spec.get("data_labels")) - plot = chart.plots[0] - plot.has_data_labels = True - charts_module.apply_data_label_style(plot.data_labels, data_cfg) + plot = chart_first_plot(chart) + if plot is None: + return + + set_plot_has_data_labels(plot, True) + labels = plot_data_labels(plot) + if labels is None: + return + + charts_module.apply_data_label_style(labels, data_cfg) _set_label_nowrap(chart) if group.chart_def.dir == "horizontal": @@ -869,7 +947,8 @@ def _render_waterfall_group( chart_type, chart_data, style = charts_module.build_waterfall_payload(chart_spec) - chart_frame = cast(Any, slide.shapes).add_chart( + chart_frame = slide_add_chart( + slide, chart_type, Emu(x), Emu(y), @@ -877,9 +956,14 @@ def _render_waterfall_group( Emu(h), chart_data, ) - chart = chart_frame.chart + if chart_frame is None: + return + + chart = shape_chart(chart_frame) + if chart is None: + return - chart.has_legend = False + set_chart_has_legend(chart, False) _delete_auto_title(chart) # Apply series colors (offset series + value series) @@ -895,9 +979,9 @@ def _render_waterfall_group( if group.chart_def.colors: offset_idx_obj = wf_meta.get("offset_series_idx") offset_idx = offset_idx_obj if isinstance(offset_idx_obj, int) else 0 - series = cast(list[Any], chart.series) + series_list = chart_series(chart) target_idx: int | None = None - for idx in range(len(series)): + for idx in range(len(series_list)): if idx != offset_idx: target_idx = idx break @@ -914,9 +998,9 @@ def _render_waterfall_group( # Pre-compute desired label texts (supports format strings like {:,.0f}). label_texts = _waterfall_overlay_label_texts(wf_meta, group.chart_def.format) - prs_part = cast(Any, slide.part.package).presentation_part - prs_obj = prs_part.presentation - slide_size: tuple[int, int] = (int(prs_obj.slide_width), int(prs_obj.slide_height)) + slide_size = slide_size_emu(slide) + if slide_size is None: + return before_shape_count = len(slide.shapes) charts_module.add_waterfall_overlays( diff --git a/clean_slides/pptx_access.py b/clean_slides/pptx_access.py index d5a1d46..772f63a 100644 --- a/clean_slides/pptx_access.py +++ b/clean_slides/pptx_access.py @@ -158,6 +158,149 @@ def chart_xml_space(chart: object) -> object | None: return getattr(chart, "_chartSpace", None) +def chart_xml_element(chart: object) -> object | None: + """Return underlying OOXML chart element when available.""" + return getattr(chart, "_element", None) + + +class _AddChartCallable(Protocol): + def __call__( + self, + chart_type: object, + x: object, + y: object, + cx: object, + cy: object, + chart_data: object, + ) -> object: ... + + +def slide_add_chart( + slide: object, + chart_type: object, + x: object, + y: object, + cx: object, + cy: object, + chart_data: object, +) -> object | None: + """Add a chart to a slide when supported and return the chart frame.""" + shapes = getattr(slide, "shapes", None) + add_chart = getattr(shapes, "add_chart", None) + if not callable(add_chart): + return None + add_chart_fn = cast(_AddChartCallable, add_chart) + return add_chart_fn(chart_type, x, y, cx, cy, chart_data) + + +class _MutableChartLegend(Protocol): + has_legend: bool + + +def set_chart_has_legend(chart: object, has_legend: bool) -> None: + """Set chart legend visibility.""" + mutable_chart = cast(_MutableChartLegend, chart) + mutable_chart.has_legend = has_legend + + +def chart_series(chart: object) -> list[object]: + """Return chart series objects.""" + return _iter_objects(getattr(chart, "series", None)) + + +def chart_plots(chart: object) -> list[object]: + """Return chart plot objects.""" + return _iter_objects(getattr(chart, "plots", None)) + + +def chart_first_plot(chart: object) -> object | None: + """Return chart first plot when present.""" + plots = chart_plots(chart) + return plots[0] if plots else None + + +class _MutablePlotDataLabels(Protocol): + has_data_labels: bool + + +class _PlotDataLabels(Protocol): + @property + def data_labels(self) -> object: ... + + +def set_plot_has_data_labels(plot: object, has_data_labels: bool) -> None: + """Set data-label visibility for a chart plot.""" + mutable_plot = cast(_MutablePlotDataLabels, plot) + mutable_plot.has_data_labels = has_data_labels + + +def plot_data_labels(plot: object) -> object | None: + """Return plot data-label collection when available.""" + return getattr(cast(_PlotDataLabels, plot), "data_labels", None) + + +def series_points(series: object) -> list[object]: + """Return points from a chart series.""" + return _iter_objects(getattr(series, "points", None)) + + +def point_fill_solid(point: object) -> None: + """Convert point fill to solid when supported.""" + point_format = getattr(point, "format", None) + fill = getattr(point_format, "fill", None) + solid = getattr(fill, "solid", None) + if callable(solid): + solid() + + +def point_fill_fore_color(point: object) -> object | None: + """Return point fill foreground color when available.""" + point_format = getattr(point, "format", None) + fill = getattr(point_format, "fill", None) + return getattr(fill, "fore_color", None) + + +def point_line_fill_background(point: object) -> None: + """Set point line fill background when supported.""" + point_format = getattr(point, "format", None) + line = getattr(point_format, "line", None) + fill = getattr(line, "fill", None) + background = getattr(fill, "background", None) + if callable(background): + background() + + +def slide_size_emu(slide: object) -> tuple[int, int] | None: + """Return presentation slide size for a slide in EMU.""" + part = getattr(slide, "part", None) + package = getattr(part, "package", None) + presentation_part = getattr(package, "presentation_part", None) + presentation = getattr(presentation_part, "presentation", None) + width = getattr(presentation, "slide_width", None) + height = getattr(presentation, "slide_height", None) + if width is None or height is None: + return None + try: + return (int(width), int(height)) + except (TypeError, ValueError): + return None + + +def paragraph_font(paragraph: object) -> object | None: + """Return paragraph font object when available.""" + return getattr(paragraph, "font", None) + + +class _MutableFont(Protocol): + size: object + + +def set_font_size(font: object, size: object) -> None: + """Set font size for a paragraph font object.""" + mutable_font = cast(_MutableFont, font) + mutable_font.size = size + + def shape_has_connector_endpoints(shape: object) -> bool: """Return whether shape exposes connector endpoints.""" return getattr(shape, "begin_x", None) is not None and getattr(shape, "end_x", None) is not None diff --git a/tests/test_pptx_access.py b/tests/test_pptx_access.py index c106cd8..a6a32fc 100644 --- a/tests/test_pptx_access.py +++ b/tests/test_pptx_access.py @@ -1,15 +1,28 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from clean_slides.pptx_access import ( + chart_first_plot, + chart_plots, + chart_series, chart_series_names, chart_type_value, + chart_xml_element, chart_xml_space, iter_shapes, iter_slides, + paragraph_font, paragraph_xml_element, + plot_data_labels, + point_fill_fore_color, + point_fill_solid, + point_line_fill_background, presentation_chart_types, + series_points, + set_chart_has_legend, + set_font_size, + set_plot_has_data_labels, set_text_frame_text, shape_chart, shape_chart_type, @@ -21,11 +34,21 @@ shape_text_frame, shape_text_frame_text, shape_xml_element, + slide_add_chart, + slide_size_emu, text_frame_paragraphs, text_frame_text, ) +def _object_list() -> list[object]: + return [] + + +def _chart_add_calls() -> list[tuple[object, object, object, object, object, object]]: + return [] + + @dataclass class _Series: name: str @@ -34,8 +57,11 @@ class _Series: @dataclass class _Chart: chart_type: int - series: list[_Series] + series: list[object] _chartSpace: object | None = None + _element: object | None = None + plots: list[object] = field(default_factory=_object_list) + has_legend: bool = True @dataclass @@ -72,6 +98,113 @@ class _Presentation: slides: list[_Slide] +@dataclass +class _Fill: + fore_color: object + solid_calls: int = 0 + + def solid(self) -> None: + self.solid_calls += 1 + + +@dataclass +class _LineFill: + background_calls: int = 0 + + def background(self) -> None: + self.background_calls += 1 + + +@dataclass +class _Line: + fill: _LineFill + + +@dataclass +class _PointFormat: + fill: _Fill + line: _Line + + +@dataclass +class _Point: + format: _PointFormat + + +@dataclass +class _SeriesWithPoints: + points: list[_Point] + + +@dataclass +class _Plot: + has_data_labels: bool = False + data_labels: object | None = None + + +@dataclass +class _ChartFrame: + has_chart: bool + chart: _Chart + + +@dataclass +class _ShapesWithAddChart: + chart: _Chart + calls: list[tuple[object, object, object, object, object, object]] = field( + default_factory=_chart_add_calls + ) + + def add_chart( + self, + chart_type: object, + x: object, + y: object, + cx: object, + cy: object, + chart_data: object, + ) -> _ChartFrame: + self.calls.append((chart_type, x, y, cx, cy, chart_data)) + return _ChartFrame(has_chart=True, chart=self.chart) + + +@dataclass +class _PresentationRoot: + slide_width: int + slide_height: int + + +@dataclass +class _PresentationPart: + presentation: _PresentationRoot + + +@dataclass +class _Package: + presentation_part: _PresentationPart + + +@dataclass +class _Part: + package: _Package + + +@dataclass +class _SlideWithPart: + shapes: _ShapesWithAddChart + part: _Part + + +@dataclass +class _Font: + size: object | None = None + + +@dataclass +class _ParagraphWithFont: + font: _Font + + def test_presentation_chart_types_collects_chart_values() -> None: prs = _Presentation( slides=[ @@ -131,3 +264,66 @@ def test_shape_and_text_frame_helpers_cover_text_placeholder_connector() -> None assert text_frame.text == "Updated" assert chart_series_names(chart) == ["Revenue"] + + +def test_chart_mutation_and_slide_helpers_cover_dynamic_graphic_access() -> None: + point = _Point( + format=_PointFormat( + fill=_Fill(fore_color={"kind": "accent"}), + line=_Line(fill=_LineFill()), + ) + ) + series = _SeriesWithPoints(points=[point]) + plot = _Plot(data_labels={"k": "v"}) + chart = _Chart( + chart_type=57, + series=[series], + _element={"tag": "chart"}, + plots=[plot], + ) + + shapes = _ShapesWithAddChart(chart=chart) + slide = _SlideWithPart( + shapes=shapes, + part=_Part( + package=_Package( + presentation_part=_PresentationPart( + presentation=_PresentationRoot(slide_width=1280, slide_height=720) + ) + ) + ), + ) + + frame = slide_add_chart(slide, chart_type=57, x=1, y=2, cx=3, cy=4, chart_data={"d": 1}) + assert frame is not None + assert shape_chart(frame) is chart + assert len(shapes.calls) == 1 + + assert chart_xml_element(chart) == {"tag": "chart"} + set_chart_has_legend(chart, False) + assert chart.has_legend is False + + assert chart_series(chart) == [series] + assert series_points(series) == [point] + + point_fill_solid(point) + assert point.format.fill.solid_calls == 1 + assert point_fill_fore_color(point) == {"kind": "accent"} + + point_line_fill_background(point) + assert point.format.line.fill.background_calls == 1 + + assert chart_plots(chart) == [plot] + assert chart_first_plot(chart) is plot + + set_plot_has_data_labels(plot, True) + assert plot.has_data_labels is True + assert plot_data_labels(plot) == {"k": "v"} + + assert slide_size_emu(slide) == (1280, 720) + + paragraph = _ParagraphWithFont(font=_Font()) + font = paragraph_font(paragraph) + assert font is not None + set_font_size(font, 11) + assert paragraph.font.size == 11