Skip to content

Commit 526a55b

Browse files
authored
Diff (3)
1 parent 7dc9766 commit 526a55b

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

CompStats/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
__version__ = '0.1.7'
14+
__version__ = '0.1.8'
1515
from CompStats.bootstrap import StatisticSamples
1616
from CompStats.measurements import CI, SE, difference_p_value
1717
from CompStats.performance import performance, difference, all_differences, plot_performance, plot_difference

CompStats/interface.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -604,19 +604,40 @@ def p_value(self, right:bool=True):
604604
def dataframe(self, value_name:str='Score',
605605
var_name:str='Best',
606606
alg_legend:str='Algorithm',
607-
perf_names:str=None):
607+
sig_legend:str='Significant',
608+
perf_names:str=None,
609+
CI:float=0.05):
608610
"""Dataframe"""
609611
if perf_names is None and isinstance(self.best, np.ndarray):
610612
perf_names = [f'{alg}({k})'
611613
for k, alg in enumerate(self.best)]
612-
return dataframe(self, value_name=value_name,
613-
var_name=var_name,
614-
alg_legend=alg_legend,
615-
perf_names=perf_names)
614+
df = dataframe(self, value_name=value_name,
615+
var_name=var_name,
616+
alg_legend=alg_legend,
617+
perf_names=perf_names)
618+
df[sig_legend] = False
619+
if isinstance(self.best, str):
620+
for name, p in self.p_value().items():
621+
if p >= CI:
622+
continue
623+
df.loc[df[alg_legend] == name, sig_legend] = True
624+
else:
625+
p_values = self.p_value()
626+
systems = list(p_values.keys())
627+
p_values = np.array([p_values[k] for k in systems])
628+
for name, p_value in zip(perf_names, p_values.T):
629+
mask = df[var_name] == name
630+
for alg, p in zip(systems, p_value):
631+
if p >= CI:
632+
continue
633+
_ = mask & (df[alg_legend] == alg)
634+
df.loc[_, sig_legend] = True
635+
return df
616636

617637
def plot(self, value_name:str='Difference',
618638
var_name:str='Best',
619639
alg_legend:str='Algorithm',
640+
sig_legend:str='Significant',
620641
perf_names:list=None,
621642
CI:float=0.05,
622643
kind:str='point', linestyle:str='none',
@@ -644,7 +665,10 @@ def plot(self, value_name:str='Difference',
644665
import seaborn as sns
645666
df = self.dataframe(value_name=value_name,
646667
var_name=var_name,
647-
alg_legend=alg_legend, perf_names=perf_names)
668+
alg_legend=alg_legend,
669+
sig_legend=sig_legend,
670+
perf_names=perf_names,
671+
CI=CI)
648672
title = var_name
649673
if var_name not in df.columns:
650674
var_name = None
@@ -653,7 +677,9 @@ def plot(self, value_name:str='Difference',
653677
f_grid = sns.catplot(df, x=value_name, errorbar=ci,
654678
y=alg_legend, col=var_name,
655679
kind=kind, linestyle=linestyle,
656-
col_wrap=col_wrap, capsize=capsize, **kwargs)
680+
col_wrap=col_wrap, capsize=capsize,
681+
hue=sig_legend,
682+
**kwargs)
657683
if set_refline:
658684
f_grid.refline(x=0)
659685
if isinstance(self.best, str):

0 commit comments

Comments
 (0)