Skip to content

Commit 75eb6c4

Browse files
mjanuszcopybara-github
authored andcommitted
Log volume name in the eval summaries.
PiperOrigin-RevId: 854145174
1 parent fd56e84 commit 75eb6c4

File tree

3 files changed

+253
-80
lines changed

3 files changed

+253
-80
lines changed

ffn/training/examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def get_example(load_example, eval_tracker: tracker.EvalTracker,
7070
assert predicted.base is seed
7171
yield predicted, patches, labels, weights
7272

73-
eval_tracker.add_patch(full_labels, seed, loss_weights, coord)
73+
eval_tracker.add_patch(full_labels, seed, loss_weights, coord,
74+
volume_name=volname)
7475

7576

7677
ExampleGenerator = Iterable[tuple[np.ndarray, np.ndarray, np.ndarray,

ffn/training/tracker.py

Lines changed: 163 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
import collections
1818
import enum
1919
import io
20-
from typing import Optional, Sequence
20+
from typing import Any, Sequence
2121

22+
from absl import logging
2223
import numpy as np
23-
2424
import PIL
2525
import PIL.Image
2626
import PIL.ImageDraw
27+
import PIL.ImageFont
2728
from scipy import special
28-
2929
import tensorflow.compat.v1 as tf
30+
3031
from . import mask
3132
from . import variables
3233

@@ -62,20 +63,26 @@ class FovStat(enum.IntEnum):
6263
class EvalTracker:
6364
"""Tracks eval results over multiple training steps."""
6465

65-
def __init__(self,
66-
eval_shape: list[int],
67-
shifts: Sequence[tuple[int, int, int]]):
66+
def __init__(
67+
self, eval_shape: list[int], shifts: Sequence[tuple[int, int, int]]
68+
):
6869
# TODO(mjanusz): Remove this TFv1 code once no longer used.
6970
if not tf.executing_eagerly():
7071
self.eval_labels = tf.compat.v1.placeholder(
71-
tf.float32, [1] + eval_shape + [1], name='eval_labels')
72+
tf.float32, [1] + eval_shape + [1], name='eval_labels'
73+
)
7274
self.eval_preds = tf.compat.v1.placeholder(
73-
tf.float32, [1] + eval_shape + [1], name='eval_preds')
75+
tf.float32, [1] + eval_shape + [1], name='eval_preds'
76+
)
7477
self.eval_weights = tf.compat.v1.placeholder(
75-
tf.float32, [1] + eval_shape + [1], name='eval_weights')
78+
tf.float32, [1] + eval_shape + [1], name='eval_weights'
79+
)
7680
self.eval_loss = tf.reduce_mean(
77-
self.eval_weights * tf.nn.sigmoid_cross_entropy_with_logits(
78-
logits=self.eval_preds, labels=self.eval_labels))
81+
self.eval_weights
82+
* tf.nn.sigmoid_cross_entropy_with_logits(
83+
logits=self.eval_preds, labels=self.eval_labels
84+
)
85+
)
7986
self.sess = None
8087
self.eval_threshold = special.logit(0.9)
8188
self._eval_shape = eval_shape # zyx
@@ -138,12 +145,15 @@ def track_weights(self, weights: np.ndarray):
138145
self.fov_stats.value[FovStat.MASKED_VOXELS] += np.sum(weights == 0.0)
139146
self.fov_stats.value[FovStat.WEIGHTS_SUM] += np.sum(weights)
140147

141-
def record_move(self, wanted: bool, executed: bool,
142-
offset_xyz: Sequence[int]):
148+
def record_move(
149+
self, wanted: bool, executed: bool, offset_xyz: Sequence[int]
150+
):
143151
"""Records an FFN FOV move."""
144152
r = int(np.linalg.norm(offset_xyz))
145-
assert r in self.moves_by_r, ('%d not in %r' %
146-
(r, list(self.moves_by_r.keys())))
153+
assert r in self.moves_by_r, '%d not in %r' % (
154+
r,
155+
list(self.moves_by_r.keys()),
156+
)
147157

148158
if wanted:
149159
if executed:
@@ -156,9 +166,15 @@ def record_move(self, wanted: bool, executed: bool,
156166
self.moves.value[MoveType.SPURIOUS] += 1
157167
self.moves_by_r[r].value[MoveType.SPURIOUS] += 1
158168

159-
def slice_image(self, coord: np.ndarray, labels: np.ndarray,
160-
predicted: np.ndarray, weights: np.ndarray,
161-
slice_axis: int) -> tf.Summary.Value:
169+
def slice_image(
170+
self,
171+
coord: np.ndarray,
172+
labels: np.ndarray,
173+
predicted: np.ndarray,
174+
weights: np.ndarray,
175+
slice_axis: int,
176+
volume_name: str | bytes | Sequence[Any] | np.ndarray | None = None,
177+
) -> tf.Summary.Value:
162178
"""Builds a tf.Summary showing a slice of an object mask.
163179
164180
The object mask slice is shown side by side with the corresponding
@@ -172,6 +188,7 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
172188
slice_axis: axis in the middle of which to place the cutting plane for
173189
which the summary image will be generated, valid values are 2 ('x'), 1
174190
('y'), and 0 ('z').
191+
volume_name: name of the volume to be displayed on the image.
175192
176193
Returns:
177194
tf.Summary.Value object with the image.
@@ -191,14 +208,37 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
191208

192209
im = PIL.Image.fromarray(
193210
np.repeat(
194-
np.concatenate([labels, predicted, weights], axis=1)[...,
195-
np.newaxis],
211+
np.concatenate([labels, predicted, weights], axis=1)[
212+
..., np.newaxis
213+
],
196214
3,
197-
axis=2), 'RGB')
215+
axis=2,
216+
),
217+
'RGB',
218+
)
198219
draw = PIL.ImageDraw.Draw(im)
199220

200221
x, y, z = coord.squeeze()
201-
draw.text((1, 1), '%d %d %d' % (x, y, z), fill='rgb(255,64,64)')
222+
text = f'{x},{y},{z}'
223+
if volume_name is not None:
224+
if (
225+
isinstance(volume_name, (list, tuple, np.ndarray))
226+
and len(volume_name) == 1
227+
):
228+
volume_name = volume_name[0]
229+
230+
if isinstance(volume_name, bytes):
231+
volume_name = volume_name.decode('utf-8')
232+
233+
text += f'\n{volume_name}'
234+
235+
try:
236+
237+
# font = PIL.ImageFont.load_default()
238+
except (IOError, ValueError):
239+
font = PIL.ImageFont.load_default()
240+
241+
draw.text((1, 1), text, fill='rgb(255,64,64)', font=font)
202242
del draw
203243

204244
im.save(buf, 'PNG')
@@ -212,14 +252,19 @@ def slice_image(self, coord: np.ndarray, labels: np.ndarray,
212252
height=h,
213253
width=w * 3,
214254
colorspace=3, # RGB
215-
encoded_image_string=buf.getvalue()))
216-
217-
def add_patch(self,
218-
labels: np.ndarray,
219-
predicted: np.ndarray,
220-
weights: np.ndarray,
221-
coord: Optional[np.ndarray] = None,
222-
image_summaries: bool = True):
255+
encoded_image_string=buf.getvalue(),
256+
),
257+
)
258+
259+
def add_patch(
260+
self,
261+
labels: np.ndarray,
262+
predicted: np.ndarray,
263+
weights: np.ndarray,
264+
coord: np.ndarray | None = None,
265+
image_summaries: bool = True,
266+
volume_name: str | None = None,
267+
):
223268
"""Evaluates single-object segmentation quality."""
224269

225270
predicted = mask.crop_and_pad(predicted, (0, 0, 0), self._eval_shape)
@@ -228,15 +273,21 @@ def add_patch(self,
228273

229274
if not tf.executing_eagerly():
230275
assert self.sess is not None
231-
loss, = self.sess.run(
232-
[self.eval_loss], {
276+
(loss,) = self.sess.run(
277+
[self.eval_loss],
278+
{
233279
self.eval_labels: labels,
234280
self.eval_preds: predicted,
235-
self.eval_weights: weights
236-
})
281+
self.eval_weights: weights,
282+
},
283+
)
237284
else:
238-
loss = tf.reduce_mean(weights * tf.nn.sigmoid_cross_entropy_with_logits(
239-
logits=predicted, labels=labels))
285+
loss = tf.reduce_mean(
286+
weights
287+
* tf.nn.sigmoid_cross_entropy_with_logits(
288+
logits=predicted, labels=labels
289+
)
290+
)
240291

241292
self.loss.value[:] += loss
242293
self.num_voxels.value[VoxelType.TOTAL] += labels.size
@@ -247,23 +298,29 @@ def add_patch(self,
247298
pred_bg = np.logical_not(pred_mask)
248299
true_bg = np.logical_not(true_mask)
249300

250-
self.prediction_counts.value[PredictionType.TP] += np.sum(pred_mask
251-
& true_mask)
301+
self.prediction_counts.value[PredictionType.TP] += np.sum(
302+
pred_mask & true_mask
303+
)
252304
self.prediction_counts.value[PredictionType.TN] += np.sum(pred_bg & true_bg)
253-
self.prediction_counts.value[PredictionType.FP] += np.sum(pred_mask
254-
& true_bg)
255-
self.prediction_counts.value[PredictionType.FN] += np.sum(pred_bg
256-
& true_mask)
305+
self.prediction_counts.value[PredictionType.FP] += np.sum(
306+
pred_mask & true_bg
307+
)
308+
self.prediction_counts.value[PredictionType.FN] += np.sum(
309+
pred_bg & true_mask
310+
)
257311
self.num_patches.value[:] += 1
258312

259313
if image_summaries:
260314
predicted = special.expit(predicted)
261315
self.images_xy.append(
262-
self.slice_image(coord, labels, predicted, weights, 0))
316+
self.slice_image(coord, labels, predicted, weights, 0, volume_name)
317+
)
263318
self.images_xz.append(
264-
self.slice_image(coord, labels, predicted, weights, 1))
319+
self.slice_image(coord, labels, predicted, weights, 1, volume_name)
320+
)
265321
self.images_yz.append(
266-
self.slice_image(coord, labels, predicted, weights, 2))
322+
self.slice_image(coord, labels, predicted, weights, 2, volume_name)
323+
)
267324

268325
def _compute_classification_metrics(self, prediction_counts, prefix):
269326
"""Computes standard classification metrics."""
@@ -276,19 +333,21 @@ def _compute_classification_metrics(self, prediction_counts, prefix):
276333
recall = tp / max(tp + fn, 1)
277334

278335
if precision > 0 or recall > 0:
279-
f1 = (2.0 * precision * recall / (precision + recall))
336+
f1 = 2.0 * precision * recall / (precision + recall)
280337
else:
281338
f1 = 0.0
282339

283340
return [
284341
tf.Summary.Value(
285342
tag='%s/accuracy' % prefix,
286-
simple_value=(tp + tn) / max(tp + tn + fp + fn, 1)),
343+
simple_value=(tp + tn) / max(tp + tn + fp + fn, 1),
344+
),
287345
tf.Summary.Value(tag='%s/precision' % prefix, simple_value=precision),
288346
tf.Summary.Value(tag='%s/recall' % prefix, simple_value=recall),
289347
tf.Summary.Value(
290-
tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)),
291-
tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1)
348+
tag='%s/specificity' % prefix, simple_value=tn / max(tn + fp, 1)
349+
),
350+
tf.Summary.Value(tag='%s/f1' % prefix, simple_value=f1),
292351
]
293352

294353
def get_summaries(self) -> list[tf.Summary.Value]:
@@ -308,49 +367,74 @@ def get_summaries(self) -> list[tf.Summary.Value]:
308367
move_summaries.append(
309368
tf.Summary.Value(
310369
tag='moves/all/%s' % mt.name.lower(),
311-
simple_value=self.moves.tf_value[mt] / total_moves))
312-
313-
summaries = [
314-
tf.Summary.Value(
315-
tag='fov/masked_voxel_fraction',
316-
simple_value=(self.fov_stats.tf_value[FovStat.MASKED_VOXELS] /
317-
self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])),
318-
tf.Summary.Value(
319-
tag='fov/average_weight',
320-
simple_value=(self.fov_stats.tf_value[FovStat.WEIGHTS_SUM] /
321-
self.fov_stats.tf_value[FovStat.TOTAL_VOXELS])),
322-
tf.Summary.Value(
323-
tag='masked_voxel_fraction',
324-
simple_value=(self.num_voxels.tf_value[VoxelType.MASKED] /
325-
self.num_voxels.tf_value[VoxelType.TOTAL])),
326-
tf.Summary.Value(
327-
tag='eval/patch_loss',
328-
simple_value=self.loss.tf_value[0] / self.num_patches.tf_value[0]),
329-
tf.Summary.Value(
330-
tag='eval/patches', simple_value=self.num_patches.tf_value[0]),
331-
tf.Summary.Value(tag='moves/total', simple_value=total_moves)
332-
] + move_summaries + (
333-
list(self.meshes) + list(self.images_xy) + list(self.images_xz) +
334-
list(self.images_yz))
370+
simple_value=self.moves.tf_value[mt] / total_moves,
371+
)
372+
)
373+
374+
summaries = (
375+
[
376+
tf.Summary.Value(
377+
tag='fov/masked_voxel_fraction',
378+
simple_value=(
379+
self.fov_stats.tf_value[FovStat.MASKED_VOXELS]
380+
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
381+
),
382+
),
383+
tf.Summary.Value(
384+
tag='fov/average_weight',
385+
simple_value=(
386+
self.fov_stats.tf_value[FovStat.WEIGHTS_SUM]
387+
/ self.fov_stats.tf_value[FovStat.TOTAL_VOXELS]
388+
),
389+
),
390+
tf.Summary.Value(
391+
tag='masked_voxel_fraction',
392+
simple_value=(
393+
self.num_voxels.tf_value[VoxelType.MASKED]
394+
/ self.num_voxels.tf_value[VoxelType.TOTAL]
395+
),
396+
),
397+
tf.Summary.Value(
398+
tag='eval/patch_loss',
399+
simple_value=self.loss.tf_value[0]
400+
/ self.num_patches.tf_value[0],
401+
),
402+
tf.Summary.Value(
403+
tag='eval/patches', simple_value=self.num_patches.tf_value[0]
404+
),
405+
tf.Summary.Value(tag='moves/total', simple_value=total_moves),
406+
]
407+
+ move_summaries
408+
+ (
409+
list(self.meshes)
410+
+ list(self.images_xy)
411+
+ list(self.images_xz)
412+
+ list(self.images_yz)
413+
)
414+
)
335415

336416
summaries.extend(
337-
self._compute_classification_metrics(self.prediction_counts,
338-
'eval/all'))
417+
self._compute_classification_metrics(self.prediction_counts, 'eval/all')
418+
)
339419

340420
for r, r_moves in self.moves_by_r.items():
341421
total_moves = sum(r_moves.tf_value)
342422
summaries.extend([
343423
tf.Summary.Value(
344424
tag='moves/r=%d/correct' % r,
345-
simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves),
425+
simple_value=r_moves.tf_value[MoveType.CORRECT] / total_moves,
426+
),
346427
tf.Summary.Value(
347428
tag='moves/r=%d/spurious' % r,
348-
simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves),
429+
simple_value=r_moves.tf_value[MoveType.SPURIOUS] / total_moves,
430+
),
349431
tf.Summary.Value(
350432
tag='moves/r=%d/missed' % r,
351-
simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves),
433+
simple_value=r_moves.tf_value[MoveType.MISSED] / total_moves,
434+
),
352435
tf.Summary.Value(
353-
tag='moves/r=%d/total' % r, simple_value=total_moves)
436+
tag='moves/r=%d/total' % r, simple_value=total_moves
437+
),
354438
])
355439

356440
return summaries

0 commit comments

Comments
 (0)