Skip to content

Commit 0be706a

Browse files
authored
refactor: update rounded_survival_table to use dynamic column names (#332)
* fix: update rounded_survival_table to use dynamic column names * - Rename `rounded_survival_table()` to `_rounded_survival_table()` to mark it as an internal helper function - Add comprehensive type hints to function signature and return type - Add parameter and return type annotations: pd.DataFrame, str - Add detailed NumPy-style docstring documenting parameters, return value, and purpose - Update test imports to use the new private function name
1 parent 72f2970 commit 0be706a

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

acro/acro_tables.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def survival_plot( # pylint: disable=too-many-arguments
522522
):
523523
"""Create the survival plot according to the status of suppressing."""
524524
if self.suppress:
525-
survival_table = rounded_survival_table(survival_table)
525+
survival_table = _rounded_survival_table(survival_table)
526526
plot = survival_table.plot(y="rounded_survival_fun", xlim=0, ylim=0)
527527
else: # pragma: no cover
528528
plot = survival_func.plot()
@@ -914,14 +914,37 @@ def delete_empty_rows_columns(table: DataFrame) -> tuple[DataFrame, list[str]]:
914914
return (table, comments)
915915

916916

917-
def rounded_survival_table(survival_table):
918-
"""Calculate the rounded surival function."""
917+
def _rounded_survival_table(
918+
survival_table: pd.DataFrame,
919+
num_at_risk_col: str = "num at risk",
920+
num_events_col: str = "num events",
921+
) -> pd.DataFrame:
922+
"""Calculate the rounded survival function.
923+
924+
Internal helper function for survival analysis with disclosure control.
925+
Applies rounding to survival tables to prevent disclosure of small counts.
926+
927+
Parameters
928+
----------
929+
survival_table : pd.DataFrame
930+
The survival table containing survival analysis results.
931+
num_at_risk_col : str, default "num at risk"
932+
Name of the column containing number at risk values.
933+
num_events_col : str, default "num events"
934+
Name of the column containing number of events.
935+
936+
Returns
937+
-------
938+
pd.DataFrame
939+
The survival table with rounded survival function added.
940+
"""
919941
death_censored = (
920-
survival_table["num at risk"].shift(periods=1) - survival_table["num at risk"]
942+
survival_table[num_at_risk_col].shift(periods=1)
943+
- survival_table[num_at_risk_col]
921944
)
922945
death_censored = death_censored.tolist()
923-
survivor = survival_table["num at risk"].tolist()
924-
deaths = survival_table["num events"].tolist()
946+
survivor = survival_table[num_at_risk_col].tolist()
947+
deaths = survival_table[num_events_col].tolist()
925948
rounded_num_of_deaths = []
926949
rounded_num_at_risk = []
927950
sub_total = 0

test/test_initial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import statsmodels.api as sm
1212

1313
from acro import ACRO, acro_tables, add_constant, add_to_acro, record, utils
14-
from acro.acro_tables import rounded_survival_table
14+
from acro.acro_tables import _rounded_survival_table
1515
from acro.record import Records, load_records
1616

1717
# pylint: disable=redefined-outer-name,too-many-lines
@@ -735,7 +735,7 @@ def test_rounded_survival_table():
735735
)
736736

737737
# Apply rounded_survival_table
738-
result = rounded_survival_table(survival_table.copy())
738+
result = _rounded_survival_table(survival_table.copy())
739739

740740
# Check that it has the rounded_survival_fun column
741741
assert "rounded_survival_fun" in result.columns

0 commit comments

Comments
 (0)