-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonline_plot.py
More file actions
53 lines (44 loc) · 1.74 KB
/
online_plot.py
File metadata and controls
53 lines (44 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from village.classes.plot import OnlinePlotFigureManager
class Online_Plot(OnlinePlotFigureManager):
def __init__(self) -> None:
super().__init__()
self.ax1 = self.fig.add_subplot(1, 2, 1)
self.ax2 = self.fig.add_subplot(1, 2, 2)
def update_plot(self, df: pd.DataFrame) -> None:
try:
self.make_timing_plot(df, self.ax1)
except Exception:
self.make_error_plot(self.ax1)
try:
self.make_trial_side_and_correct_plot(df, self.ax2)
except Exception:
self.make_error_plot(self.ax2)
self.fig.tight_layout()
def make_timing_plot(self, df: pd.DataFrame, ax: plt.Axes) -> None:
ax.clear()
df.plot(kind="scatter", x="TRIAL_START", y="trial", ax=ax)
def make_trial_side_and_correct_plot(self, df: pd.DataFrame, ax: plt.Axes) -> None:
_ = self.plot_side_correct_performance(df, ax)
def make_error_plot(self, ax) -> None:
ax.clear()
ax.text(
0.5,
0.5,
"Could not create plot",
horizontalalignment="center",
verticalalignment="center",
transform=ax.transAxes,
)
def plot_side_correct_performance(df: pd.DataFrame, ax: plt.Axes) -> plt.Axes:
ax.clear()
# select only the last 100 trials
df = df.tail(100)
sns.scatterplot(data=df, x="trial", y="trial_type", hue="correct", ax=ax)
# make sure the y axis ticks are ascending, inverting the y axis
ax.invert_yaxis()
# plot the mean of the last 10 trials
ax.plot(pd.Series([int(x) for x in df.correct]).rolling(10).mean(), "r")
return ax