Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion ax/analysis/healthcheck/healthcheck_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

# pyre-strict

from __future__ import annotations

import json
from enum import IntEnum

import pandas as pd
from ax.core.analysis_card import AnalysisCard
from ax.analysis.analysis import ErrorAnalysisCard
from ax.core.analysis_card import AnalysisCard, AnalysisCardBase


class HealthcheckStatus(IntEnum):
Expand All @@ -18,6 +21,13 @@ class HealthcheckStatus(IntEnum):
WARNING = 2


# Healthchecks that provide valuable progress info even when passing
PRIORITY_HEALTHCHECKS: set[str] = {
"BaselineImprovementAnalysis",
"EarlyStoppingAnalysis",
}


class HealthcheckAnalysisCard(AnalysisCard):
def get_status(self) -> HealthcheckStatus:
return HealthcheckStatus(json.loads(self.blob)["status"])
Expand Down Expand Up @@ -49,3 +59,49 @@ def create_healthcheck_analysis_card(
}
),
)


# Status order for sorting: FAIL first, then WARNING, then PASS
_STATUS_SORT_ORDER: dict[HealthcheckStatus, int] = {
HealthcheckStatus.FAIL: 1,
HealthcheckStatus.WARNING: 2,
HealthcheckStatus.PASS: 3,
}


def sort_healthcheck_cards(
cards: list[AnalysisCardBase],
) -> list[AnalysisCardBase]:
"""
Sort healthcheck cards by severity and priority.

Order:
1. ErrorAnalysisCard (errors during computation)
2. FAIL status
3. WARNING status
4. PASS status with priority (BaselineImprovement, EarlyStopping, etc.)
5. PASS status (rest)

Args:
cards: List of analysis cards (typically HealthcheckAnalysisCard or
ErrorAnalysisCard instances).

Returns:
Sorted list of cards.
"""

def sort_key(card: AnalysisCardBase) -> tuple[int, int, str]:
if isinstance(card, ErrorAnalysisCard):
return (0, 0, card.name)

if isinstance(card, HealthcheckAnalysisCard):
return (
_STATUS_SORT_ORDER[card.get_status()],
0 if card.name in PRIORITY_HEALTHCHECKS else 1,
card.name,
)

# Fallback for type safety (unreachable in practice)
return (4, 1, card.name)

return sorted(cards, key=sort_key)
51 changes: 51 additions & 0 deletions ax/analysis/healthcheck/tests/test_healthcheck_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import pandas as pd
from ax.analysis.analysis import ErrorAnalysisCard
from ax.analysis.healthcheck.healthcheck_analysis import (
create_healthcheck_analysis_card,
HealthcheckStatus,
sort_healthcheck_cards,
)
from ax.core.analysis_card import AnalysisCardBase
from ax.utils.common.testutils import TestCase


def _card(name: str, status: HealthcheckStatus) -> AnalysisCardBase:
return create_healthcheck_analysis_card(
name=name, title=name, subtitle=name, df=pd.DataFrame(), status=status
)


def _error(name: str) -> AnalysisCardBase:
return ErrorAnalysisCard(
name=name, title=name, subtitle=name, df=pd.DataFrame(), blob=""
)


class TestHealthcheckAnalysis(TestCase):
def test_sort_ordering(self) -> None:
cards: list[AnalysisCardBase] = [
_card("RegularAnalysis", HealthcheckStatus.PASS),
_card("WarningAnalysis", HealthcheckStatus.WARNING),
_error("ErrorAnalysis"),
_card("BaselineImprovementAnalysis", HealthcheckStatus.PASS),
_card("FailAnalysis", HealthcheckStatus.FAIL),
]
result = sort_healthcheck_cards(cards)

self.assertEqual(
[c.name for c in result],
[
"ErrorAnalysis",
"FailAnalysis",
"WarningAnalysis",
"BaselineImprovementAnalysis",
"RegularAnalysis",
],
)
15 changes: 4 additions & 11 deletions ax/analysis/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, final

from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis, ErrorAnalysisCard
from ax.analysis.analysis import Analysis
from ax.analysis.diagnostics import DiagnosticAnalysis
from ax.analysis.healthcheck.baseline_improvement import BaselineImprovementAnalysis
from ax.analysis.healthcheck.can_generate_candidates import (
Expand All @@ -19,7 +19,7 @@
ConstraintsFeasibilityAnalysis,
)
from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckAnalysisCard
from ax.analysis.healthcheck.healthcheck_analysis import sort_healthcheck_cards
from ax.analysis.healthcheck.metric_fetching_errors import MetricFetchingErrorsAnalysis
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
Expand Down Expand Up @@ -247,21 +247,14 @@ def compute(
if analyis is not None
]

non_passing_health_checks = [
card
for card in health_check_cards
if (isinstance(card, HealthcheckAnalysisCard) and not card.is_passing())
or isinstance(card, ErrorAnalysisCard)
]

health_checks_group = (
AnalysisCardGroup(
name="HealthchecksAnalysis",
title=HEALTH_CHECK_CARDGROUP_TITLE,
subtitle=HEALTH_CHECK_CARDGROUP_SUBTITLE,
children=non_passing_health_checks,
children=sort_healthcheck_cards(health_check_cards),
)
if len(non_passing_health_checks) > 0
if len(health_check_cards) > 0
else None
)

Expand Down