Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
embed/* -text
18 changes: 18 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@ endif()
if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
add_compile_definitions(_SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING)

message(STATUS "MSVC detected, polyfilling #embed via Python...")

set(VOCAB_HEADER "${CMAKE_CURRENT_SOURCE_DIR}/vocab.hpp")
set(GENERATED_VOCAB_HPP "${CMAKE_CURRENT_SOURCE_DIR}/vocab_generated.hpp")
execute_process(
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/embed/embed_fix.py
${VOCAB_HEADER}
${GENERATED_VOCAB_HPP}
RESULT_VARIABLE result
)

if (NOT result EQUAL 0)
message(FATAL_ERROR "Failed to run embed_fix.py")
endif()

add_definitions(-DUSE_GENERATED_VOCAB)

endif()

set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
Expand Down
62 changes: 62 additions & 0 deletions embed/convert_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import re
import json
import sys
import os

def minify_json(input_file, output_file):
try:
# 1. 读取原始 JSON 文件
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)

# 2. 写入压缩后的 JSON
# separators=(',', ':') 会删除所有的空格和换行
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(data, f, separators=(',', ':'), ensure_ascii=False)

print(f"成功!压缩后的文件已保存至: {output_file}")

# 统计压缩率
orig_size = os.path.getsize(input_file)
mini_size = os.path.getsize(output_file)
reduction = (1 - mini_size / orig_size) * 100
print(f"原大小: {orig_size} bytes | 压缩后: {mini_size} bytes | 减小了: {reduction:.2f}%")

except FileNotFoundError:
print("错误:找不到输入文件。")
except json.JSONDecodeError:
print("错误:该文件不是有效的 JSON 格式。")
except Exception as e:
print(f"发生未知错误: {e}")

def extract_to_bin(source_file, output_bin):
with open(source_file, 'r', encoding='utf-8') as f:
content = f.read()

# 这里的正则匹配数组内容(匹配 { 和 } 之间的部分)
# 假设你的数据格式是 0xab, 0x12 ...
match = re.search(r'\{(.*)\}', content, re.DOTALL)
if not match:
print("未发现数组内容")
return

data_str = match.group(1)

# 提取所有十六进制或十进制数字
# 这个正则可以匹配 0xFF, 0xaf, 123 等格式
numbers = re.findall(r'(0x[0-9a-fA-F]+|[0-9]+)', data_str)

# 转换成字节序列
binary_data = bytearray()
for n in numbers:
if n.startswith('0x'):
binary_data.append(int(n, 16))
else:
binary_data.append(int(n, 10))

with open(output_bin, 'wb') as f:
f.write(binary_data)
print(f"转换完成:{output_bin}, 大小: {len(binary_data)} 字节")

#extract_to_bin('vocab_umt5.hpp', 'umt5_tokenizer.json')
minify_json("umt5_tokenizer.json.bak", "umt5_tokenizer.json")
43 changes: 43 additions & 0 deletions embed/embed_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import re
import os
import sys

def file_to_c_array(filepath):
with open(filepath, "rb") as f:
data = f.read()

cont = []
for i, b in enumerate(data):
val = ""
if (i + 1) % 16 == 0:
val = "\n"
val += f'{b}';
cont.append(val)

return ",".join(cont)

def process_header(input_path, output_hpp):
base_dir = os.path.dirname(os.path.abspath(input_path))

with open(input_path, "r", encoding="utf-8") as f:
content = f.read()

pattern = re.compile(r'char\s+(\w+)\[\]\s*=?\s*\{\s*#embed\s+"([^"]+)"\s*\};')
matches = pattern.findall(content)
if not matches:
print("No #embed found in vocab.hpp")
return

out_content = f'#pragma once\n\n'
for var_name, file_path in matches:
file_path = os.path.join(base_dir, file_path)

print(f"Embedding {file_path} into {var_name}...")
hex_data = file_to_c_array(file_path)
out_content += f"static const unsigned char {var_name}[] = {{\n{hex_data}\n}};\n"

with open(output_hpp, "w") as f: f.write(out_content)

if __name__ == "__main__":
# Usage: python embed_fix.py <vocab.hpp> <out.hpp>
process_header(sys.argv[1], sys.argv[2])
Loading