@@ -59,7 +59,7 @@ fn model_type_from_config(config: &str) -> Option<String> {
5959pub struct LocalModelInfo {
6060 pub config_path : PathBuf ,
6161 pub tokenizer_path : PathBuf ,
62- pub weights_path : PathBuf ,
62+ pub weights_paths : Vec < PathBuf > ,
6363}
6464
6565/// Download and cache model files from HuggingFace
@@ -81,14 +81,48 @@ pub fn build_model_info(
8181 let tokenizer_path = api
8282 . get ( "tokenizer.json" )
8383 . map_err ( |_| LibError :: ModelTokenizerFetchFailed ) ?;
84- let weights_path = api
85- . get ( "model.safetensors" )
86- . map_err ( |_| LibError :: ModelWeightsFetchFailed ) ?;
84+ let weights_paths = match api. get ( "model.safetensors" ) {
85+ Ok ( path) => vec ! [ path] ,
86+ Err ( _) => {
87+ // Support sharded safetensors via model.safetensors.index.json
88+ let index_path = api
89+ . get ( "model.safetensors.index.json" )
90+ . map_err ( |_| LibError :: ModelWeightsFetchFailed ) ?;
91+ let index_contents = std:: fs:: read_to_string ( & index_path)
92+ . map_err ( |_| LibError :: ModelWeightsFetchFailed ) ?;
93+ let index_json: Value = serde_json:: from_str ( & index_contents)
94+ . map_err ( |_| LibError :: ModelWeightsFetchFailed ) ?;
95+ let weight_map = index_json
96+ . get ( "weight_map" )
97+ . and_then ( Value :: as_object)
98+ . ok_or_else ( || LibError :: ModelWeightsFetchFailed ) ?;
99+
100+ let mut shards: Vec < String > = weight_map
101+ . values ( )
102+ . filter_map ( |v| v. as_str ( ) . map ( |s| s. to_string ( ) ) )
103+ . collect ( ) ;
104+ shards. sort ( ) ;
105+ shards. dedup ( ) ;
106+
107+ if shards. is_empty ( ) {
108+ return Err ( Box :: new ( LibError :: ModelWeightsFetchFailed ) ) ;
109+ }
110+
111+ let mut paths = Vec :: with_capacity ( shards. len ( ) ) ;
112+ for shard in shards {
113+ let p = api
114+ . get ( & shard)
115+ . map_err ( |_| LibError :: ModelWeightsFetchFailed ) ?;
116+ paths. push ( p) ;
117+ }
118+ paths
119+ }
120+ } ;
87121
88122 Ok ( LocalModelInfo {
89123 config_path,
90124 tokenizer_path,
91- weights_path ,
125+ weights_paths ,
92126 } )
93127}
94128
@@ -153,7 +187,7 @@ impl BertEmbeddingModel {
153187 let _ = tokenizer. with_truncation ( None ) ;
154188
155189 let vb = unsafe {
156- VarBuilder :: from_mmaped_safetensors ( & [ model_info. weights_path ] , BERT_DTYPE , & device)
190+ VarBuilder :: from_mmaped_safetensors ( & model_info. weights_paths , BERT_DTYPE , & device)
157191 . map_err ( |_| LibError :: ModelWeightsLoadFailed ) ?
158192 } ;
159193
@@ -216,12 +250,8 @@ impl CausalEmbeddingModel {
216250
217251 let dtype = dtype_from_config ( & config, & device) ;
218252 let vb = unsafe {
219- VarBuilder :: from_mmaped_safetensors (
220- std:: slice:: from_ref ( & model_info. weights_path ) ,
221- dtype,
222- & device,
223- )
224- . map_err ( |_| LibError :: ModelWeightsLoadFailed ) ?
253+ VarBuilder :: from_mmaped_safetensors ( & model_info. weights_paths , dtype, & device)
254+ . map_err ( |_| LibError :: ModelWeightsLoadFailed ) ?
225255 } ;
226256
227257 let vb = if vb. contains_tensor ( "model.embed_tokens.weight" ) {
0 commit comments