forked from treder/MVPA-Light
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample2_crossvalidate.m
More file actions
81 lines (63 loc) · 2.88 KB
/
example2_crossvalidate.m
File metadata and controls
81 lines (63 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
%%% In example 1, training and testing was performed on the same data. This
%%% can lead to overfitting and an inflated measure of classification
%%% accuracy. The function mv_crossvalidate implements cross-validation
%%% which controls for overfitting by repeatedly splitting the data into
%%% training and test sets.
close all
clear all
% Load data (in /examples folder)
[dat, clabel, chans] = load_example_data('epoched3');
% Average activity in 0.6-0.8 interval (see example 1)
ival_idx = find(dat.time >= 0.6 & dat.time <= 0.8);
X = squeeze(mean(dat.trial(:,:,ival_idx),3));
%% Cross-validation
% Configuration struct for cross-validation. As classifier, we
% use LDA. The value of the regularisation parameter lambda is determined
% automatically. As performance measure, use area under the ROC curve
% ('auc').
%
% To get a realistic estimate of classification performance, we perform
% 5-fold (cfg.k = 5) cross-validation with 10 repetitions (cfg.repeat = 10).
cfg_LDA = [];
cfg_LDA.classifier = 'lda';
cfg_LDA.metric = 'auc';
cfg_LDA.cv = 'kfold'; % 'kfold' 'leaveout' 'holdout'
cfg_LDA.k = 5;
cfg_LDA.repeat = 10;
cfg_LDA.balance = 'undersample';
% the param substruct contains the hyperparameters for the classifier.
% Here, we only set lambda = 'auto'. This is the default, so in general
% setting param is not required unless one wants to change the default
% settings.
cfg_LDA.param = [];
cfg_LDA.param.lambda = 'auto';
[acc_LDA, result_LDA] = mv_crossvalidate(cfg_LDA, X, clabel);
% Run analysis also for Logistic Regression (LR), using the same
% cross-validation settings.
cfg_LR = cfg_LDA;
cfg_LR.classifier = 'logreg';
cfg_LR.param = []; % sub-struct with hyperparameters for classifier
cfg_LR.param.lambda = 'auto';
[acc_LR, result_LR] = mv_crossvalidate(cfg_LR, X, clabel);
fprintf('\nClassification accuracy (LDA): %2.2f%%\n', 100*acc_LDA)
fprintf('Classification accuracy (Logreg): %2.2f%%\n', 100*acc_LR)
% Produce plot of results
h = mv_plot_result({result_LDA, result_LR});
%% Use a binomial test to assess statistical significance of accuracies (ACC)
cfg = [];
cfg.test = 'binomial';
stat = mv_statistics(cfg, result_LDA);
%% Comparing cross-validation to training and testing on the same data
cfg_LDA.metric = 'accuracy';
% Select only the first samples
nReduced = 29;
label_reduced = clabel(1:nReduced);
X_reduced = X(1:nReduced,:);
% Cross-validation (proper way)
cfg_LDA.cv = 'kfold';
acc_LDA = mv_crossvalidate(cfg_LDA, X_reduced, label_reduced);
% No cross-validation (test on training data)
cfg_LDA.cv = 'none';
acc_reduced = mv_crossvalidate(cfg_LDA, X_reduced, label_reduced);
fprintf('Using %d samples with cross-validation (proper way): %2.2f%%\n', nReduced, 100*acc_LDA)
fprintf('Using %d samples without cross-validation (test on training data): %2.2f%%\n', nReduced, 100*acc_reduced)