-
Notifications
You must be signed in to change notification settings - Fork 183
Add Precision-Recall curve in probscores (PR_curve) #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
1152a30
140f04b
22e1a74
47542cd
bb30bad
1964065
a4a44db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,10 @@ | |
| ROC_curve_init | ||
| ROC_curve_accum | ||
| ROC_curve_compute | ||
| PR_curve | ||
| PR_curve_init | ||
| PR_curve_accum | ||
| PR_curve_compute | ||
| """ | ||
|
|
||
| import numpy as np | ||
|
|
@@ -421,3 +425,142 @@ def ROC_curve_compute(ROC, compute_area=False): | |
| return POFD_vals, POD_vals, area | ||
| else: | ||
| return POFD_vals, POD_vals | ||
|
|
||
|
|
||
| def PR_curve(P_f, X_o, X_min, n_prob_thrs=10, compute_area=False): | ||
| """ | ||
| Compute the Precision–Recall (PR) curve and optionally its area. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| P_f : array_like | ||
| Forecasted probabilities for exceeding the threshold X_min. | ||
| Non-finite values are ignored. | ||
| X_o : array_like | ||
| Observed values. Non-finite values are ignored. | ||
| X_min : float | ||
| Precipitation intensity threshold for yes/no prediction. | ||
| n_prob_thrs : int, optional | ||
| Number of probability thresholds to evaluate. | ||
| The interval [0, 1] is divided into n_prob_thrs evenly spaced values. | ||
| compute_area : bool, optional | ||
| If True, compute the area under the PR curve using trapezoidal integration. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : tuple | ||
| (precision_vals, recall_vals) for each probability threshold. | ||
| If compute_area is True, return (precision_vals, recall_vals, area), | ||
| where area is the trapezoidal estimate of the PR curve area. | ||
| """ | ||
| P_f = P_f.copy() | ||
| X_o = X_o.copy() | ||
| pr = PR_curve_init(X_min, n_prob_thrs) | ||
| PR_curve_accum(pr, P_f, X_o) | ||
| return PR_curve_compute(pr, X_o, X_min, compute_area) | ||
|
|
||
|
|
||
| def PR_curve_init(X_min, n_prob_thrs=10): | ||
| """ | ||
| Initialize a Precision–Recall curve object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X_min : float | ||
| Precipitation intensity threshold for yes/no prediction. | ||
| n_prob_thrs : int, optional | ||
| Number of probability thresholds to evaluate. | ||
|
|
||
| Returns | ||
| ------- | ||
| PR : dict | ||
| Dictionary containing counters for hits, misses, false alarms, | ||
| correct negatives, and the probability thresholds. | ||
| Keys: | ||
| - "X_min" : threshold value | ||
| - "hits", "misses", "false_alarms", "corr_neg" : arrays of counts | ||
| - "prob_thrs" : array of evenly spaced thresholds in [0, 1] | ||
| """ | ||
| PR = {} | ||
| PR["X_min"] = X_min | ||
| PR["hits"] = np.zeros(n_prob_thrs, dtype=int) | ||
| PR["misses"] = np.zeros(n_prob_thrs, dtype=int) | ||
| PR["false_alarms"] = np.zeros(n_prob_thrs, dtype=int) | ||
| PR["corr_neg"] = np.zeros(n_prob_thrs, dtype=int) | ||
| PR["prob_thrs"] = np.linspace(0.0, 1.0, int(n_prob_thrs)) | ||
| return PR | ||
|
|
||
|
|
||
| def PR_curve_accum(PR, P_f, X_o): | ||
| """ | ||
| Accumulate forecast–observation pairs into the PR object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| PR : dict | ||
| A PR curve object created with PR_curve_init. | ||
| P_f : array_like | ||
| Forecasted probabilities for exceeding X_min. | ||
| X_o : array_like | ||
| Observed values. | ||
|
|
||
| Notes | ||
| ----- | ||
| Updates the counters (hits, misses, false alarms, correct negatives) | ||
| for each probability threshold in PR["prob_thrs"]. | ||
| """ | ||
| mask = np.logical_and(np.isfinite(P_f), np.isfinite(X_o)) | ||
| P_f = P_f[mask] | ||
| X_o = X_o[mask] | ||
| for i, p in enumerate(PR["prob_thrs"]): | ||
| forecast_yes = P_f >= p | ||
| obs_yes = X_o >= PR["X_min"] | ||
| PR["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes)) | ||
| PR["misses"][i] += np.sum(np.logical_and(~forecast_yes, obs_yes)) | ||
| PR["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, ~obs_yes)) | ||
| PR["corr_neg"][i] += np.sum(np.logical_and(~forecast_yes, ~obs_yes)) | ||
|
|
||
|
|
||
| def PR_curve_compute(PR, X_o, X_min, compute_area=False): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """ | ||
| Compute precision and recall values from the PR object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| PR : dict | ||
| A PR curve object created with PR_curve_init. | ||
| X_o : array_like | ||
| Observed values (used only if prevalence or area is computed). | ||
| X_min : float | ||
| Precipitation intensity threshold for yes/no prediction. | ||
| compute_area : bool, optional | ||
| If True, compute the area under the PR curve. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : tuple | ||
| (precision_vals, recall_vals) for each probability threshold. | ||
| If compute_area is True, return (precision_vals, recall_vals, area), | ||
| where area is the trapezoidal estimate of the PR curve area. | ||
| """ | ||
| precision_vals = [] | ||
| recall_vals = [] | ||
|
|
||
| for i in range(len(PR["prob_thrs"])): | ||
| hits = PR["hits"][i] | ||
| misses = PR["misses"][i] | ||
| false_alarms = PR["false_alarms"][i] | ||
|
|
||
| recall = hits / (hits + misses) if (hits + misses) > 0 else 0.0 | ||
| precision = hits / (hits + false_alarms) if (hits + false_alarms) > 0 else 1.0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the convention using
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you comment on how this is handled in scikit-learn for example? |
||
|
|
||
| recall_vals.append(recall) | ||
| precision_vals.append(precision) | ||
|
|
||
| if compute_area: | ||
| # Sort by recall before integration | ||
| recall_sorted, precision_sorted = zip(*sorted(zip(recall_vals, precision_vals))) | ||
| area = np.trapz(precision_sorted, recall_sorted) | ||
| return precision_vals, recall_vals, area | ||
| else: | ||
| return precision_vals, recall_vals | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line computing
obs_yesis constant and can be moved outside of the loop