Skip to content

Commit f883ca9

Browse files
ENH & MTN & FIX
- Fix bentkus_p_value calculation - Fix and move higher_is_better logic in the same place - Implement unit test for BinaryClassificationRiskControl - Fix parametrizing of existing test
1 parent 0c8c12d commit f883ca9

File tree

4 files changed

+54
-7
lines changed

4 files changed

+54
-7
lines changed

mapie/control_risk/p_values.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def compute_hoeffdding_bentkus_p_value(
8989
)
9090
factor = 1 if binary else np.e
9191
bentkus_p_value = factor * binom.cdf(
92-
np.ceil(n_obs_repeat * r_hat_repeat), n_obs, alpha_repeat
92+
np.ceil(n_obs_repeat * r_hat_repeat), n_obs_repeat, alpha_repeat
9393
)
9494
hb_p_value = np.where(
9595
bentkus_p_value > hoeffding_p_value,

mapie/risk_control.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,8 +743,6 @@ def get_value_and_effective_sample_size(
743743
in zip(risk_occurrences, risk_conditions)
744744
if risk_condition)
745745
risk_value = risk_sum / effective_sample_size
746-
if self.higher_is_better:
747-
risk_value = 1 - risk_value
748746
return risk_value, effective_sample_size
749747
return None
750748

mapie/risk_control_draft.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self._predict_function = predict_function
3131
self._risk = risk
3232
self._best_predict_param_choice = best_predict_param_choice
33-
self._alpha = 1 - target_level
33+
self._target_level = target_level
3434
self._delta = 1 - confidence_level
3535

3636
self._thresholds: NDArray[float] = np.linspace(0, 0.99, 100)
@@ -56,9 +56,15 @@ def calibrate(self, X_calibrate: ArrayLike, y_calibrate: ArrayLike) -> None:
5656
) for predictions in predictions_per_threshold]
5757
)
5858

59+
if self._risk.higher_is_better:
60+
risks_and_eff_sizes[:, 0] = 1 - risks_and_eff_sizes[:, 0]
61+
alpha = self._target_level
62+
else:
63+
alpha = 1 - self._target_level
64+
5965
valid_thresholds_index = ltt_procedure(
6066
risks_and_eff_sizes[:, 0],
61-
np.array([self._alpha]),
67+
np.array([alpha]),
6268
self._delta,
6369
risks_and_eff_sizes[:, 1],
6470
True,

mapie/tests/test_risk_control.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@
1111
from sklearn.pipeline import Pipeline, make_pipeline
1212
from sklearn.preprocessing import OneHotEncoder
1313
from sklearn.utils.validation import check_is_fitted
14+
from sklearn.metrics import precision_score, recall_score, accuracy_score
1415
from typing_extensions import TypedDict
1516

1617
from numpy.typing import NDArray
17-
from mapie.risk_control import PrecisionRecallController
18+
from mapie.risk_control import (
19+
PrecisionRecallController,
20+
precision,
21+
recall,
22+
accuracy,
23+
BinaryClassificationRisk,
24+
)
1825

1926
Params = TypedDict(
2027
"Params",
@@ -260,7 +267,7 @@ def test_predict_output_shape(
260267
X,
261268
alpha=alpha,
262269
bound=args["bound"],
263-
delta=.1
270+
delta=delta
264271
)
265272
n_alpha = len(alpha) if hasattr(alpha, "__len__") else 1
266273
assert y_pred.shape == y.shape
@@ -808,3 +815,39 @@ def test_method_none_recall() -> None:
808815
)
809816
mapie_clf.fit(X_toy, y_toy)
810817
assert mapie_clf.method == "crc"
818+
819+
820+
# The following test is voluntarily agnostic
821+
# to the specific binary classification risk control implementation.
822+
@pytest.mark.parametrize(
823+
"risk_instance, metric_func, effective_sample_func",
824+
[
825+
(precision, precision_score, lambda y_true, y_pred: np.sum(y_pred == 1)),
826+
(recall, recall_score, lambda y_true, y_pred: np.sum(y_true == 1)),
827+
(accuracy, accuracy_score, lambda y_true, y_pred: len(y_true)),
828+
],
829+
)
830+
@pytest.mark.parametrize(
831+
"y_true, y_pred",
832+
[
833+
(np.array([1, 0, 1, 0]), np.array([1, 1, 0, 0])),
834+
(np.array([1, 1, 0, 0]), np.array([1, 1, 1, 0])),
835+
(np.array([0, 0, 0, 0]), np.array([0, 1, 0, 1])),
836+
],
837+
)
838+
def test_binary_classification_risk(
839+
risk_instance: BinaryClassificationRisk,
840+
metric_func,
841+
effective_sample_func,
842+
y_true,
843+
y_pred
844+
):
845+
result = risk_instance.get_value_and_effective_sample_size(y_true, y_pred)
846+
if effective_sample_func(y_true, y_pred) == 0:
847+
assert result is None
848+
else:
849+
value, n = result
850+
expected_value = metric_func(y_true, y_pred)
851+
expected_n = effective_sample_func(y_true, y_pred)
852+
assert np.isclose(value, expected_value)
853+
assert n == expected_n

0 commit comments

Comments
 (0)