@@ -22,6 +22,7 @@ namespace npuw {
2222// Layer names for Eagle3 speculative decoding
2323struct Eagle3LayerNames {
2424 static constexpr const char * hidden_states = " hidden_states" ;
25+ static constexpr const char * last_hidden_state = " last_hidden_state" ;
2526};
2627
2728// Utility functions for Eagle3 layer name matching
@@ -71,12 +72,25 @@ class Eagle3Extension {
7172 uint32_t chunk_start_token,
7273 uint32_t chunk_token_count);
7374
75+ // Retrieve and store last_hidden_state output tensor for draft and target models
76+ void update_last_hidden_state (const std::shared_ptr<ov::IAsyncInferRequest>& request,
77+ const std::unordered_map<std::string, ov::Output<const ov::Node>>& out_ports);
78+
79+ ov::SoPtr<ov::ITensor> get_hidden_states () const {
80+ return m_hidden_states;
81+ }
82+
83+ ov::SoPtr<ov::ITensor> get_last_hidden_state () const {
84+ return m_last_hidden_state;
85+ }
86+
7487private:
7588 void validate_hidden_state_tensor (const ov::SoPtr<ov::ITensor>& tensor, const std::string& name);
7689
7790 Eagle3ModelRole m_role = Eagle3ModelRole::None;
7891
79- ov::SoPtr<ov::ITensor> m_hidden_states; // /< Draft model input: hidden_states
92+ ov::SoPtr<ov::ITensor> m_hidden_states; // /< Draft model input: hidden_states
93+ ov::SoPtr<ov::ITensor> m_last_hidden_state; // /< Draft/Target model output: last_hidden_state
8094};
8195
8296} // namespace npuw
0 commit comments