重复惩罚疑问 #293
重复惩罚疑问
#293
-
|
在model.py文件中的361行,是对逻辑值进行重复惩罚的操作。 logits[:, list(set(input_ids.tolist()[0]))] /= rp但这里似乎对batch中的每一条数据,都根据第一条数据来判断是否进行重复惩罚了,这样是合理的吗?(本人小白,不太懂) |
Beta Was this translation helpful? Give feedback.
Answered by
jingyaogong
Apr 1, 2025
Replies: 1 comment
-
|
好问题
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) # batch_size=1
out = self._stream(non_pad, ...)
如果将来要支持真正的批量生成(parallel generation),那么这行代码确实需要修改,每个样本都应该根据自己的历史 tokens 来进行重复惩罚。但目前的实现在当前的使用场景下是正确的。 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
jingyaogong
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
好问题
input_ids的 shape 是[1, seq_len],因为这个方法是单条数据生成用的。看一下生成函数的调用链:
_stream,每次只传入一条数据:_stream中使用input_ids.tolist()[0]是没有问题的,因为:[0]索引就是获取这唯一的一条数据set()用于获取已生成的 token 集合,用于重复惩罚如果将来要支持真正的批量生成(parallel generation),那么这行代码确实需要修改,每个样本都应该根据自己的历史 tokens 来进行重复惩罚。但目前的实现在当前的使用场景下是正确的。