Skip to content

Commit 964c565

Browse files
pytorchbotGithub Executorch
andauthored
Support multimethod in runner (#17398)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #17228 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/131/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/131/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/131/orig Differential Revision: [D92225533](https://our.internmc.facebook.com/intern/diff/D92225533/) @diff-train-skip-merge Co-authored-by: Github Executorch <github_executorch@arm.com>
1 parent 9ae8181 commit 964c565

File tree

8 files changed

+125
-33
lines changed

8 files changed

+125
-33
lines changed

examples/models/llama/main.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ DEFINE_string(
7777
"etdump.in",
7878
"If an etdump path is provided, generate an ETDump file at the specified path for profiling purposes.");
7979

80+
DEFINE_string(
81+
method_name,
82+
"forward",
83+
"Method name to execute in the model (e.g., 'forward', 'lora_forward').");
84+
8085
// Helper function to parse comma-separated string lists
8186
std::vector<std::string> parseStringList(const std::string& input) {
8287
std::vector<std::string> result;
@@ -145,11 +150,11 @@ int32_t main(int32_t argc, char** argv) {
145150
data_paths,
146151
temperature,
147152
#ifdef ET_EVENT_TRACER_ENABLED
148-
std::move(etdump_gen_ptr)
153+
std::move(etdump_gen_ptr),
149154
#else
150-
nullptr
155+
nullptr,
151156
#endif
152-
);
157+
FLAGS_method_name);
153158

154159
if (runner == nullptr) {
155160
ET_LOG(Error, "Failed to create llama runner");

examples/models/llama/runner/runner.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3737
const std::string& tokenizer_path,
3838
std::optional<const std::string> data_path,
3939
float temperature,
40-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) {
40+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer,
41+
const std::string& method_name) {
4142
if (data_path.has_value()) {
4243
std::vector<std::string> data_files;
4344
data_files.push_back(data_path.value());
@@ -46,22 +47,25 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
4647
tokenizer_path,
4748
std::move(data_files),
4849
temperature,
49-
std::move(event_tracer));
50+
std::move(event_tracer),
51+
method_name);
5052
}
5153
return create_llama_runner(
5254
model_path,
5355
tokenizer_path,
5456
std::vector<std::string>(),
5557
temperature,
56-
std::move(event_tracer));
58+
std::move(event_tracer),
59+
method_name);
5760
}
5861

5962
std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
6063
const std::string& model_path,
6164
const std::string& tokenizer_path,
6265
std::vector<std::string> data_files,
6366
float temperature,
64-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) {
67+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer,
68+
const std::string& method_name) {
6569
ET_LOG(
6670
Info,
6771
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
@@ -84,7 +88,8 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
8488
std::move(tokenizer),
8589
data_files,
8690
temperature,
87-
std::move(event_tracer));
91+
std::move(event_tracer),
92+
method_name);
8893
}
8994

9095
} // namespace example

examples/models/llama/runner/runner.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
2929
const std::string& tokenizer_path,
3030
std::optional<const std::string> data_path,
3131
float temperature = -1.0f,
32-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr);
32+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr,
33+
const std::string& method_name = "forward");
3334

3435
std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3536
const std::string& model_path,
3637
const std::string& tokenizer_path,
3738
std::vector<std::string> data_files = {},
3839
float temperature = -1.0f,
39-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr);
40+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr,
41+
const std::string& method_name = "forward");
4042

4143
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(
4244
const std::string& tokenizer_path,

extension/llm/runner/llm_runner_helper.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,35 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
182182
const std::string& model_path,
183183
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
184184
std::optional<const std::string> data_path,
185-
float temperature) {
185+
float temperature,
186+
const std::string& method_name) {
186187
if (data_path.has_value()) {
187188
std::vector<std::string> data_files;
188189
data_files.push_back(data_path.value());
189190
return create_text_llm_runner(
190-
model_path, std::move(tokenizer), std::move(data_files), temperature);
191+
model_path,
192+
std::move(tokenizer),
193+
std::move(data_files),
194+
temperature,
195+
nullptr,
196+
method_name);
191197
}
192198
return create_text_llm_runner(
193199
model_path,
194200
std::move(tokenizer),
195201
std::vector<std::string>(),
196-
temperature);
202+
temperature,
203+
nullptr,
204+
method_name);
197205
}
198206

199207
std::unique_ptr<TextLLMRunner> create_text_llm_runner(
200208
const std::string& model_path,
201209
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
202210
std::vector<std::string> data_files,
203211
float temperature,
204-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer) {
212+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer,
213+
const std::string& method_name) {
205214
// Sanity check tokenizer
206215
if (!tokenizer || !tokenizer->is_loaded()) {
207216
ET_LOG(Error, "Tokenizer is null or not loaded");
@@ -236,10 +245,10 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
236245
// Create IOManager
237246
std::unique_ptr<IOManager> io_manager = std::make_unique<IOManager>(*module);
238247

239-
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
240-
// TextPrefiller and TextTokenGenerator
241-
auto text_decoder_runner =
242-
std::make_unique<TextDecoderRunner>(module.get(), io_manager.get());
248+
// Create text_decoder_runner
249+
ET_LOG(Info, "Using method: %s", method_name.c_str());
250+
auto text_decoder_runner = std::make_unique<TextDecoderRunner>(
251+
module.get(), io_manager.get(), method_name);
243252

244253
// Create text_prefiller
245254
auto text_prefiller = std::make_unique<TextPrefiller>(

extension/llm/runner/llm_runner_helper.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,16 @@ ET_EXPERIMENTAL std::unordered_set<uint64_t> get_eos_ids(
9595
* @param data_path Optional path to additional data required by the model
9696
* @param temperature Optional temperature parameter for controlling randomness
9797
* (deprecated)
98+
* @param method_name Name of the method to execute in the model
9899
* @return std::unique_ptr<TextLLMRunner> Initialized TextLLMRunner instance, or
99100
* nullptr on failure
100101
*/
101102
ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
102103
const std::string& model_path,
103104
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
104105
std::optional<const std::string> data_path,
105-
float temperature = -1.0f);
106+
float temperature = -1.0f,
107+
const std::string& method_name = "forward");
106108

107109
/**
108110
* @brief Creates a TextLLMRunner instance with dependency injection
@@ -116,6 +118,8 @@ ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
116118
* @param data_files Vector of paths to additional data required by the model
117119
* @param temperature Optional temperature parameter for controlling randomness
118120
* (deprecated)
121+
* @param event_tracer Optional event tracer for profiling
122+
* @param method_name Name of the method to execute in the model
119123
* @return std::unique_ptr<TextLLMRunner> Initialized TextLLMRunner instance, or
120124
* nullptr on failure
121125
*/
@@ -124,7 +128,8 @@ ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
124128
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
125129
std::vector<std::string> data_files = {},
126130
float temperature = -1.0f,
127-
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr);
131+
std::unique_ptr<::executorch::runtime::EventTracer> event_tracer = nullptr,
132+
const std::string& method_name = "forward");
128133

129134
/**
130135
* @brief Creates a MultimodalRunner instance with dependency injection

extension/llm/runner/test/test_text_decoder_runner.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,41 @@ class TextDecoderRunnerTest : public Test {
4747
std::unique_ptr<IOManager> io_manager_;
4848
};
4949

50+
// Test that method_name defaults to "forward"
51+
TEST_F(TextDecoderRunnerTest, MethodNameDefaultsToForward) {
52+
EXPECT_EQ(runner_->method_name(), "forward");
53+
}
54+
55+
// Test that method_name can be set to a custom value
56+
TEST_F(TextDecoderRunnerTest, MethodNameCustomValue) {
57+
auto custom_runner = std::make_unique<TextDecoderRunner>(
58+
mock_module_.get(), io_manager_.get(), "encode");
59+
EXPECT_EQ(custom_runner->method_name(), "encode");
60+
}
61+
62+
// Test that load() uses method_name (not hardcoded "forward")
63+
TEST_F(TextDecoderRunnerTest, LoadUsesMethodName) {
64+
// Get an available model
65+
const char* model_path = std::getenv("KVCACHE_CACHE_POS");
66+
if (!model_path) {
67+
GTEST_SKIP() << "No PTE model environment variable set";
68+
}
69+
auto module = std::make_unique<Module>(model_path);
70+
auto load_result = module->load();
71+
if (load_result != Error::Ok) {
72+
GTEST_SKIP() << "Failed to load model";
73+
}
74+
75+
auto io_mgr = std::make_unique<IOManager>(*module);
76+
77+
// Create runner with a method name that doesn't exist
78+
TextDecoderRunner runner(module.get(), io_mgr.get(), "nonexistent_method");
79+
80+
// load() should fail because "nonexistent_method" doesn't exist
81+
auto result = runner.load();
82+
EXPECT_NE(result, Error::Ok);
83+
}
84+
5085
// Test logits_to_token() method with Float tensor
5186
TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) {
5287
TensorFactory<executorch::aten::ScalarType::Float> tf_float;

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ namespace llm {
2222
// NOTE: we observed ~2x loading performance increase on iPhone 15
2323
// and a ~5% improvement on Galaxy S22 by switching to
2424
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
25-
TextDecoderRunner::TextDecoderRunner(Module* module, IOManager* io_manager)
26-
: module_(module), io_manager_(io_manager) {}
25+
TextDecoderRunner::TextDecoderRunner(
26+
Module* module,
27+
IOManager* io_manager,
28+
std::string method_name)
29+
: module_(module),
30+
io_manager_(io_manager),
31+
method_name_(std::move(method_name)) {}
2732

2833
// This function is functional, meaning it shouldn't modify any state of the
2934
// input. It should be safe to call multiple times with the same inputs. The
@@ -32,7 +37,7 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
3237
TensorPtr& tokens,
3338
int64_t start_pos) {
3439
// ET_LOG(Info, "Input token %" PRIu64, input_token);
35-
auto method_meta_result = module_->method_meta("forward");
40+
auto method_meta_result = module_->method_meta(method_name_);
3641
if (!method_meta_result.ok()) {
3742
return method_meta_result.error();
3843
}
@@ -44,25 +49,31 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
4449

4550
if (use_kv_cache) {
4651
auto start_pos_tensor_result = populate_start_pos_or_cache_position(
47-
module_, start_pos, cache_positions, tokens->numel(), "forward");
52+
module_,
53+
start_pos,
54+
cache_positions,
55+
tokens->numel(),
56+
method_name_.c_str());
4857
if (!start_pos_tensor_result.ok()) {
4958
return start_pos_tensor_result.error();
5059
}
5160
auto start_pos_tensor = std::move(*start_pos_tensor_result);
5261

5362
std::vector<runtime::EValue> inputs;
54-
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);
63+
auto inputs_res =
64+
io_manager_->prepare_decode(tokens, start_pos_tensor, method_name_);
5565
ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error());
5666
inputs = inputs_res.get();
57-
auto outputs_res = module_->forward(inputs);
67+
auto outputs_res = module_->execute(method_name_, inputs);
5868
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
5969

60-
auto update_err = io_manager_->update_decode(outputs_res.get());
70+
auto update_err =
71+
io_manager_->update_decode(outputs_res.get(), method_name_);
6172
ET_CHECK_OK_OR_RETURN_ERROR(update_err);
6273

6374
ET_CHECK_MSG(
6475
outputs_res.get().size() == 1,
65-
"More then one output returned from executing LLM.");
76+
"More than one output returned from executing LLM.");
6677
ET_CHECK_MSG(
6778
outputs_res.get()[0].isTensor(),
6879
"Non Tensor Output returned from executing LLM");
@@ -72,11 +83,12 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
7283
} else { // no kv cache
7384
(void)start_pos; // unused
7485

75-
auto outputs_res = module_->forward(tokens);
86+
std::vector<runtime::EValue> inputs{tokens};
87+
auto outputs_res = module_->execute(method_name_, inputs);
7688
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
7789
ET_CHECK_MSG(
7890
outputs_res.get().size() == 1,
79-
"More then one output returned from executing LLM.");
91+
"More than one output returned from executing LLM.");
8092
ET_CHECK_MSG(
8193
outputs_res.get()[0].isTensor(),
8294
"Non Tensor Output returned from executing LLM");

extension/llm/runner/text_decoder_runner.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ namespace llm {
2020

2121
class ET_EXPERIMENTAL TextDecoderRunner {
2222
public:
23-
explicit TextDecoderRunner(Module* module, IOManager* io_manager);
23+
explicit TextDecoderRunner(
24+
Module* module,
25+
IOManager* io_manager,
26+
std::string method_name = "forward");
2427

2528
virtual ~TextDecoderRunner() = default;
2629

@@ -40,15 +43,30 @@ class ET_EXPERIMENTAL TextDecoderRunner {
4043
* @return The error code.
4144
*/
4245
virtual ::executorch::runtime::Error load() {
43-
return module_->load_method("forward");
46+
auto err = module_->load_method(method_name_);
47+
if (err != ::executorch::runtime::Error::Ok) {
48+
ET_LOG(
49+
Error,
50+
"Failed to load method '%s'. Check available methods in the model.",
51+
method_name_.c_str());
52+
}
53+
return err;
4454
}
4555

4656
/**
4757
* Check if the required methods in the Module is loaded.
4858
* @return True if the Module is loaded, false otherwise.
4959
*/
5060
virtual bool is_method_loaded() {
51-
return module_->is_method_loaded("forward");
61+
return module_->is_method_loaded(method_name_);
62+
}
63+
64+
/**
65+
* Get the method name used by this runner.
66+
* @return The method name.
67+
*/
68+
const std::string& method_name() const {
69+
return method_name_;
5270
}
5371

5472
inline void stop() {
@@ -79,6 +97,7 @@ class ET_EXPERIMENTAL TextDecoderRunner {
7997
*/
8098
Module* module_;
8199
IOManager* io_manager_;
100+
std::string method_name_;
82101
bool should_stop_{false};
83102
};
84103

0 commit comments

Comments
 (0)