Skip to content
Discussion options

You must be logged in to vote

好问题

  1. input_ids 的 shape 是 [1, seq_len],因为这个方法是单条数据生成用的。

  2. 看一下生成函数的调用链:

  • 对于批量生成,generate它会循环调用 _stream,每次只传入一条数据:
for i in range(input_ids.size(0)):
    non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)  # batch_size=1
    out = self._stream(non_pad, ...)
  1. 所以在 _stream 中使用 input_ids.tolist()[0] 是没有问题的,因为:
  • batch_size 始终为 1
  • [0] 索引就是获取这唯一的一条数据
  • set() 用于获取已生成的 token 集合,用于重复惩罚

如果将来要支持真正的批量生成(parallel generation),那么这行代码确实需要修改,每个样本都应该根据自己的历史 tokens 来进行重复惩罚。但目前的实现在当前的使用场景下是正确的。

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jingyaogong
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants