Skip to content

Commit e7e87b2

Browse files
authored
compute_train_stats: Fix logits passed in as proba (#273)
Led to an error when I was training.
1 parent 1c1bf24 commit e7e87b2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/imitation/rewards/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def compute_train_stats(
132132
_n_gen_or_1 = max(1, n_generated)
133133
generated_acc = _n_pred_gen / float(_n_gen_or_1)
134134

135-
label_dist = th.distributions.Bernoulli(disc_logits_gen_is_high)
135+
label_dist = th.distributions.Bernoulli(logits=disc_logits_gen_is_high)
136136
entropy = th.mean(label_dist.entropy())
137137

138138
pairs = [

0 commit comments

Comments
 (0)