@@ -970,16 +970,22 @@ def plot_mean_and_cis_by_date(df: pd.DataFrame, item_to_show: str, group_trials_
970970
971971
972972def plot_filter_model_variables (corr_mat_list :list , norm_contribution_df :pd .DataFrame , ** kwargs ) -> plt .Axes :
973+ """ Plot the mean correlation matrix and the mean contribution of each variable.
974+ corr_mat_list: list of correlation matrices
975+ norm_contribution_df: DataFrame with the normalized contribution of each variable
976+ """
973977 fig , ax = plt .subplots (2 , 1 , figsize = (10 , 10 ))
974- X = corr_mat_list [0 ].index
978+ # get the variable names from the first correlation matrix
979+ X = corr_mat_list [0 ].index
980+ # Calculate the mean correlation matrix
975981 corr_mat_mean = np .mean (np .stack (corr_mat_list ), axis = 0 )
976982 # # Create a mask for the upper triangle
977983 # mask = np.triu(np.ones_like(corr_mat_mean, dtype=bool), k=1)
978984 sns .heatmap (corr_mat_mean , ax = ax [0 ], cmap = 'coolwarm' , annot = True , fmt = ".2f" )
979985 ax [0 ].set_xticklabels (X , rotation = 16 , ha = "right" , rotation_mode = "anchor" )
980986 ax [0 ].set_yticklabels (X , rotation = 8 , ha = "right" , rotation_mode = "anchor" )
981987 ax [0 ].set_title ("Mean Correlation Matrix" )
982-
988+ # Plot the mean contribution of each variable
983989 norm_contribution_df .mean (axis = 1 ).sort_values ().plot (kind = 'barh' , ax = ax [1 ])
984990 ax [1 ].set_xlabel ('Mean Contribution' )
985991 plt .tight_layout ()
0 commit comments