Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 83 additions & 12 deletions lua/model/core/chat.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,76 @@ local function split_messages(text)
local chunk_lines = {}
local chunk_is_user = true

local blocks = {}
local block = {}

--- Insert message and reset/toggle chunk state. User text is trimmed.
local function add_message()
local text_ = table.concat(chunk_lines, '\n')

table.insert(messages, {
role = chunk_is_user and 'user' or 'assistant',
content = chunk_is_user and vim.trim(text_) or text_,
content = blocks,
})

chunk_lines = {}
blocks = {}
chunk_is_user = not chunk_is_user
end

local function add_block()
if block.type == nil then
block.type = 'text'
end

local text_ = table.concat(chunk_lines, '\n')
text_ = vim.trim(text_)

if block.type == 'text' then
block.text = text_
elseif block.type == 'thinking' then
block.thinking = text_
elseif block.type == 'redacted_thinking' then
block.data = text_
end

table.insert(blocks, block)

block = {}
chunk_lines = {}
end

for i, line in ipairs(text) do
if i == 1 then
local is_system = i == 1 and line:match('^> (.+)') ~= nil
local is_thinking = not chunk_is_user and line:match('<thinking') ~= nil
local is_end_thinking = not chunk_is_user
and line:match('^</thinking') ~= nil
local is_redacted_thinking = not chunk_is_user
and line:match('<redacted_thinking') ~= nil
local is_end_redacted_thinking = not chunk_is_user
and line:match('^</redacted_thinking') ~= nil
local is_end_message = line == '======'

if is_system then
system = line:match('^> (.+)')

if system == nil then
table.insert(chunk_lines, line)
end
elseif line == '======' then
elseif is_end_message then
add_block()
add_message()
elseif is_thinking then
block.type = 'thinking'
elseif is_end_thinking then
local signature = line:match('</thinking%s+signature="([^"]*)"')
block.signature = signature
add_block()
elseif is_redacted_thinking then
block.type = 'redacted_thinking'
elseif is_end_redacted_thinking then
add_block()
else
table.insert(chunk_lines, line)
end
end

-- add text after last `======` if not empty
if table.concat(chunk_lines, '') ~= '' then
add_block()
add_message()
end

Expand Down Expand Up @@ -137,6 +178,33 @@ local function parse_config(text)
end
end

local function message_content_to_string(content)
if type(content) == 'string' then
return content
end

local result = {}
for i, block in ipairs(content) do
if block.type == 'text' then
result[i] = block.text
elseif block.type == 'thinking' then
result[i] = '<thinking>\n'
.. block.thinking
.. '\n</thinking signature="'
.. block.signature
.. '">'
elseif block.type == 'redacted_thinking' then
result[i] = '<redacted_thinking>\n'
.. block.data
.. '\n</redacted_thinking>'
end
end

result = table.concat(result, '\n\n')

return result
end

--- Parse a chat file. Must start with a chat name, can follow with a lua table
--- of config between `---`. If the next line starts with `> `, it is parsed as
--- the system instruction. The rest of the text is parsed as alternating
Expand Down Expand Up @@ -184,9 +252,12 @@ function M.to_string(contents, name)
end

if message.role == 'user' then
result = result .. '\n' .. message.content .. '\n'
result = result
.. '\n'
.. message_content_to_string(message.content)
.. '\n'
else
result = result .. message.content
result = result .. message_content_to_string(message.content)
end
end

Expand Down
22 changes: 20 additions & 2 deletions lua/model/providers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,26 @@ local M = {
on_message = function(msg)
local data = util.json.decode(msg.data)

if msg.event == 'content_block_delta' then
consume(data.delta.text)
if msg.event == 'content_block_start' then
if data.content_block.type == 'thinking' then
consume('<thinking>\n')
elseif data.content_block.type == 'redacted_thinking' then
consume(
'<redacted_thinking>\n'
.. data.content_block.data
.. '\n</redacted_thinking>\n'
)
end
elseif msg.event == 'content_block_delta' then
if data.delta.type == 'thinking_delta' then
consume(data.delta.thinking)
elseif data.delta.type == 'signature_delta' then
consume(
'\n</thinking signature="' .. data.delta.signature .. '">\n\n'
)
else
consume(data.delta.text)
end
elseif msg.event == 'message_delta' then
util.show(data.usage.output_tokens, 'output tokens')
elseif msg.event == 'message_stop' then
Expand Down
Loading