Skip to content

Commit cabc6db

Browse files
committed
demonstration for filtering variables
1 parent 5b6cf8b commit cabc6db

File tree

3 files changed

+1258
-82794
lines changed

3 files changed

+1258
-82794
lines changed

lecilab_behavior_analysis/plots.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,16 +970,22 @@ def plot_mean_and_cis_by_date(df: pd.DataFrame, item_to_show: str, group_trials_
970970

971971

972972
def plot_filter_model_variables(corr_mat_list:list, norm_contribution_df:pd.DataFrame, **kwargs) -> plt.Axes:
973+
""" Plot the mean correlation matrix and the mean contribution of each variable.
974+
corr_mat_list: list of correlation matrices
975+
norm_contribution_df: DataFrame with the normalized contribution of each variable
976+
"""
973977
fig, ax = plt.subplots(2, 1, figsize=(10, 10))
974-
X = corr_mat_list[0].index
978+
# get the variable names from the first correlation matrix
979+
X = corr_mat_list[0].index
980+
# Calculate the mean correlation matrix
975981
corr_mat_mean = np.mean(np.stack(corr_mat_list), axis=0)
976982
# # Create a mask for the upper triangle
977983
# mask = np.triu(np.ones_like(corr_mat_mean, dtype=bool), k=1)
978984
sns.heatmap(corr_mat_mean, ax=ax[0], cmap='coolwarm', annot=True, fmt=".2f")
979985
ax[0].set_xticklabels(X, rotation=16, ha="right", rotation_mode="anchor")
980986
ax[0].set_yticklabels(X, rotation=8, ha="right", rotation_mode="anchor")
981987
ax[0].set_title("Mean Correlation Matrix")
982-
988+
# Plot the mean contribution of each variable
983989
norm_contribution_df.mean(axis=1).sort_values().plot(kind='barh', ax=ax[1])
984990
ax[1].set_xlabel('Mean Contribution')
985991
plt.tight_layout()

0 commit comments

Comments
 (0)