@@ -107,6 +107,7 @@ def __init__(self, y_true, *y_pred,
107107 self .num_samples = num_samples
108108 self .n_jobs = n_jobs
109109 self .use_tqdm = use_tqdm
110+ self .sorting_func = np .linalg .norm
110111 self ._init ()
111112
112113 def _init (self ):
@@ -139,11 +140,14 @@ def __sklearn_clone__(self):
139140 ins = klass (** params )
140141 ins .predictions = dict (self .predictions )
141142 ins ._statistic_samples ._samples = self .statistic_samples ._samples
143+ ins .sorting_func = self .sorting_func
142144 return ins
143145
144146 def __repr__ (self ):
145147 """Prediction statistics with standard error in parenthesis"""
146- return f"<{ self .__class__ .__name__ } >\n { self } "
148+ arg = 'score_func' if self .error_func is None else 'error_func'
149+ func_name = self .statistic_func .__name__
150+ return f"<{ self .__class__ .__name__ } ({ arg } ={ func_name } )>\n { self } "
147151
148152 def __str__ (self ):
149153 """Prediction statistics with standard error in parenthesis"""
@@ -152,7 +156,14 @@ def __str__(self):
152156 output = ["Statistic with its standard error (se)" ]
153157 output .append ("statistic (se)" )
154158 for key , value in self .statistic .items ():
155- output .append (f'{ value :0.4f} ({ se [key ]:0.4f} ) <= { key } ' )
159+ if isinstance (value , float ):
160+ desc = f'{ value :0.4f} ({ se [key ]:0.4f} ) <= { key } '
161+ else :
162+ desc = [f'{ v :0.4f} ({ k :0.4f} )'
163+ for v , k in zip (value , se [key ])]
164+ desc = ', ' .join (desc )
165+ desc = f'{ desc } <= { key } '
166+ output .append (desc )
156167 return "\n " .join (output )
157168
158169 def __call__ (self , y_pred , name = None ):
@@ -202,6 +213,7 @@ def difference(self, wrt_to: str=None):
202213 diff_ins = Difference (statistic_samples = clone (self .statistic_samples ),
203214 statistic = self .statistic ,
204215 best = self .best [0 ])
216+ diff_ins .sorting_func = self .sorting_func
205217 diff_ins .statistic_samples .calls = diff
206218 diff_ins .statistic_samples .info ['best' ] = self .best [0 ]
207219 return diff_ins
@@ -214,10 +226,20 @@ def best(self):
214226 return self ._best
215227 except AttributeError :
216228 statistic = [(k , v ) for k , v in self .statistic .items ()]
217- statistic = sorted (statistic , key = lambda x : x [1 ],
229+ statistic = sorted (statistic ,
230+ key = lambda x : self .sorting_func (x [1 ]),
218231 reverse = self .statistic_samples .BiB )
219232 self ._best = statistic [0 ]
220233 return self ._best
234+
235+ @property
236+ def sorting_func (self ):
237+ """Rank systems when multiple performances are used"""
238+ return self ._sorting_func
239+
240+ @sorting_func .setter
241+ def sorting_func (self , value ):
242+ self ._sorting_func = value
221243
222244 @property
223245 def statistic (self ):
@@ -241,7 +263,8 @@ def statistic(self):
241263
242264 data = sorted ([(k , self .statistic_func (self .y_true , v ))
243265 for k , v in self .predictions .items ()],
244- key = lambda x : x [1 ], reverse = self .statistic_samples .BiB )
266+ key = lambda x : self .sorting_func (x [1 ]),
267+ reverse = self .statistic_samples .BiB )
245268 return dict (data )
246269
247270 @property
@@ -419,19 +442,36 @@ class Difference:
419442 best :str = None
420443 statistic :dict = None
421444
445+ @property
446+ def sorting_func (self ):
447+ """Rank systems when multiple performances are used"""
448+ return self ._sorting_func
449+
450+ @sorting_func .setter
451+ def sorting_func (self , value ):
452+ self ._sorting_func = value
453+
422454 def __repr__ (self ):
423455 """p-value"""
424456 return f"<{ self .__class__ .__name__ } >\n { self } "
425457
426458 def __str__ (self ):
427459 """p-value"""
428460 output = [f"difference p-values w.r.t { self .best } " ]
429- for k , v in self .p_value ().items ():
430- output .append (f'{ v :0.4f} <= { k } ' )
461+ for key , value in self .p_value ().items ():
462+ if isinstance (value , float ):
463+ output .append (f'{ value :0.4f} <= { key } ' )
464+ else :
465+ desc = [f'{ v :0.4f} ' for v in value ]
466+ desc = ', ' .join (desc )
467+ desc = f'{ desc } <= { key } '
431468 return "\n " .join (output )
432469
433- def p_value (self ):
470+ def p_value (self , right : bool = True ):
434471 """Compute p_value of the differences
472+
473+ :param right: Estimate the p-value using :math:`\\ text{sample} \\ geq 2\\ delta`
474+ :type right: bool
435475
436476 >>> from sklearn.svm import LinearSVC
437477 >>> from sklearn.ensemble import RandomForestClassifier
@@ -452,10 +492,20 @@ def p_value(self):
452492 """
453493 values = []
454494 sign = 1 if self .statistic_samples .BiB else - 1
495+ ndim = self .statistic [self .best ].ndim
455496 for k , v in self .statistic_samples .calls .items ():
456497 delta = 2 * sign * (self .statistic [self .best ] - self .statistic [k ])
457- values .append ((k , (v > delta ).mean ()))
458- values .sort (key = lambda x : x [1 ])
498+ if ndim == 0 :
499+ if right :
500+ values .append ((k , (v >= delta ).mean ()))
501+ else :
502+ values .append ((k , (v <= 0 ).mean ()))
503+ else :
504+ if right :
505+ values .append ((k , (v >= delta ).mean (axis = 0 )))
506+ else :
507+ values .append ((k , (v <= 0 ).mean (axis = 0 )))
508+ values .sort (key = lambda x : self .sorting_func (x [1 ]))
459509 return dict (values )
460510
461511 def plot (self , ** kwargs ):
0 commit comments