We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1c1bf24 commit e7e87b2Copy full SHA for e7e87b2
src/imitation/rewards/common.py
@@ -132,7 +132,7 @@ def compute_train_stats(
132
_n_gen_or_1 = max(1, n_generated)
133
generated_acc = _n_pred_gen / float(_n_gen_or_1)
134
135
- label_dist = th.distributions.Bernoulli(disc_logits_gen_is_high)
+ label_dist = th.distributions.Bernoulli(logits=disc_logits_gen_is_high)
136
entropy = th.mean(label_dist.entropy())
137
138
pairs = [
0 commit comments