1717import collections
1818import enum
1919import io
20- from typing import Optional , Sequence
20+ from typing import Any , Sequence
2121
22+ from absl import logging
2223import numpy as np
23-
2424import PIL
2525import PIL .Image
2626import PIL .ImageDraw
27+ import PIL .ImageFont
2728from scipy import special
28-
2929import tensorflow .compat .v1 as tf
30+
3031from . import mask
3132from . import variables
3233
@@ -62,20 +63,26 @@ class FovStat(enum.IntEnum):
6263class EvalTracker :
6364 """Tracks eval results over multiple training steps."""
6465
65- def __init__ (self ,
66- eval_shape : list [int ],
67- shifts : Sequence [ tuple [ int , int , int ]] ):
66+ def __init__ (
67+ self , eval_shape : list [int ], shifts : Sequence [ tuple [ int , int , int ]]
68+ ):
6869 # TODO(mjanusz): Remove this TFv1 code once no longer used.
6970 if not tf .executing_eagerly ():
7071 self .eval_labels = tf .compat .v1 .placeholder (
71- tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_labels' )
72+ tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_labels'
73+ )
7274 self .eval_preds = tf .compat .v1 .placeholder (
73- tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_preds' )
75+ tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_preds'
76+ )
7477 self .eval_weights = tf .compat .v1 .placeholder (
75- tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_weights' )
78+ tf .float32 , [1 ] + eval_shape + [1 ], name = 'eval_weights'
79+ )
7680 self .eval_loss = tf .reduce_mean (
77- self .eval_weights * tf .nn .sigmoid_cross_entropy_with_logits (
78- logits = self .eval_preds , labels = self .eval_labels ))
81+ self .eval_weights
82+ * tf .nn .sigmoid_cross_entropy_with_logits (
83+ logits = self .eval_preds , labels = self .eval_labels
84+ )
85+ )
7986 self .sess = None
8087 self .eval_threshold = special .logit (0.9 )
8188 self ._eval_shape = eval_shape # zyx
@@ -138,12 +145,15 @@ def track_weights(self, weights: np.ndarray):
138145 self .fov_stats .value [FovStat .MASKED_VOXELS ] += np .sum (weights == 0.0 )
139146 self .fov_stats .value [FovStat .WEIGHTS_SUM ] += np .sum (weights )
140147
141- def record_move (self , wanted : bool , executed : bool ,
142- offset_xyz : Sequence [int ]):
148+ def record_move (
149+ self , wanted : bool , executed : bool , offset_xyz : Sequence [int ]
150+ ):
143151 """Records an FFN FOV move."""
144152 r = int (np .linalg .norm (offset_xyz ))
145- assert r in self .moves_by_r , ('%d not in %r' %
146- (r , list (self .moves_by_r .keys ())))
153+ assert r in self .moves_by_r , '%d not in %r' % (
154+ r ,
155+ list (self .moves_by_r .keys ()),
156+ )
147157
148158 if wanted :
149159 if executed :
@@ -156,9 +166,15 @@ def record_move(self, wanted: bool, executed: bool,
156166 self .moves .value [MoveType .SPURIOUS ] += 1
157167 self .moves_by_r [r ].value [MoveType .SPURIOUS ] += 1
158168
159- def slice_image (self , coord : np .ndarray , labels : np .ndarray ,
160- predicted : np .ndarray , weights : np .ndarray ,
161- slice_axis : int ) -> tf .Summary .Value :
169+ def slice_image (
170+ self ,
171+ coord : np .ndarray ,
172+ labels : np .ndarray ,
173+ predicted : np .ndarray ,
174+ weights : np .ndarray ,
175+ slice_axis : int ,
176+ volume_name : str | bytes | Sequence [Any ] | np .ndarray | None = None ,
177+ ) -> tf .Summary .Value :
162178 """Builds a tf.Summary showing a slice of an object mask.
163179
164180 The object mask slice is shown side by side with the corresponding
@@ -172,6 +188,7 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
172188 slice_axis: axis in the middle of which to place the cutting plane for
173189 which the summary image will be generated, valid values are 2 ('x'), 1
174190 ('y'), and 0 ('z').
191+ volume_name: name of the volume to be displayed on the image.
175192
176193 Returns:
177194 tf.Summary.Value object with the image.
@@ -191,14 +208,37 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
191208
192209 im = PIL .Image .fromarray (
193210 np .repeat (
194- np .concatenate ([labels , predicted , weights ], axis = 1 )[...,
195- np .newaxis ],
211+ np .concatenate ([labels , predicted , weights ], axis = 1 )[
212+ ..., np .newaxis
213+ ],
196214 3 ,
197- axis = 2 ), 'RGB' )
215+ axis = 2 ,
216+ ),
217+ 'RGB' ,
218+ )
198219 draw = PIL .ImageDraw .Draw (im )
199220
200221 x , y , z = coord .squeeze ()
201- draw .text ((1 , 1 ), '%d %d %d' % (x , y , z ), fill = 'rgb(255,64,64)' )
222+ text = f'{ x } ,{ y } ,{ z } '
223+ if volume_name is not None :
224+ if (
225+ isinstance (volume_name , (list , tuple , np .ndarray ))
226+ and len (volume_name ) == 1
227+ ):
228+ volume_name = volume_name [0 ]
229+
230+ if isinstance (volume_name , bytes ):
231+ volume_name = volume_name .decode ('utf-8' )
232+
233+ text += f'\n { volume_name } '
234+
235+ try :
236+
237+ # font = PIL.ImageFont.load_default()
238+ except (IOError , ValueError ):
239+ font = PIL .ImageFont .load_default ()
240+
241+ draw .text ((1 , 1 ), text , fill = 'rgb(255,64,64)' , font = font )
202242 del draw
203243
204244 im .save (buf , 'PNG' )
@@ -212,14 +252,19 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
212252 height = h ,
213253 width = w * 3 ,
214254 colorspace = 3 , # RGB
215- encoded_image_string = buf .getvalue ()))
216-
217- def add_patch (self ,
218- labels : np .ndarray ,
219- predicted : np .ndarray ,
220- weights : np .ndarray ,
221- coord : Optional [np .ndarray ] = None ,
222- image_summaries : bool = True ):
255+ encoded_image_string = buf .getvalue (),
256+ ),
257+ )
258+
259+ def add_patch (
260+ self ,
261+ labels : np .ndarray ,
262+ predicted : np .ndarray ,
263+ weights : np .ndarray ,
264+ coord : np .ndarray | None = None ,
265+ image_summaries : bool = True ,
266+ volume_name : str | None = None ,
267+ ):
223268 """Evaluates single-object segmentation quality."""
224269
225270 predicted = mask .crop_and_pad (predicted , (0 , 0 , 0 ), self ._eval_shape )
@@ -228,15 +273,21 @@ def add_patch(self,
228273
229274 if not tf .executing_eagerly ():
230275 assert self .sess is not None
231- loss , = self .sess .run (
232- [self .eval_loss ], {
276+ (loss ,) = self .sess .run (
277+ [self .eval_loss ],
278+ {
233279 self .eval_labels : labels ,
234280 self .eval_preds : predicted ,
235- self .eval_weights : weights
236- })
281+ self .eval_weights : weights ,
282+ },
283+ )
237284 else :
238- loss = tf .reduce_mean (weights * tf .nn .sigmoid_cross_entropy_with_logits (
239- logits = predicted , labels = labels ))
285+ loss = tf .reduce_mean (
286+ weights
287+ * tf .nn .sigmoid_cross_entropy_with_logits (
288+ logits = predicted , labels = labels
289+ )
290+ )
240291
241292 self .loss .value [:] += loss
242293 self .num_voxels .value [VoxelType .TOTAL ] += labels .size
@@ -247,23 +298,29 @@ def add_patch(self,
247298 pred_bg = np .logical_not (pred_mask )
248299 true_bg = np .logical_not (true_mask )
249300
250- self .prediction_counts .value [PredictionType .TP ] += np .sum (pred_mask
251- & true_mask )
301+ self .prediction_counts .value [PredictionType .TP ] += np .sum (
302+ pred_mask & true_mask
303+ )
252304 self .prediction_counts .value [PredictionType .TN ] += np .sum (pred_bg & true_bg )
253- self .prediction_counts .value [PredictionType .FP ] += np .sum (pred_mask
254- & true_bg )
255- self .prediction_counts .value [PredictionType .FN ] += np .sum (pred_bg
256- & true_mask )
305+ self .prediction_counts .value [PredictionType .FP ] += np .sum (
306+ pred_mask & true_bg
307+ )
308+ self .prediction_counts .value [PredictionType .FN ] += np .sum (
309+ pred_bg & true_mask
310+ )
257311 self .num_patches .value [:] += 1
258312
259313 if image_summaries :
260314 predicted = special .expit (predicted )
261315 self .images_xy .append (
262- self .slice_image (coord , labels , predicted , weights , 0 ))
316+ self .slice_image (coord , labels , predicted , weights , 0 , volume_name )
317+ )
263318 self .images_xz .append (
264- self .slice_image (coord , labels , predicted , weights , 1 ))
319+ self .slice_image (coord , labels , predicted , weights , 1 , volume_name )
320+ )
265321 self .images_yz .append (
266- self .slice_image (coord , labels , predicted , weights , 2 ))
322+ self .slice_image (coord , labels , predicted , weights , 2 , volume_name )
323+ )
267324
268325 def _compute_classification_metrics (self , prediction_counts , prefix ):
269326 """Computes standard classification metrics."""
@@ -276,19 +333,21 @@ def _compute_classification_metrics(self, prediction_counts, prefix):
276333 recall = tp / max (tp + fn , 1 )
277334
278335 if precision > 0 or recall > 0 :
279- f1 = ( 2.0 * precision * recall / (precision + recall ) )
336+ f1 = 2.0 * precision * recall / (precision + recall )
280337 else :
281338 f1 = 0.0
282339
283340 return [
284341 tf .Summary .Value (
285342 tag = '%s/accuracy' % prefix ,
286- simple_value = (tp + tn ) / max (tp + tn + fp + fn , 1 )),
343+ simple_value = (tp + tn ) / max (tp + tn + fp + fn , 1 ),
344+ ),
287345 tf .Summary .Value (tag = '%s/precision' % prefix , simple_value = precision ),
288346 tf .Summary .Value (tag = '%s/recall' % prefix , simple_value = recall ),
289347 tf .Summary .Value (
290- tag = '%s/specificity' % prefix , simple_value = tn / max (tn + fp , 1 )),
291- tf .Summary .Value (tag = '%s/f1' % prefix , simple_value = f1 )
348+ tag = '%s/specificity' % prefix , simple_value = tn / max (tn + fp , 1 )
349+ ),
350+ tf .Summary .Value (tag = '%s/f1' % prefix , simple_value = f1 ),
292351 ]
293352
294353 def get_summaries (self ) -> list [tf .Summary .Value ]:
@@ -308,49 +367,74 @@ def get_summaries(self) -> list[tf.Summary.Value]:
308367 move_summaries .append (
309368 tf .Summary .Value (
310369 tag = 'moves/all/%s' % mt .name .lower (),
311- simple_value = self .moves .tf_value [mt ] / total_moves ))
312-
313- summaries = [
314- tf .Summary .Value (
315- tag = 'fov/masked_voxel_fraction' ,
316- simple_value = (self .fov_stats .tf_value [FovStat .MASKED_VOXELS ] /
317- self .fov_stats .tf_value [FovStat .TOTAL_VOXELS ])),
318- tf .Summary .Value (
319- tag = 'fov/average_weight' ,
320- simple_value = (self .fov_stats .tf_value [FovStat .WEIGHTS_SUM ] /
321- self .fov_stats .tf_value [FovStat .TOTAL_VOXELS ])),
322- tf .Summary .Value (
323- tag = 'masked_voxel_fraction' ,
324- simple_value = (self .num_voxels .tf_value [VoxelType .MASKED ] /
325- self .num_voxels .tf_value [VoxelType .TOTAL ])),
326- tf .Summary .Value (
327- tag = 'eval/patch_loss' ,
328- simple_value = self .loss .tf_value [0 ] / self .num_patches .tf_value [0 ]),
329- tf .Summary .Value (
330- tag = 'eval/patches' , simple_value = self .num_patches .tf_value [0 ]),
331- tf .Summary .Value (tag = 'moves/total' , simple_value = total_moves )
332- ] + move_summaries + (
333- list (self .meshes ) + list (self .images_xy ) + list (self .images_xz ) +
334- list (self .images_yz ))
370+ simple_value = self .moves .tf_value [mt ] / total_moves ,
371+ )
372+ )
373+
374+ summaries = (
375+ [
376+ tf .Summary .Value (
377+ tag = 'fov/masked_voxel_fraction' ,
378+ simple_value = (
379+ self .fov_stats .tf_value [FovStat .MASKED_VOXELS ]
380+ / self .fov_stats .tf_value [FovStat .TOTAL_VOXELS ]
381+ ),
382+ ),
383+ tf .Summary .Value (
384+ tag = 'fov/average_weight' ,
385+ simple_value = (
386+ self .fov_stats .tf_value [FovStat .WEIGHTS_SUM ]
387+ / self .fov_stats .tf_value [FovStat .TOTAL_VOXELS ]
388+ ),
389+ ),
390+ tf .Summary .Value (
391+ tag = 'masked_voxel_fraction' ,
392+ simple_value = (
393+ self .num_voxels .tf_value [VoxelType .MASKED ]
394+ / self .num_voxels .tf_value [VoxelType .TOTAL ]
395+ ),
396+ ),
397+ tf .Summary .Value (
398+ tag = 'eval/patch_loss' ,
399+ simple_value = self .loss .tf_value [0 ]
400+ / self .num_patches .tf_value [0 ],
401+ ),
402+ tf .Summary .Value (
403+ tag = 'eval/patches' , simple_value = self .num_patches .tf_value [0 ]
404+ ),
405+ tf .Summary .Value (tag = 'moves/total' , simple_value = total_moves ),
406+ ]
407+ + move_summaries
408+ + (
409+ list (self .meshes )
410+ + list (self .images_xy )
411+ + list (self .images_xz )
412+ + list (self .images_yz )
413+ )
414+ )
335415
336416 summaries .extend (
337- self ._compute_classification_metrics (self .prediction_counts ,
338- 'eval/all' ) )
417+ self ._compute_classification_metrics (self .prediction_counts , 'eval/all' )
418+ )
339419
340420 for r , r_moves in self .moves_by_r .items ():
341421 total_moves = sum (r_moves .tf_value )
342422 summaries .extend ([
343423 tf .Summary .Value (
344424 tag = 'moves/r=%d/correct' % r ,
345- simple_value = r_moves .tf_value [MoveType .CORRECT ] / total_moves ),
425+ simple_value = r_moves .tf_value [MoveType .CORRECT ] / total_moves ,
426+ ),
346427 tf .Summary .Value (
347428 tag = 'moves/r=%d/spurious' % r ,
348- simple_value = r_moves .tf_value [MoveType .SPURIOUS ] / total_moves ),
429+ simple_value = r_moves .tf_value [MoveType .SPURIOUS ] / total_moves ,
430+ ),
349431 tf .Summary .Value (
350432 tag = 'moves/r=%d/missed' % r ,
351- simple_value = r_moves .tf_value [MoveType .MISSED ] / total_moves ),
433+ simple_value = r_moves .tf_value [MoveType .MISSED ] / total_moves ,
434+ ),
352435 tf .Summary .Value (
353- tag = 'moves/r=%d/total' % r , simple_value = total_moves )
436+ tag = 'moves/r=%d/total' % r , simple_value = total_moves
437+ ),
354438 ])
355439
356440 return summaries
0 commit comments