Skip to content

Commit f4a81c4

Browse files
committed
.
1 parent e827f95 commit f4a81c4

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

lecilab_behavior_analysis/df_transforms.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import numpy as np
33
import ast
4-
from typing import Tuple
4+
from typing import Tuple, Union
55
import lecilab_behavior_analysis.utils as utils
66

77
def fill_missing_data(df: pd.DataFrame) -> pd.DataFrame:
@@ -421,27 +421,32 @@ def get_triangle_polar_plot_df(df: pd.DataFrame) -> pd.DataFrame:
421421
return df_bias
422422

423423

424-
def get_bias_evolution_df(df: pd.DataFrame, groupby: str) -> pd.DataFrame:
424+
def get_bias_evolution_df(df: pd.DataFrame, groupby: Union[str, list[str]]) -> pd.DataFrame:
425425
"""
426426
Gets how the bias of the animals (alternating, right bias, or left bias)
427427
evolves over time.
428428
429429
Arguments:
430430
df: DataFrame with the data
431-
groupby: str, the column to group by. Can be 'session' or 'trial_group'
431+
groupby: str or list, the column(s) to group by. Can be 'session' or 'trial_group'
432432
433433
Returns:
434434
df: DataFrame with the bias evolution
435435
"""
436-
utils.column_checker(df, required_columns={"roa_choice_numeric", "subject", groupby})
436+
groupby_items = ["subject"]
437+
if isinstance(groupby, str):
438+
groupby_items.append(groupby)
439+
elif isinstance(groupby, list):
440+
groupby_items.extend(groupby)
441+
utils.column_checker(df, required_columns=set(groupby_items + ["roa_choice_numeric"]))
437442
df_anchev = df.copy()
438-
df_anchev = df_anchev.groupby(['subject', groupby])['roa_choice_numeric'].value_counts().reset_index(name='count')
443+
df_anchev = df_anchev.groupby(groupby_items)['roa_choice_numeric'].value_counts().reset_index(name='count')
439444
# transform counts into percentages
440-
df_anchev['percentage'] = df_anchev['count'] / df_anchev.groupby(['subject', groupby])['count'].transform('sum')
445+
df_anchev['percentage'] = df_anchev['count'] / df_anchev.groupby(groupby_items)['count'].transform('sum')
441446

442447
# pivot or melt the dataframe so that each subject and session has a y value, that will be the percentage when the bias
443448
# is 0, and the x value will be the differences between the percentages when the bias is 1 and -1
444-
df_bias_pivot = df_anchev.pivot(index=['subject', groupby], columns='roa_choice_numeric', values='percentage')
449+
df_bias_pivot = df_anchev.pivot(index=groupby_items, columns='roa_choice_numeric', values='percentage')
445450

446451
# fill the NaN values with 0
447452
df_bias_pivot = df_bias_pivot.fillna(0)
@@ -486,6 +491,16 @@ def create_transition_matrix(events: list) -> pd.DataFrame:
486491
# Return the transition matrix as a pandas DataFrame for better readability
487492
return pd.DataFrame(transition_matrix, index=items, columns=items)
488493

494+
495+
def add_visual_stimulus_difference(df_in: pd.DataFrame) -> pd.DataFrame:
496+
df = df_in.copy() # Create a copy to avoid modifying the original DataFrame
497+
utils.column_checker(df_in, required_columns={"visual_stimulus"})
498+
df['visual_stimulus'] = df['visual_stimulus'].apply(ast.literal_eval)
499+
df["visual_stim_difference"] = df["visual_stimulus"].apply(lambda x: x[0] - x[1])
500+
# bin the data every 0.1
501+
df["vis_stim_dif_bin"] = np.round((df["visual_stim_difference"] // 0.1) * 0.1, 1)
502+
return df
503+
489504
# if __name__ == "__main__":
490505
# from lecilab_behavior_analysis.utils import load_example_data
491506
# df = load_example_data("mouse1")

lecilab_behavior_analysis/figure_maker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def subject_progress_figure(df: pd.DataFrame, **kwargs) -> Figure:
2727
# Create a GridSpec with 3 rows and 1 column
2828
rows_gs = gridspec.GridSpec(4, 1, height_ratios=[.7, 1, 1, 1])
2929
# Create separate inner grids for each row with different width ratios
30-
gs1 = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=rows_gs[0], width_ratios=[1, 4, 1])
30+
gs1 = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=rows_gs[0], width_ratios=[1, 3, 1])
3131
gs2 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=rows_gs[1], width_ratios=[1.5, 3])
3232
gs3 = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=rows_gs[2], width_ratios=[1.5, 1, 3])
3333
gs4 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=rows_gs[3])
@@ -65,7 +65,7 @@ def subject_progress_figure(df: pd.DataFrame, **kwargs) -> Figure:
6565

6666
# generate the calendar plot
6767
dates_df = dft.get_dates_df(df)
68-
cal_image = plots.rasterize_plot(plots.training_calendar_plot(dates_df), dpi=300)
68+
cal_image = plots.rasterize_plot(plots.training_calendar_plot(dates_df), dpi=600)
6969
# paste the calendar plot filling the entire axis
7070
ax_cal.imshow(cal_image)
7171
ax_cal.axis("off")

lecilab_behavior_analysis/plot_testing.ipynb

Lines changed: 34 additions & 16 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)