Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 136 additions & 52 deletions clean_slides/chart_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
from typing import Protocol, cast

from pptx.oxml.xmlchemy import OxmlElement
from pptx.slide import Slide
Expand All @@ -19,9 +20,24 @@
from .chart_engine.spec_utils import object_list, optional_str_list, str_key_dict
from .charts import ChartEngine
from .pptx_access import (
chart_first_plot,
chart_series,
chart_xml_element,
paragraph_font,
plot_data_labels,
point_fill_fore_color,
point_fill_solid,
point_line_fill_background,
series_points,
set_chart_has_legend,
set_font_size,
set_plot_has_data_labels,
set_text_frame_text,
shape_chart,
shape_has_text_frame,
shape_text_frame,
slide_add_chart,
slide_size_emu,
text_frame_paragraphs,
text_frame_text,
)
Expand All @@ -44,6 +60,22 @@ class ChartGroup:
max_col: int


class _XmlElementLike(Protocol):
def iter(self, tag: str) -> Iterable[object]: ...

def find(self, path: str) -> object | None: ...

def append(self, element: object) -> None: ...

def insert(self, index: int, element: object) -> None: ...

def addprevious(self, element: object) -> None: ...


class _XmlSettable(Protocol):
def set(self, key: str, value: str) -> None: ...


def _iterable_objects(value: object) -> list[object]:
"""Coerce list/tuple/set payloads into ``list[object]``."""
if isinstance(value, set):
Expand Down Expand Up @@ -481,14 +513,14 @@ def _rewrite_overlay_value_label_texts(
paragraphs = text_frame_paragraphs(text_frame)
if paragraphs:
first_paragraph = paragraphs[0]
font = getattr(first_paragraph, "font", None)
font = paragraph_font(first_paragraph)
if font is not None:
font.size = Pt(font_size_pt)
set_font_size(font, Pt(font_size_pt))

label_idx += 1


def _set_label_nowrap(chart: Any) -> None:
def _set_label_nowrap(chart: object) -> None:
"""Set ``wrap="none"`` on every data-label ``bodyPr`` in the chart.

Without this, PowerPoint (and LibreOffice) may wrap short labels
Expand All @@ -498,22 +530,30 @@ def _set_label_nowrap(chart: Any) -> None:

ns_a = "http://schemas.openxmlformats.org/drawingml/2006/main"
ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart"
chart_el = chart._element
for dlbls in chart_el.iter(f"{{{ns_c}}}dLbls"):
tx_pr = dlbls.find(f"{{{ns_c}}}txPr")
if tx_pr is None:
chart_el_obj = chart_xml_element(chart)
if chart_el_obj is None:
return
chart_el = cast(_XmlElementLike, chart_el_obj)

for dlbls_obj in chart_el.iter(f"{{{ns_c}}}dLbls"):
dlbls = cast(_XmlElementLike, dlbls_obj)
tx_pr_obj = dlbls.find(f"{{{ns_c}}}txPr")
if tx_pr_obj is None:
continue
body_pr = tx_pr.find(f"{{{ns_a}}}bodyPr")
if body_pr is not None:
body_pr.set("wrap", "none")
tx_pr = cast(_XmlElementLike, tx_pr_obj)
body_pr_obj = tx_pr.find(f"{{{ns_a}}}bodyPr")
if body_pr_obj is None:
continue
body_pr = cast(_XmlSettable, body_pr_obj)
body_pr.set("wrap", "none")


# Visual gap between bar tip and label left edge, as a fraction of chart width.
_HORIZONTAL_LABEL_VISUAL_GAP: float = 0.02


def _set_horizontal_label_offsets(
chart: Any,
chart: object,
values: list[float],
axis_max: float,
plot_w: float,
Expand All @@ -535,7 +575,10 @@ def _set_horizontal_label_offsets(
"""
from lxml import etree

chart_el = chart._element
chart_el_obj = chart_xml_element(chart)
if chart_el_obj is None:
return
chart_el = cast(_XmlElementLike, chart_el_obj)
ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart"

# The OOXML structure is:
Expand All @@ -544,29 +587,35 @@ def _set_horizontal_label_offsets(
# python-pptx creates dLbls at the barChart level. We need a ser-level
# dLbls container to hold per-point dLbl elements with manual offsets.

bar_chart = chart_el.find(f".//{{{ns_c}}}barChart")
if bar_chart is None:
bar_chart_obj = chart_el.find(f".//{{{ns_c}}}barChart")
if bar_chart_obj is None:
return
ser = bar_chart.find(f"{{{ns_c}}}ser")
if ser is None:
bar_chart = cast(_XmlElementLike, bar_chart_obj)

ser_obj = bar_chart.find(f"{{{ns_c}}}ser")
if ser_obj is None:
return
ser = cast(_XmlElementLike, ser_obj)

# Create ser-level dLbls container (or reuse if it already exists)
ser_dlbls = ser.find(f"{{{ns_c}}}dLbls")
if ser_dlbls is None:
ser_dlbls_obj = ser.find(f"{{{ns_c}}}dLbls")
if ser_dlbls_obj is None:
ser_dlbls_el = OxmlElement("c:dLbls")
# Insert before cat/val
insert_before: Any | None = None
insert_before_obj: object | None = None
for tag_suffix in ("cat", "val", "shape", "extLst"):
insert_before = ser.find(f"{{{ns_c}}}{tag_suffix}")
if insert_before is not None:
insert_before_obj = ser.find(f"{{{ns_c}}}{tag_suffix}")
if insert_before_obj is not None:
break
if insert_before is not None:
if insert_before_obj is not None:
insert_before = cast(_XmlElementLike, insert_before_obj)
insert_before.addprevious(ser_dlbls_el)
else:
ser.append(ser_dlbls_el)
ser_dlbls = ser.find(f"{{{ns_c}}}dLbls")
assert ser_dlbls is not None
ser_dlbls_obj = ser.find(f"{{{ns_c}}}dLbls")
if ser_dlbls_obj is None:
return
ser_dlbls = cast(_XmlElementLike, ser_dlbls_obj)

# Character width estimate: ~0.6 × font size per character
char_width_emu = font_size_pt * 0.6 * _EMU_PER_PT
Expand Down Expand Up @@ -659,18 +708,24 @@ def _set_horizontal_label_offsets(
ser_dlbls.append(shared_el)


def _delete_auto_title(chart: Any) -> None:
def _delete_auto_title(chart: object) -> None:
"""Suppress the auto-generated chart title (series name watermark)."""
ns_c = "http://schemas.openxmlformats.org/drawingml/2006/chart"
chart_el = chart._element
ns = {"c": ns_c}
auto_title = chart_el.find(".//c:autoTitleDeleted", ns)
if auto_title is not None:
auto_title.set("val", "1")
chart_el_obj = chart_xml_element(chart)
if chart_el_obj is None:
return

chart_el = cast(_XmlElementLike, chart_el_obj)
auto_title_obj = chart_el.find(".//{" + ns_c + "}autoTitleDeleted")
if auto_title_obj is None:
return

auto_title = cast(_XmlSettable, auto_title_obj)
auto_title.set("val", "1")


def _apply_point_colors(
chart: Any,
chart: object,
charts_module: ChartEngine,
point_colors: list[str | None],
series_idx: int,
Expand All @@ -679,15 +734,25 @@ def _apply_point_colors(
if not point_colors:
return

series = chart.series[series_idx]
series_list = chart_series(chart)
if series_idx < 0 or series_idx >= len(series_list):
return

series = series_list[series_idx]
points = series_points(series)
for point_idx, color in enumerate(point_colors):
if not color:
if not color or point_idx >= len(points):
continue
point = series.points[point_idx]
point.format.fill.solid()
applied = bool(charts_module.apply_color(point.format.fill.fore_color, color))

point = points[point_idx]
point_fill_solid(point)
fore_color = point_fill_fore_color(point)
if fore_color is None:
continue

applied = bool(charts_module.apply_color(fore_color, color))
if applied:
point.format.line.fill.background()
point_line_fill_background(point)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -785,17 +850,23 @@ def _render_bar_group(

chart_type, chart_data, style = charts_module.build_bar_payload(chart_spec)

chart_frame = cast(Any, slide.shapes).add_chart(
chart_frame = slide_add_chart(
slide,
chart_type,
Emu(x),
Emu(y),
Emu(w),
Emu(h),
chart_data,
)
chart = chart_frame.chart
if chart_frame is None:
return

chart.has_legend = False
chart = shape_chart(chart_frame)
if chart is None:
return

set_chart_has_legend(chart, False)
_delete_auto_title(chart)

series_colors = optional_str_list(style.get("series_colors", []))
Expand All @@ -810,9 +881,16 @@ def _render_bar_group(

if has_labels:
data_cfg = str_key_dict(chart_spec.get("data_labels"))
plot = chart.plots[0]
plot.has_data_labels = True
charts_module.apply_data_label_style(plot.data_labels, data_cfg)
plot = chart_first_plot(chart)
if plot is None:
return

set_plot_has_data_labels(plot, True)
labels = plot_data_labels(plot)
if labels is None:
return

charts_module.apply_data_label_style(labels, data_cfg)
_set_label_nowrap(chart)

if group.chart_def.dir == "horizontal":
Expand Down Expand Up @@ -869,17 +947,23 @@ def _render_waterfall_group(

chart_type, chart_data, style = charts_module.build_waterfall_payload(chart_spec)

chart_frame = cast(Any, slide.shapes).add_chart(
chart_frame = slide_add_chart(
slide,
chart_type,
Emu(x),
Emu(y),
Emu(w),
Emu(h),
chart_data,
)
chart = chart_frame.chart
if chart_frame is None:
return

chart = shape_chart(chart_frame)
if chart is None:
return

chart.has_legend = False
set_chart_has_legend(chart, False)
_delete_auto_title(chart)

# Apply series colors (offset series + value series)
Expand All @@ -895,9 +979,9 @@ def _render_waterfall_group(
if group.chart_def.colors:
offset_idx_obj = wf_meta.get("offset_series_idx")
offset_idx = offset_idx_obj if isinstance(offset_idx_obj, int) else 0
series = cast(list[Any], chart.series)
series_list = chart_series(chart)
target_idx: int | None = None
for idx in range(len(series)):
for idx in range(len(series_list)):
if idx != offset_idx:
target_idx = idx
break
Expand All @@ -914,9 +998,9 @@ def _render_waterfall_group(
# Pre-compute desired label texts (supports format strings like {:,.0f}).
label_texts = _waterfall_overlay_label_texts(wf_meta, group.chart_def.format)

prs_part = cast(Any, slide.part.package).presentation_part
prs_obj = prs_part.presentation
slide_size: tuple[int, int] = (int(prs_obj.slide_width), int(prs_obj.slide_height))
slide_size = slide_size_emu(slide)
if slide_size is None:
return

before_shape_count = len(slide.shapes)
charts_module.add_waterfall_overlays(
Expand Down
Loading