Skip to content

Commit b4770df

Browse files
authored
Migrate to PyTorch tokenizers (#781)
## Description This PR migrates us from tokenisers-cpp to PyTorch tokenisers that are by default bundled with executorch binaries ### Introduces a breaking change? - [ ] Yes - [x] No - User faces no changes ### Type of change - [ ] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions This changes need to be tested manually. Try running all our apps that consume tokenizers and see whether the output is ok. ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings
1 parent 4af4cc7 commit b4770df

File tree

155 files changed

+36974
-300
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

155 files changed

+36974
-300
lines changed

packages/react-native-executorch/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ string(APPEND CMAKE_CXX_FLAGS " -DRCT_NEW_ARCH_ENABLED")
2121
set(ANDROID_CPP_DIR "${CMAKE_SOURCE_DIR}/src/main/cpp")
2222
set(COMMON_CPP_DIR "${CMAKE_SOURCE_DIR}/../common")
2323
set(LIBS_DIR "${CMAKE_SOURCE_DIR}/../third-party/android/libs")
24+
set(TOKENIZERS_DIR "${CMAKE_SOURCE_DIR}/../third-party/include/executorch/extension/llm/tokenizers/include")
2425
set(INCLUDE_DIR "${CMAKE_SOURCE_DIR}/../third-party/include")
2526

2627
# Treat third-party headers as system headers to suppress deprecation warnings

packages/react-native-executorch/android/src/main/cpp/CMakeLists.txt

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ target_include_directories(
1717
"${COMMON_CPP_DIR}"
1818
"${ANDROID_CPP_DIR}"
1919
"${INCLUDE_DIR}"
20+
"${TOKENIZERS_DIR}"
2021
"${REACT_NATIVE_DIR}/ReactCommon"
2122
"${REACT_NATIVE_DIR}/ReactAndroid/src/main/jni/react/turbomodule"
2223
"${REACT_NATIVE_DIR}/ReactCommon/callinvoker"
@@ -84,13 +85,6 @@ elseif(ANDROID_ABI STREQUAL "x86_64")
8485
set(OPENCV_THIRD_PARTY_LIBS "")
8586
endif()
8687

87-
# ------- tokenizers-cpp -------
88-
89-
set(TOKENIZERS_LIBS
90-
"${LIBS_DIR}/tokenizers-cpp/${ANDROID_ABI}/libtokenizers_c.a"
91-
"${LIBS_DIR}/tokenizers-cpp/${ANDROID_ABI}/libtokenizers_cpp.a"
92-
"${LIBS_DIR}/tokenizers-cpp/${ANDROID_ABI}/libsentencepiece.a"
93-
)
9488

9589
# ------- phonemis -------
9690

@@ -108,8 +102,6 @@ target_link_libraries(
108102
${RN_VERSION_LINK_LIBRARIES}
109103
${OPENCV_LIBS}
110104
${OPENCV_THIRD_PARTY_LIBS}
111-
${TOKENIZERS_LIBS}
112-
${TOKENIZERS_THIRD_PARTY_LIBS}
113105
${PHONEMIS_LIBS}
114106
executorch
115107
${EXECUTORCH_LIBS}

packages/react-native-executorch/common/rnexecutorch/ErrorCodes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ enum class RnExecutorchErrorCode : int32_t {
6767
WrongDimensions = 116,
6868
/**
6969
* Thrown when the input passed to our APIs is invalid, for example when
70-
* passing an empty message aray to LLM's generate().
70+
* passing an empty message array to LLM's generate().
7171
*/
7272
InvalidUserInput = 117,
7373
/**
7474
* Thrown when the number of downloaded files is unexpected, due to download
7575
* interruptions.
7676
*/
7777
DownloadInterrupted = 118,
78+
/**
79+
* Thrown when an error occurs with the tokenizer or tokenization process.
80+
*/
81+
TokenizerError = 167,
7882
/**
7983
* Thrown when there's a configuration mismatch between multilingual and
8084
* language settings in Speech-to-Text models.

packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
#include "TokenizerModule.h"
22
#include "Error.h"
33
#include "ErrorCodes.h"
4+
#include <cstdint>
45
#include <executorch/extension/module/module.h>
56
#include <filesystem>
6-
#include <rnexecutorch/data_processing/FileUtils.h>
7+
#include <pytorch/tokenizers/error.h>
8+
#include <runner/constants.h>
79

810
namespace rnexecutorch {
911
using namespace facebook;
12+
using namespace executorch::extension::constants;
1013

1114
TokenizerModule::TokenizerModule(
1215
std::string source, std::shared_ptr<react::CallInvoker> callInvoker)
13-
: tokenizer(tokenizers::Tokenizer::FromBlobJSON(
14-
file_utils::loadBytesFromFile(source))),
15-
memorySizeLowerBound(std::filesystem::file_size(source)) {}
16+
: tokenizer(std::make_unique<tokenizers::HFTokenizer>()),
17+
memorySizeLowerBound(std::filesystem::file_size(source)) {
18+
19+
auto status = tokenizer->load(source);
20+
21+
if (status != tokenizers::Error::Ok) {
22+
throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError,
23+
"Unexpected issue occured while loading tokenizer");
24+
};
25+
}
1626

1727
void TokenizerModule::ensureTokenizerLoaded(
1828
const std::string &methodName) const {
@@ -23,31 +33,69 @@ void TokenizerModule::ensureTokenizerLoaded(
2333
}
2434
}
2535

26-
std::vector<int32_t> TokenizerModule::encode(std::string s) const {
36+
std::vector<uint64_t> TokenizerModule::encode(std::string s) const {
2737
ensureTokenizerLoaded("encode");
28-
return tokenizer->Encode(s);
38+
39+
// If the used tokenizer.json has defined post_processor field,
40+
// setting any of bos or eos arguments to value other than provided constant
41+
// ( which is 0) will result in running the post_processor with
42+
// 'add_special_token' flag
43+
auto encodeResult =
44+
tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens);
45+
if (!encodeResult.ok()) {
46+
throw rnexecutorch::RnExecutorchError(
47+
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
48+
"Unexpected issue occured while encoding: " +
49+
std::to_string(static_cast<int32_t>(encodeResult.error())));
50+
}
51+
return encodeResult.get();
2952
}
3053

31-
std::string TokenizerModule::decode(std::vector<int32_t> vec,
54+
std::string TokenizerModule::decode(std::vector<uint64_t> vec,
3255
bool skipSpecialTokens) const {
3356
ensureTokenizerLoaded("decode");
34-
return tokenizer->Decode(vec, skipSpecialTokens);
57+
58+
auto decodeResult = tokenizer->decode(vec, skipSpecialTokens);
59+
if (!decodeResult.ok()) {
60+
throw RnExecutorchError(
61+
RnExecutorchErrorCode::TokenizerError,
62+
"Unexpected issue occured while decoding: " +
63+
std::to_string(static_cast<int32_t>(decodeResult.error())));
64+
}
65+
66+
return decodeResult.get();
3567
}
3668

3769
size_t TokenizerModule::getVocabSize() const {
3870
ensureTokenizerLoaded("getVocabSize");
39-
return tokenizer->GetVocabSize();
71+
return static_cast<size_t>(tokenizer->vocab_size());
4072
}
4173

42-
std::string TokenizerModule::idToToken(int32_t tokenId) const {
74+
std::string TokenizerModule::idToToken(uint64_t tokenId) const {
4375
ensureTokenizerLoaded("idToToken");
44-
return tokenizer->IdToToken(tokenId);
76+
auto result = tokenizer->id_to_piece(tokenId);
77+
if (!result.ok()) {
78+
throw rnexecutorch::RnExecutorchError(
79+
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
80+
"Unexpected issue occured while trying to convert id to token: " +
81+
std::to_string(static_cast<int32_t>(result.error())));
82+
}
83+
return result.get();
4584
}
4685

47-
int32_t TokenizerModule::tokenToId(std::string token) const {
86+
uint64_t TokenizerModule::tokenToId(std::string token) const {
4887
ensureTokenizerLoaded("tokenToId");
49-
return tokenizer->TokenToId(token);
88+
89+
auto result = tokenizer->piece_to_id(token);
90+
if (!result.ok()) {
91+
throw rnexecutorch::RnExecutorchError(
92+
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
93+
"Unexpected issue occured while trying to convert token to id: " +
94+
std::to_string(static_cast<int32_t>(result.error())));
95+
}
96+
return result.get();
5097
}
98+
5199
std::size_t TokenizerModule::getMemoryLowerBound() const noexcept {
52100
return memorySizeLowerBound;
53101
}

packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,33 @@
22

33
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
44
#include <ReactCommon/CallInvoker.h>
5+
#include <pytorch/tokenizers/hf_tokenizer.h>
56
#include <string>
6-
#include <tokenizers-cpp/tokenizers_cpp.h>
77
namespace rnexecutorch {
88
using namespace facebook;
99

1010
class TokenizerModule {
1111
public:
1212
explicit TokenizerModule(std::string source,
1313
std::shared_ptr<react::CallInvoker> callInvoker);
14-
[[nodiscard("Registered non-void function")]] std::vector<int32_t>
14+
[[nodiscard("Registered non-void function")]] std::vector<uint64_t>
1515
encode(std::string s) const;
1616
[[nodiscard("Registered non-void function")]] std::string
17-
decode(std::vector<int32_t> vec, bool skipSpecialTokens) const;
17+
decode(std::vector<uint64_t> vec, bool skipSpecialTokens) const;
1818
[[nodiscard("Registered non-void function")]] std::string
19-
idToToken(int32_t tokenId) const;
20-
[[nodiscard("Registered non-void function")]] int32_t
19+
idToToken(uint64_t tokenId) const;
20+
[[nodiscard("Registered non-void function")]] uint64_t
2121
tokenToId(std::string token) const;
2222
[[nodiscard("Registered non-void function")]] std::size_t
2323
getVocabSize() const;
2424
std::size_t getMemoryLowerBound() const noexcept;
2525

2626
private:
2727
void ensureTokenizerLoaded(const std::string &methodName) const;
28-
std::unique_ptr<tokenizers::Tokenizer> tokenizer;
28+
std::unique_ptr<tokenizers::HFTokenizer> tokenizer;
2929
const std::size_t memorySizeLowerBound{0};
3030
};
3131

3232
REGISTER_CONSTRUCTOR(TokenizerModule, std::string,
3333
std::shared_ptr<react::CallInvoker>);
34-
} // namespace rnexecutorch
34+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ getValue<std::vector<int64_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
218218
return getArrayAsVector<int64_t>(val, runtime);
219219
}
220220

221+
template <>
222+
inline std::vector<uint64_t>
223+
getValue<std::vector<uint64_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
224+
return getArrayAsVector<uint64_t>(val, runtime);
225+
}
226+
221227
// Template specializations for std::span<T> types
222228
template <>
223229
inline std::span<float> getValue<std::span<float>>(const jsi::Value &val,
@@ -273,6 +279,12 @@ inline std::span<int64_t> getValue<std::span<int64_t>>(const jsi::Value &val,
273279
return getTypedArrayAsSpan<int64_t>(val, runtime);
274280
}
275281

282+
template <>
283+
inline std::span<uint64_t>
284+
getValue<std::span<uint64_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
285+
return getTypedArrayAsSpan<uint64_t>(val, runtime);
286+
}
287+
276288
// Conversion from C++ types to jsi --------------------------------------------
277289

278290
// Implementation functions might return any type, but in a promise we can only
@@ -293,6 +305,15 @@ inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
293305
return {runtime, array};
294306
}
295307

308+
inline jsi::Value getJsiValue(const std::vector<uint64_t> &vec,
309+
jsi::Runtime &runtime) {
310+
jsi::Array array(runtime, vec.size());
311+
for (size_t i = 0; i < vec.size(); i++) {
312+
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<double>(vec[i])));
313+
}
314+
return {runtime, array};
315+
}
316+
296317
inline jsi::Value getJsiValue(const std::vector<float> &vec,
297318
jsi::Runtime &runtime) {
298319
jsi::Array array(runtime, vec.size());
@@ -302,6 +323,16 @@ inline jsi::Value getJsiValue(const std::vector<float> &vec,
302323
return {runtime, array};
303324
}
304325

326+
inline jsi::Value getJsiValue(const std::vector<std::string> &vec,
327+
jsi::Runtime &runtime) {
328+
jsi::Array array(runtime, vec.size());
329+
for (size_t i = 0; i < vec.size(); i++) {
330+
array.setValueAtIndex(runtime, i,
331+
jsi::String::createFromUtf8(runtime, vec[i]));
332+
}
333+
return {runtime, array};
334+
}
335+
305336
inline jsi::Value getJsiValue(const std::vector<char> &vec,
306337
jsi::Runtime &runtime) {
307338
jsi::Array array(runtime, vec.size());
@@ -311,10 +342,28 @@ inline jsi::Value getJsiValue(const std::vector<char> &vec,
311342
return {runtime, array};
312343
}
313344

345+
// Conditional as on android, size_t and uint64_t reduce to the same type,
346+
// introducing ambiguity
347+
template <typename T,
348+
typename = std::enable_if_t<std::is_same_v<T, size_t> &&
349+
!std::is_same_v<size_t, uint64_t>>>
350+
inline jsi::Value getJsiValue(T val, jsi::Runtime &runtime) {
351+
return jsi::Value(static_cast<double>(val));
352+
}
353+
354+
inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) {
355+
jsi::BigInt bigInt = jsi::BigInt::fromUint64(runtime, val);
356+
return {runtime, bigInt};
357+
}
358+
314359
inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
315360
return {runtime, val};
316361
}
317362

363+
inline jsi::Value getJsiValue(bool val, jsi::Runtime &runtime) {
364+
return jsi::Value(val);
365+
}
366+
318367
inline jsi::Value getJsiValue(const std::shared_ptr<OwningArrayBuffer> &buf,
319368
jsi::Runtime &runtime) {
320369
jsi::ArrayBuffer arrayBuffer(runtime, buf);

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ SpeechToText::encode(std::span<float> waveform) const {
3737
}
3838

3939
std::shared_ptr<OwningArrayBuffer>
40-
SpeechToText::decode(std::span<int32_t> tokens,
40+
SpeechToText::decode(std::span<uint64_t> tokens,
4141
std::span<float> encoderOutput) const {
4242
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
4343
return std::make_shared<OwningArrayBuffer>(decoderOutput);

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SpeechToText {
2222
encode(std::span<float> waveform) const;
2323
[[nodiscard(
2424
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
25-
decode(std::span<int32_t> tokens, std::span<float> encoderOutput) const;
25+
decode(std::span<uint64_t> tokens, std::span<float> encoderOutput) const;
2626
[[nodiscard("Registered non-void function")]] std::vector<char>
2727
transcribe(std::span<float> waveform, std::string languageOption) const;
2828

0 commit comments

Comments
 (0)