Skip to content

Commit 8148e27

Browse files
queeliusclaude
andcommitted
Clean up v0.5.0: extract helpers, remove dead eps field, delete stale artifacts
- Add require_params(), get_delta() helpers and parameter length validation to R/utils.R; replace inline boilerplate across dfr_dist.R, diagnostics.R, and distributions.R - Remove dead eps field from dfr_dist constructor and tests - Simplify inv_cdf (remove unnecessary do.call wrapper) - Update sampler roxygen to reflect inverse CDF implementation - Add left-censoring note to distributions.Rd - Add test-utils.R with 14 tests for the new helpers - Delete stale vignette .R tangling artifacts and .html builds Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3b7b5a5 commit 8148e27

16 files changed

+138
-1616
lines changed

R/dfr_dist.R

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#' @param rate A function that computes the hazard rate at time `t`.
88
#' @param par The parameters of the distribution. Defaults to `NULL`,
99
#' which means that the parameters are unknown.
10-
#' @param eps The epsilon update for numerical integration. Defaults to 0.01.
1110
#' @param ob_col The column name for observation times in data frames.
1211
#' Defaults to "t".
1312
#' @param delta_col The column name for event indicators in data frames.
@@ -25,14 +24,13 @@
2524
#' If NULL, falls back to numerical Hessian via numDeriv::hessian.
2625
#' @return A `dfr_dist` object that inherits from `likelihood_model`.
2726
#' @export
28-
dfr_dist <- function(rate, par = NULL, eps = 0.01,
27+
dfr_dist <- function(rate, par = NULL,
2928
ob_col = "t", delta_col = "delta",
3029
cum_haz_rate = NULL, score_fn = NULL,
3130
hess_fn = NULL) {
3231
structure(
3332
list(rate = rate,
3433
par = par,
35-
eps = eps,
3634
ob_col = ob_col,
3735
delta_col = delta_col,
3836
cum_haz_rate = cum_haz_rate,
@@ -87,13 +85,11 @@ inv_cdf.dfr_dist <- function(x, ...) {
8785
cdf_fn <- cdf(x, ...)
8886
function(p, par = NULL, ...) {
8987
par <- get_params(par, x$par)
90-
uniroot_args <- list(
91-
f = function(t) {
92-
cdf_fn(t, par, ...) - p
93-
},
88+
uniroot(
89+
f = function(t) cdf_fn(t, par, ...) - p,
9490
interval = c(0, 1e3),
95-
extendInt = "upX")
96-
do.call(uniroot, uniroot_args)$root
91+
extendInt = "upX"
92+
)$root
9793
}
9894
}
9995

@@ -109,23 +105,16 @@ params.dfr_dist <- function(x, ...) {
109105
}
110106

111107
#' Sampling function for `dfr_dist` objects.
112-
#'
113-
#' Since S(t,par) = exp(-cum_hz(t,par)), we can sample from the
114-
#' distribution by letting t = 0 (or some other positive number if
115-
#' we want to condition on T > t_min), sampling from an exponential
116-
#' distribution with `lambda = rate(t, par)`, and then rejecting
117-
#' the sample if `runif(1) > S(t, par)`. If accepted, add that
118-
#' observation to the sample, otherwise reject it, let `t = t + eps`
119-
#' where `eps` is some small number, and repeat. We continue this
120-
#' process until we have `n` observations for the sample.
108+
#'
109+
#' Uses inverse CDF sampling: generates uniform random values and
110+
#' transforms them through the quantile function (inverse CDF).
121111
#'
122112
#' @param x The object to obtain the sampler of.
123-
#' @param ... Additional arguments to pass into the survival function
113+
#' @param ... Additional arguments to pass into the inverse CDF constructor.
124114
#' @return A function that samples from the distribution. It accepts
125-
#' `n`, the number of samples to take, `t` is the time at which to start
126-
#' sampling, `par` are the parameters of the distribution, and `eps` is
127-
#' the update for numerical integration. Finally, we pass additional
128-
#' arguments `...` into the hazard function.
115+
#' `n`, the number of samples to take, `par`, the parameters of the
116+
#' distribution, and `...`, additional arguments passed to the quantile
117+
#' function.
129118
#' @importFrom algebraic.dist params surv sampler
130119
#' @importFrom stats runif
131120
#' @export
@@ -313,12 +302,7 @@ loglik.dfr_dist <- function(model, ...) {
313302
par <- get_params(par, model$par)
314303

315304
t <- df[[model$ob_col]]
316-
317-
if (model$delta_col %in% names(df)) {
318-
delta <- df[[model$delta_col]]
319-
} else {
320-
delta <- rep(1, length(t))
321-
}
305+
delta <- get_delta(df, model$delta_col)
322306

323307
ll <- 0
324308

@@ -476,13 +460,7 @@ fit.dfr_dist <- function(object, ...) {
476460
function(df, par = NULL,
477461
method = c("BFGS", "Nelder-Mead", "L-BFGS-B", "CG", "SANN"),
478462
control = list(), ...) {
479-
if (is.null(par)) {
480-
par <- params(object)
481-
if (is.null(par)) {
482-
stop("Initial parameters required: specify 'par' argument ",
483-
"or set parameters in dfr_dist()")
484-
}
485-
}
463+
par <- require_params(par, params(object))
486464

487465
control <- modifyList(list(fnscale = -1), control)
488466
method <- match.arg(method)

R/diagnostics.R

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,7 @@ residuals.dfr_dist <- function(object, data, par = NULL,
7575
type <- match.arg(type)
7676
H <- cum_haz(object, ...)
7777

78-
if (is.null(par)) {
79-
par <- object$par
80-
if (is.null(par)) {
81-
stop("Parameters required: provide via 'par' argument or in distribution object")
82-
}
83-
}
78+
par <- require_params(par, object$par)
8479

8580
if (!object$ob_col %in% names(data)) {
8681
stop(sprintf("Time column '%s' not found in data", object$ob_col))
@@ -93,13 +88,7 @@ residuals.dfr_dist <- function(object, data, par = NULL,
9388
return(H_vals)
9489
}
9590

96-
if (object$delta_col %in% names(data)) {
97-
delta <- data[[object$delta_col]]
98-
} else {
99-
delta <- rep(1, nrow(data))
100-
}
101-
102-
delta - H_vals
91+
get_delta(data, object$delta_col) - H_vals
10392
}
10493

10594
# =============================================================================
@@ -172,12 +161,7 @@ plot.dfr_dist <- function(x, data = NULL, par = NULL,
172161
empirical_col = "steelblue", ...) {
173162
what <- match.arg(what)
174163

175-
if (is.null(par)) {
176-
par <- x$par
177-
if (is.null(par)) {
178-
stop("Parameters required: provide via 'par' argument or in distribution object")
179-
}
180-
}
164+
par <- require_params(par, x$par)
181165

182166
if (is.null(xlim)) {
183167
xlim <- if (!is.null(data) && x$ob_col %in% names(data)) {
@@ -210,11 +194,7 @@ plot.dfr_dist <- function(x, data = NULL, par = NULL,
210194

211195
if (!is.null(data) && empirical && !add && x$ob_col %in% names(data)) {
212196
t <- data[[x$ob_col]]
213-
delta <- if (x$delta_col %in% names(data)) {
214-
data[[x$delta_col]]
215-
} else {
216-
rep(1, length(t))
217-
}
197+
delta <- get_delta(data, x$delta_col)
218198

219199
km <- kaplan_meier(t, delta)
220200

R/distributions.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
#' Each provides the complete specification (rate, cum_haz_rate, score_fn,
55
#' and where practical, hess_fn) for optimal performance.
66
#'
7+
#' @section Left-Censoring Note:
8+
#' The analytical score and Hessian functions provided by these constructors
9+
#' assume event indicators in \{0, 1\} (right-censored and exact observations).
10+
#' For left-censored data (delta = -1), these functions are not applicable and
11+
#' the package automatically falls back to numerical differentiation via
12+
#' \code{numDeriv::grad} and \code{numDeriv::hessian} through the log-likelihood,
13+
#' which handles all censoring types correctly.
14+
#'
715
#' @name distributions
816
#' @family distributions
917
NULL
@@ -67,11 +75,11 @@ dfr_exponential <- function(lambda = NULL) {
6775
par[[1]] * t
6876
},
6977
score_fn = function(df, par, ...) {
70-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
78+
delta <- get_delta(df)
7179
c(sum(delta == 1) / par[[1]] - sum(df$t))
7280
},
7381
hess_fn = function(df, par, ...) {
74-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
82+
delta <- get_delta(df)
7583
matrix(-sum(delta == 1) / par[[1]]^2, nrow = 1, ncol = 1)
7684
},
7785
par = lambda
@@ -158,7 +166,7 @@ dfr_weibull <- function(shape = NULL, scale = NULL) {
158166
k <- par[[1]]
159167
sigma <- par[[2]]
160168
t <- df$t
161-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
169+
delta <- get_delta(df)
162170

163171
n_events <- sum(delta == 1)
164172
t_ratio <- t / sigma
@@ -175,7 +183,7 @@ dfr_weibull <- function(shape = NULL, scale = NULL) {
175183
k <- par[[1]]
176184
sigma <- par[[2]]
177185
t <- df$t
178-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
186+
delta <- get_delta(df)
179187

180188
n_events <- sum(delta == 1)
181189
t_ratio <- t / sigma
@@ -260,7 +268,7 @@ dfr_gompertz <- function(a = NULL, b = NULL) {
260268
a <- par[[1]]
261269
b <- par[[2]]
262270
t <- df$t
263-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
271+
delta <- get_delta(df)
264272

265273
exp_bt <- exp(b * t)
266274
n_events <- sum(delta == 1)
@@ -345,7 +353,7 @@ dfr_loglogistic <- function(alpha = NULL, beta = NULL) {
345353
alpha <- par[[1]]
346354
beta <- par[[2]]
347355
t <- df$t
348-
delta <- if ("delta" %in% names(df)) df$delta else rep(1, nrow(df))
356+
delta <- get_delta(df)
349357

350358
t_ratio <- t / alpha
351359
t_ratio_beta <- t_ratio^beta

R/utils.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,26 @@ get_params <- function(par, default = NULL) {
22
if (is.null(par) || all(is.na(par))) {
33
return(default)
44
}
5+
if (!is.null(default) && length(par) != length(default)) {
6+
stop(sprintf(
7+
"Parameter length mismatch: got %d, expected %d",
8+
length(par), length(default)
9+
))
10+
}
511
if (!is.null(default)) {
612
par[is.na(par)] <- default[is.na(par)]
713
}
814
par
915
}
16+
17+
require_params <- function(par, default) {
18+
par <- get_params(par, default)
19+
if (is.null(par)) {
20+
stop("Parameters required: provide via 'par' argument or in distribution object")
21+
}
22+
par
23+
}
24+
25+
get_delta <- function(df, delta_col = "delta") {
26+
if (delta_col %in% names(df)) df[[delta_col]] else rep(1, nrow(df))
27+
}

man/dfr_dist.Rd

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/distributions.Rd

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sampler.dfr_dist.Rd

Lines changed: 6 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-dfr_dist.R

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,6 @@ test_that("dfr_dist constructor stores parameters", {
3737
expect_equal(dist$par, c(lambda = 0.5))
3838
})
3939

40-
test_that("dfr_dist constructor uses default eps", {
41-
dist <- dfr_dist(
42-
rate = function(t, par, ...) par[1],
43-
par = c(1)
44-
)
45-
46-
expect_equal(dist$eps, 0.01)
47-
})
48-
49-
test_that("dfr_dist constructor accepts custom eps", {
50-
dist <- dfr_dist(
51-
rate = function(t, par, ...) par[1],
52-
par = c(1),
53-
eps = 0.001
54-
)
55-
56-
expect_equal(dist$eps, 0.001)
57-
})
58-
5940
test_that("is_dfr_dist returns TRUE for dfr_dist objects", {
6041
dist <- make_exponential_dfr(lambda = 1)
6142
expect_true(is_dfr_dist(dist))

tests/testthat/test-likelihood_model.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ test_that("fit throws error when no parameters available", {
584584

585585
solver <- fit(dist)
586586
# Should error when calling without par
587-
expect_error(solver(df), "Initial parameters required")
587+
expect_error(solver(df), "Parameters required")
588588
})
589589

590590
# =============================================================================

0 commit comments

Comments
 (0)