Skip to content

Commit 02dc6e2

Browse files
authored
Merge pull request #17 from INGEOTEC/develop
Version - 0.1.5
2 parents a647a25 + b335b9f commit 02dc6e2

File tree

5 files changed

+187
-65
lines changed

5 files changed

+187
-65
lines changed

CompStats/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
__version__ = '0.1.4'
14+
__version__ = '0.1.5'
1515
from CompStats.bootstrap import StatisticSamples
1616
from CompStats.measurements import CI, SE, difference_p_value
1717
from CompStats.performance import performance, difference, all_differences, plot_performance, plot_difference

CompStats/interface.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(self, y_true, *y_pred,
107107
self.num_samples = num_samples
108108
self.n_jobs = n_jobs
109109
self.use_tqdm = use_tqdm
110+
self.sorting_func = np.linalg.norm
110111
self._init()
111112

112113
def _init(self):
@@ -139,11 +140,14 @@ def __sklearn_clone__(self):
139140
ins = klass(**params)
140141
ins.predictions = dict(self.predictions)
141142
ins._statistic_samples._samples = self.statistic_samples._samples
143+
ins.sorting_func = self.sorting_func
142144
return ins
143145

144146
def __repr__(self):
145147
"""Prediction statistics with standard error in parenthesis"""
146-
return f"<{self.__class__.__name__}>\n{self}"
148+
arg = 'score_func' if self.error_func is None else 'error_func'
149+
func_name = self.statistic_func.__name__
150+
return f"<{self.__class__.__name__}({arg}={func_name})>\n{self}"
147151

148152
def __str__(self):
149153
"""Prediction statistics with standard error in parenthesis"""
@@ -152,7 +156,14 @@ def __str__(self):
152156
output = ["Statistic with its standard error (se)"]
153157
output.append("statistic (se)")
154158
for key, value in self.statistic.items():
155-
output.append(f'{value:0.4f} ({se[key]:0.4f}) <= {key}')
159+
if isinstance(value, float):
160+
desc = f'{value:0.4f} ({se[key]:0.4f}) <= {key}'
161+
else:
162+
desc = [f'{v:0.4f} ({k:0.4f})'
163+
for v, k in zip(value, se[key])]
164+
desc = ', '.join(desc)
165+
desc = f'{desc} <= {key}'
166+
output.append(desc)
156167
return "\n".join(output)
157168

158169
def __call__(self, y_pred, name=None):
@@ -202,6 +213,7 @@ def difference(self, wrt_to: str=None):
202213
diff_ins = Difference(statistic_samples=clone(self.statistic_samples),
203214
statistic=self.statistic,
204215
best=self.best[0])
216+
diff_ins.sorting_func = self.sorting_func
205217
diff_ins.statistic_samples.calls = diff
206218
diff_ins.statistic_samples.info['best'] = self.best[0]
207219
return diff_ins
@@ -214,10 +226,20 @@ def best(self):
214226
return self._best
215227
except AttributeError:
216228
statistic = [(k, v) for k, v in self.statistic.items()]
217-
statistic = sorted(statistic, key=lambda x: x[1],
229+
statistic = sorted(statistic,
230+
key=lambda x: self.sorting_func(x[1]),
218231
reverse=self.statistic_samples.BiB)
219232
self._best = statistic[0]
220233
return self._best
234+
235+
@property
236+
def sorting_func(self):
237+
"""Rank systems when multiple performances are used"""
238+
return self._sorting_func
239+
240+
@sorting_func.setter
241+
def sorting_func(self, value):
242+
self._sorting_func = value
221243

222244
@property
223245
def statistic(self):
@@ -241,7 +263,8 @@ def statistic(self):
241263

242264
data = sorted([(k, self.statistic_func(self.y_true, v))
243265
for k, v in self.predictions.items()],
244-
key=lambda x: x[1], reverse=self.statistic_samples.BiB)
266+
key=lambda x: self.sorting_func(x[1]),
267+
reverse=self.statistic_samples.BiB)
245268
return dict(data)
246269

247270
@property
@@ -419,19 +442,36 @@ class Difference:
419442
best:str=None
420443
statistic:dict=None
421444

445+
@property
446+
def sorting_func(self):
447+
"""Rank systems when multiple performances are used"""
448+
return self._sorting_func
449+
450+
@sorting_func.setter
451+
def sorting_func(self, value):
452+
self._sorting_func = value
453+
422454
def __repr__(self):
423455
"""p-value"""
424456
return f"<{self.__class__.__name__}>\n{self}"
425457

426458
def __str__(self):
427459
"""p-value"""
428460
output = [f"difference p-values w.r.t {self.best}"]
429-
for k, v in self.p_value().items():
430-
output.append(f'{v:0.4f} <= {k}')
461+
for key, value in self.p_value().items():
462+
if isinstance(value, float):
463+
output.append(f'{value:0.4f} <= {key}')
464+
else:
465+
desc = [f'{v:0.4f}' for v in value]
466+
desc = ', '.join(desc)
467+
desc = f'{desc} <= {key}'
431468
return "\n".join(output)
432469

433-
def p_value(self):
470+
def p_value(self, right:bool=True):
434471
"""Compute p_value of the differences
472+
473+
:param right: Estimate the p-value using :math:`\\text{sample} \\geq 2\\delta`
474+
:type right: bool
435475
436476
>>> from sklearn.svm import LinearSVC
437477
>>> from sklearn.ensemble import RandomForestClassifier
@@ -452,10 +492,20 @@ def p_value(self):
452492
"""
453493
values = []
454494
sign = 1 if self.statistic_samples.BiB else -1
495+
ndim = self.statistic[self.best].ndim
455496
for k, v in self.statistic_samples.calls.items():
456497
delta = 2 * sign * (self.statistic[self.best] - self.statistic[k])
457-
values.append((k, (v > delta).mean()))
458-
values.sort(key=lambda x: x[1])
498+
if ndim == 0:
499+
if right:
500+
values.append((k, (v >= delta).mean()))
501+
else:
502+
values.append((k, (v <= 0).mean()))
503+
else:
504+
if right:
505+
values.append((k, (v >= delta).mean(axis=0)))
506+
else:
507+
values.append((k, (v <= 0).mean(axis=0)))
508+
values.sort(key=lambda x: self.sorting_func(x[1]))
459509
return dict(values)
460510

461511
def plot(self, **kwargs):

0 commit comments

Comments
 (0)