Skip to content

Commit b4eeffb

Browse files
committed
chore: ensure consistent .pt device naming
1 parent 3c0323f commit b4eeffb

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

era5_training/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def main():
2121
input_data = load_nc_dataset(args.test_data_dir / Path(prefix + "-input.nc"))
2222
pred_reference = load_nc_dataset(args.test_data_dir / Path(prefix + "-predict.nc"))
2323

24-
model_path = args.scripted_model_dir / Path(f"nlgw_{prefix}_gpu_scripted.pt")
24+
model_path = args.scripted_model_dir / Path(f"nlgw_{prefix}_{device}_scripted.pt")
2525
print(f"loading model {model_path}...")
2626
model = torch.jit.load(model_path)
2727

utils/function_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def Inference_and_Save_AttentionUNet(
411411
xdata.to_netcdf(f"test-data/unet-{k}.nc")
412412

413413
print("scripting...")
414-
script_to_torchscript(model, filename="nlgw_unet_gpu_scripted.pt")
414+
script_to_torchscript(model, filename=f"nlgw_unet_{device}_scripted.pt")
415415
print("complete")
416416

417417
# write to netCDF

0 commit comments

Comments
 (0)