-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining_protocol.py
More file actions
153 lines (122 loc) · 6.81 KB
/
training_protocol.py
File metadata and controls
153 lines (122 loc) · 6.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from village.classes.training import Training
# click on the link below to see the documentation about how to create
# tasks, plots and training protocols
# https://braincircuitsbehaviorlab.github.io/village/user_guide/create.html
class TrainingProtocol(Training):
"""
This class defines the training protocol for animal behavior experiments.
The training protocol is run every time a task is finished and it determines:
1. Which new task is scheduled for the subject
2. How training variables change based on performance metrics
Required methods to implement:
- __init__: Initialize the training protocol
- default_training_settings: Define initial parameters. It is called when creating a new subject.
- update_training_settings: Update parameters after each session.
Optional method:
- gui_tabs: Organize the variables in custom GUI tabs
"""
def __init__(self) -> None:
"""Initialize the training protocol."""
super().__init__()
def default_training_settings(self) -> None:
"""
Define all initial training parameters for new subjects.
This method is called when creating a new subject, and these parameters
are saved as the initial values for that subject.
Required parameters:
- next_task (str): Name of the next task to run
- refractary_period (int): Waiting time in seconds between sessions
- minimum_duration (int): Minimum time in seconds for the task before door2 opens
- maximum_duration (int): Maximum time in seconds before task stops automatically
Additional parameters:
You can define any additional parameters needed for your specific tasks.
These can be modified between sessions based on subject performance.
"""
# Required parameters for any training protocol
self.settings.next_task = "Habituation" # Next task to run
self.settings.refractary_period = 3600 * 4 # 4 hours between sessions of the same subject
self.settings.minimum_duration = 600 # Minimum duration of 10 min
self.settings.maximum_duration = 900 # Maximum duration of 15 min
# Task-specific parameters
# (can be modified between sessions or set when the task is run manually)
self.settings.reward_amount_ml = 0.08 # Reward volume in milliliters
self.settings.stage = 1 # Current training stage
self.settings.light_intensity_high = 255 # High light intensity in the port (0-255)
self.settings.light_intensity_low = 50 # Low light intensity in the port (0-255)
self.settings.trial_types = ["left_easy",
"right_easy",
"left_hard",
"right_hard"] # Possible trial types
self.settings.punishment_time = 1 # Time in seconds for punishment
self.settings.iti_time = 2 # Inter-trial interval in seconds
self.settings.response_time = 10 # Time in seconds to respond before timeout
def update_training_settings(self) -> None:
"""
Update training parameters after each session.
This method is called when a session finishes and determines how
the subject progresses through the training protocol.
Available data for decision-making:
- self.subject (str): Name of the current subject
- self.last_task (str): Name of the task that just finished
- self.df (pd.DataFrame): DataFrame with all sessions data for this subject
Example logic:
- Progress from Habituation to FollowTheLight after 2 sessions with >100 trials
- Reduce reward amount as training progresses
- Advance to stage 2 after two consecutive sessions in FollowTheLight with >85% performance
"""
if self.last_task == "Habituation":
# Get all Habituation sessions from the dataframe
df_habituation = self.df[self.df["task"] == "Habituation"]
# Check if the animal completed at least 2 Habituation sessions
if len(df_habituation) >= 2:
# Get data from the last session
df_last_session = df_habituation.iloc[-1]
trials_last_session = df_last_session["trial"].iloc[-1]
# Progress to next task if criteria met (>100 trials)
if trials_last_session >= 100:
self.settings.next_task = "FollowTheLight"
self.settings.reward_amount_ml = 0.07 # Decrease reward
elif self.last_task == "FollowTheLight":
# Get all FollowTheLight sessions
df_follow_the_light = self.df[self.df["task"] == "FollowTheLight"]
# Check if at least 2 sessions completed
if len(df_follow_the_light) >= 2:
# Get data from the last two sessions
df_last_session = df_follow_the_light.iloc[-1]
df_previous_session = df_follow_the_light.iloc[-2]
# Calculate performance metrics
performance_last_session = df_last_session["correct"].mean()
performance_previous_session = df_previous_session["correct"].mean()
trials_last_session = df_last_session["trial"].iloc[-1]
trials_previous_session = df_previous_session["trial"].iloc[-1]
# Advance to stage 2 if criteria met
# (>85% correct in two sessions with >100 trials each)
if (performance_last_session >= 0.85 and
performance_previous_session >= 0.85 and
trials_last_session >= 100 and
trials_previous_session >= 100):
self.settings.stage = 2
self.settings.reward_amount_ml = 0.05 # Decrease reward
def gui_tabs(self):
"""
Define the organization of the settings in the GUI.
Whatever that is not defined here will be placed in the "General" tab.
They need to have the same name as your settings variables.
You can use the 'Hide' tab to hide a setting from the GUI.
Items in the lists need to have the same name as your settings variables.
You can also restrict the possible values for each setting.
"""
self.gui_tabs = {
"Port_variables": ["reward_amount_ml",
"light_intensity_high",
"light_intensity_low"],
"Other_variables": ["stage",
"trial_types",
"punishment_time",
"iti_time",
"response_time"],
}
# Define possible values for each variable
self.gui_tabs_restricted = {
"trial_types": ["left_easy", "right_easy", "left_hard", "right_hard"],
}