Skip to content

Commit cf4c54b

Browse files
committed
fix: SOMVisualizer with BMU data mapping for plot_all method
- Added a new parameter `bmus_data_map` to the `SOMVisualizer` class to facilitate the mapping of Best Matching Units (BMUs) to their corresponding data indices. - Updated the plotting methods to utilize the `bmus_data_map` for generating hit maps, metric maps, score maps, and rank maps, improving the visualization capabilities of the class. - Enhanced documentation for the new parameter to clarify its purpose and usage within the visualization functions.
1 parent 48b4f04 commit cf4c54b

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

torchsom/visualization/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def plot_all(
148148
self,
149149
quantization_errors: list[float],
150150
topographic_errors: list[float],
151+
bmus_data_map: dict[tuple[int, int], list[int]],
151152
data: torch.Tensor,
152153
target: torch.Tensor,
153154
component_names: Optional[list[str]] = None,
@@ -165,6 +166,7 @@ def plot_all(
165166
Args:
166167
quantization_errors (list[float]): List of quantization errors [epochs]
167168
topographic_errors (list[float]): List of topographic errors [epochs]
169+
bmus_data_map (dict[tuple[int, int], list[int]]): Pre-computed BMU to data indices mapping
168170
data (torch.Tensor): Input data tensor [batch_size, n_features]
169171
target (torch.Tensor): Labels tensor for data points [batch_size]
170172
component_names (Optional[list[str]]): Names for each component/feature
@@ -189,15 +191,30 @@ def plot_all(
189191
self._visualizer.plot_hit_map(data, save_path=save_path)
190192
if metric_map:
191193
self._visualizer.plot_metric_map(
192-
data, target, reduction_parameter="mean", save_path=save_path
194+
bmus_data_map=bmus_data_map,
195+
data=data,
196+
target=target,
197+
reduction_parameter="mean",
198+
save_path=save_path,
193199
)
194200
self._visualizer.plot_metric_map(
195-
data, target, reduction_parameter="std", save_path=save_path
201+
bmus_data_map=bmus_data_map,
202+
data=data,
203+
target=target,
204+
reduction_parameter="std",
205+
save_path=save_path,
196206
)
197207
if score_map:
198-
self._visualizer.plot_score_map(data, target, save_path=save_path)
208+
self._visualizer.plot_score_map(
209+
bmus_data_map=bmus_data_map,
210+
target=target,
211+
total_samples=data.shape[0],
212+
save_path=save_path,
213+
)
199214
if rank_map:
200-
self._visualizer.plot_rank_map(data, target, save_path=save_path)
215+
self._visualizer.plot_rank_map(
216+
bmus_data_map=bmus_data_map, target=target, save_path=save_path
217+
)
201218
if component_planes:
202219
self._visualizer.plot_component_planes(
203220
component_names=component_names, save_path=save_path

0 commit comments

Comments
 (0)