Skip to content

Commit e2a4405

Browse files
committed
feat(local): support sharded safetensors
1 parent 9a5dd2f commit e2a4405

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

embeddings/src/model/local.rs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ fn model_type_from_config(config: &str) -> Option<String> {
5959
pub struct LocalModelInfo {
6060
pub config_path: PathBuf,
6161
pub tokenizer_path: PathBuf,
62-
pub weights_path: PathBuf,
62+
pub weights_paths: Vec<PathBuf>,
6363
}
6464

6565
/// Download and cache model files from HuggingFace
@@ -81,14 +81,48 @@ pub fn build_model_info(
8181
let tokenizer_path = api
8282
.get("tokenizer.json")
8383
.map_err(|_| LibError::ModelTokenizerFetchFailed)?;
84-
let weights_path = api
85-
.get("model.safetensors")
86-
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
84+
let weights_paths = match api.get("model.safetensors") {
85+
Ok(path) => vec![path],
86+
Err(_) => {
87+
// Support sharded safetensors via model.safetensors.index.json
88+
let index_path = api
89+
.get("model.safetensors.index.json")
90+
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
91+
let index_contents = std::fs::read_to_string(&index_path)
92+
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
93+
let index_json: Value = serde_json::from_str(&index_contents)
94+
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
95+
let weight_map = index_json
96+
.get("weight_map")
97+
.and_then(Value::as_object)
98+
.ok_or_else(|| LibError::ModelWeightsFetchFailed)?;
99+
100+
let mut shards: Vec<String> = weight_map
101+
.values()
102+
.filter_map(|v| v.as_str().map(|s| s.to_string()))
103+
.collect();
104+
shards.sort();
105+
shards.dedup();
106+
107+
if shards.is_empty() {
108+
return Err(Box::new(LibError::ModelWeightsFetchFailed));
109+
}
110+
111+
let mut paths = Vec::with_capacity(shards.len());
112+
for shard in shards {
113+
let p = api
114+
.get(&shard)
115+
.map_err(|_| LibError::ModelWeightsFetchFailed)?;
116+
paths.push(p);
117+
}
118+
paths
119+
}
120+
};
87121

88122
Ok(LocalModelInfo {
89123
config_path,
90124
tokenizer_path,
91-
weights_path,
125+
weights_paths,
92126
})
93127
}
94128

@@ -153,7 +187,7 @@ impl BertEmbeddingModel {
153187
let _ = tokenizer.with_truncation(None);
154188

155189
let vb = unsafe {
156-
VarBuilder::from_mmaped_safetensors(&[model_info.weights_path], BERT_DTYPE, &device)
190+
VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, BERT_DTYPE, &device)
157191
.map_err(|_| LibError::ModelWeightsLoadFailed)?
158192
};
159193

@@ -216,12 +250,8 @@ impl CausalEmbeddingModel {
216250

217251
let dtype = dtype_from_config(&config, &device);
218252
let vb = unsafe {
219-
VarBuilder::from_mmaped_safetensors(
220-
std::slice::from_ref(&model_info.weights_path),
221-
dtype,
222-
&device,
223-
)
224-
.map_err(|_| LibError::ModelWeightsLoadFailed)?
253+
VarBuilder::from_mmaped_safetensors(&model_info.weights_paths, dtype, &device)
254+
.map_err(|_| LibError::ModelWeightsLoadFailed)?
225255
};
226256

227257
let vb = if vb.contains_tensor("model.embed_tokens.weight") {

0 commit comments

Comments
 (0)