Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
226318d
Add pretty printing functionality
RNKuhns Mar 13, 2023
aa166f2
Minor improvements to config test coverage
RNKuhns Mar 13, 2023
5ad1707
Add global config
RNKuhns Mar 13, 2023
34a9c5a
Update global config tests
RNKuhns Mar 13, 2023
fc5bee3
Add updated tests of local configs
RNKuhns Mar 13, 2023
d89e7ca
Add global config to docs
RNKuhns Mar 13, 2023
68560d3
Fix user_guide panel for configuration
RNKuhns Mar 13, 2023
ddaa90e
Fix docstring typo
RNKuhns Mar 13, 2023
6e4d330
merge in global config updates
RNKuhns Mar 13, 2023
b8f62d1
Fix pre-commit errors
RNKuhns Mar 13, 2023
b20283b
Add missing __init__
RNKuhns Mar 13, 2023
eeee466
Tweak behavior when invalid values encountered
RNKuhns Mar 13, 2023
91a7290
Remove unused code line
RNKuhns Mar 13, 2023
289d0ab
Move config implementation out of __init__
RNKuhns Mar 13, 2023
d0ba27a
Merge global config updates
RNKuhns Mar 13, 2023
6481944
Add metadata to pass metadata functionality tests
RNKuhns Mar 13, 2023
c47e96d
Use call to local config not global config for pretty printing config…
RNKuhns Mar 13, 2023
3c49f7d
Add more pretty printing test cases
RNKuhns Mar 19, 2023
a7f0de8
remove global config
fkiraly Apr 18, 2023
023970e
remove more config
fkiraly Apr 18, 2023
6f14d9a
more config removed
fkiraly Apr 18, 2023
079d6c7
remove config
fkiraly Apr 18, 2023
ade2c47
Merge branch 'main' into prettyprint-noconfig
fkiraly Apr 18, 2023
e7e94b5
remove global config from tests
fkiraly Apr 18, 2023
0b8f62f
default config
fkiraly Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class name: BaseEstimator
fitted state check - check_is_fitted (raises error if not is_fitted)
"""
import inspect
import re
import warnings
from collections import defaultdict
from copy import deepcopy
Expand All @@ -62,6 +63,7 @@ class name: BaseEstimator
from sklearn.base import BaseEstimator as _BaseEstimator

from skbase._exceptions import NotFittedError
from skbase.base._pretty_printing._object_html_repr import _object_html_repr
from skbase.base._tagmanager import _FlagManager

__author__: List[str] = ["mloning", "RNKuhns", "fkiraly"]
Expand All @@ -74,6 +76,11 @@ class BaseObject(_FlagManager, _BaseEstimator):
Extends scikit-learn's BaseEstimator to include sktime style interface for tags.
"""

_config = {
"display": "diagram",
"print_changed_only": True,
}

def __init__(self):
"""Construct BaseObject."""
self._init_flags(flag_attr_name="_tags")
Expand Down Expand Up @@ -682,6 +689,98 @@ def _components(self, base_class=None):

return comp_dict

def __repr__(self, n_char_max: int = 700):
"""Represent class as string.

This follows the scikit-learn implementation for the string representation
of parameterized objects.

Parameters
----------
n_char_max : int
Maximum (approximate) number of non-blank characters to render. This
can be useful in testing.
"""
from skbase.base._pretty_printing._pprint import _BaseObjectPrettyPrinter

n_max_elements_to_show = 30 # number of elements to show in sequences
# use ellipsis for sequences with a lot of elements
pp = _BaseObjectPrettyPrinter(
compact=True,
indent=1,
indent_at_name=True,
n_max_elements_to_show=n_max_elements_to_show,
changed_only=self.get_config()["print_changed_only"],
)

repr_ = pp.pformat(self)

# Use bruteforce ellipsis when there are a lot of non-blank characters
n_nonblank = len("".join(repr_.split()))
if n_nonblank > n_char_max:
lim = n_char_max // 2 # apprx number of chars to keep on both ends
regex = r"^(\s*\S){%d}" % lim
# The regex '^(\s*\S){%d}' matches from the start of the string
# until the nth non-blank character:
# - ^ matches the start of string
# - (pattern){n} matches n repetitions of pattern
# - \s*\S matches a non-blank char following zero or more blanks
left_match = re.match(regex, repr_)
right_match = re.match(regex, repr_[::-1])
left_lim = left_match.end() if left_match is not None else 0
right_lim = right_match.end() if right_match is not None else 0

if "\n" in repr_[left_lim:-right_lim]:
# The left side and right side aren't on the same line.
# To avoid weird cuts, e.g.:
# categoric...ore',
# we need to start the right side with an appropriate newline
# character so that it renders properly as:
# categoric...
# handle_unknown='ignore',
# so we add [^\n]*\n which matches until the next \n
regex += r"[^\n]*\n"
right_match = re.match(regex, repr_[::-1])
right_lim = right_match.end() if right_match is not None else 0

ellipsis = "..."
if left_lim + len(ellipsis) < len(repr_) - right_lim:
# Only add ellipsis if it results in a shorter repr
repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:]

return repr_

@property
def _repr_html_(self):
"""HTML representation of BaseObject.

This is redundant with the logic of `_repr_mimebundle_`. The latter
should be favorted in the long term, `_repr_html_` is only
implemented for consumers who do not interpret `_repr_mimbundle_`.
"""
if self.get_config()["display"] != "diagram":
raise AttributeError(
"_repr_html_ is only defined when the "
"`display` configuration option is set to 'diagram'."
)
return self._repr_html_inner

def _repr_html_inner(self):
"""Return HTML representation of class.

This function is returned by the @property `_repr_html_` to make
`hasattr(BaseObject, "_repr_html_") return `True` or `False` depending
on `self.get_config()["display"]`.
"""
return _object_html_repr(self)

def _repr_mimebundle_(self, **kwargs):
"""Mime bundle used by jupyter kernels to display instances of BaseObject."""
output = {"text/plain": repr(self)}
if self.get_config()["display"] == "diagram":
output["text/html"] = _object_html_repr(self)
return output


class TagAliaserMixin:
"""Mixin class for tag aliasing and deprecation of old tags.
Expand Down
11 changes: 11 additions & 0 deletions skbase/base/_pretty_printing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
# Many elements of this code were developed in scikit-learn. These elements
# are copyrighted by the scikit-learn developers, BSD-3-Clause License. For
# conditions see https://github.com/scikit-learn/scikit-learn/blob/main/COPYING
"""Functionality for pretty printing BaseObjects."""
from typing import List

__author__: List[str] = ["RNKuhns"]
__all__: List[str] = []
Loading