forked from treder/MVPA-Light
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample4_classify_timextime.m
More file actions
117 lines (92 loc) · 3.51 KB
/
example4_classify_timextime.m
File metadata and controls
117 lines (92 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
%%% Time generalisation example using the function mv_classify_timextime.
%%% In this function, we need data with a time dimension [samples x
%%% features x time points]. Then, cross-validation is run separately for
%%% every combination of training time point and test time point. The
%%% result is a matrix of classification performance scores, one score for
%%% every combination of training and test times.
clear all
% Load data (in /examples folder)
[dat, clabel] = load_example_data('epoched3', 0);
%% Setup configuration struct
% Configuration struct for time classification with cross-validation. We
% perform 5-fold cross-validation with 10 repetitions. As classifier, we
% use LDA.
cfg = [];
cfg.classifier = 'lda';
cfg.normalise = 'demean'; % 'demean' 'none'
cfg.metric = 'accuracy';
[acc, result_acc] = mv_classify_timextime(cfg, dat.trial, clabel);
% Let us re-run the classification, this calculating the area the ROC curve
% (AUC) as a performance metric
cfg.metric = 'auc';
[auc, result_auc] = mv_classify_timextime(cfg, dat.trial, clabel);
%% Plot time generalisation matrix
% close all
% mv_plot_result(result_acc, dat.time, dat.time) % 2nd and 3rd argument are optional
figure
cfg_plot= [];
cfg_plot.x = dat.time;
cfg_plot.y = cfg_plot.x;
mv_plot_2D(cfg_plot, acc);
colormap jet
title('Accuracy')
figure
mv_plot_2D(cfg_plot, auc);
colormap jet
title('AUC')
%% Compare with and without cross-validation
% We already calculated cross-validated performance above. Here, we do the
% analysis once again, this time without cross-validation.
cfg.cv = 'none';
cfg.metric = 'accuracy';
[acc_noCV, result_acc_noCV] = mv_classify_timextime(cfg, dat.trial, clabel);
mv_plot_result({result_acc, result_acc_noCV}, dat.time, dat.time)
%% Compare accuracy/AUC when no normalisation is performed
cfg.normalise = 'none';
cfg.metric = 'accuracy';
acc = mv_classify_timextime(cfg, dat.trial, clabel);
cfg.metric = 'auc';
auc = mv_classify_timextime(cfg, dat.trial, clabel);
figure
mv_plot_2D(cfg_plot, acc);
colormap jet
title('Accuracy')
figure
mv_plot_2D(cfg_plot, auc);
colormap jet
title('AUC')
%% Generalisation with two datasets
% The classifier is trained on one dataset, and tested on another dataset.
% As two datasets, two different subjects are taken.
[dat, clabel] = load_example_data('epoched3', 0);
% Load data from a different subject (epoched1). This will serve as the
% test data.
% The subject loaded above will serve as training data.
[dat2, clabel2] = load_example_data('epoched1');
cfg = [];
cfg.classifier = 'lda';
cfg.normalise = 'zscore'; % 'demean' 'none' 'zscore'
cfg.metric = 'acc';
[acc31, result31] = mv_classify_timextime(cfg, dat.trial, clabel, dat2.trial, clabel2);
% Reverse the analysis: train the classifier on epoched1, test on epoched3
[acc13, result13]= mv_classify_timextime(cfg, dat2.trial, clabel2, dat.trial, clabel);
% Train AND test on epoched1 (overfitting!)
[acc11, result11]= mv_classify_timextime(cfg, dat2.trial, clabel2, dat2.trial, clabel2);
figure
cfg_plot =[];
cfg_plot.y = dat.time; cfg_plot.x = dat2.time;
mv_plot_2D(cfg_plot, acc31 );
colormap jet
title('Training on epoched3, testing on epoched1')
figure
cfg_plot.x = dat.time; cfg_plot.y = dat2.time;
mv_plot_2D(cfg_plot, acc13 );
colormap jet
title('Training on epoched1, testing on epoched3')
figure
cfg_plot.x = dat.time; cfg_plot.y = dat.time;
mv_plot_2D(cfg_plot, acc11 );
colormap jet
title('Training AND testing on epoched1 (overfitting!)')
% close all
% mv_plot_result({result13, result31}, dat.time, dat.time)