Skip to content

Commit 7fe552d

Browse files
authored
Ptycho improvements (#570)
* remove some logging. Reduce n_trials to 25 * use int for defocus and defocus search range. This allows comparison inside run.py * remove unused parameters * add bool parameters to enable optimization of diffraction_angle, defocus, and C12. Only search one side of focus based on the initial defocus. * fix json comma * fix typos. Fix error where probe step size was x10 twice * always use parallax for hyperparameter optimization * use only needed parameters. * add max_batch_size parameter to better manage memory * fix up all the parameters * fix json commas * remove naming typo
1 parent beb3505 commit 7fe552d

File tree

2 files changed

+85
-130
lines changed

2 files changed

+85
-130
lines changed

operators/quantem-direct-ptycho/operator.json

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,6 @@
2020
}
2121
],
2222
"parameters": [
23-
{
24-
"name": "calculation_frequency",
25-
"label": "Calculation Frequency",
26-
"type": "int",
27-
"default": "100",
28-
"description": "Number of frames to accumulate before recalculating the center and emitting a BF image.",
29-
"required": true
30-
},
31-
{
32-
"name": "max_concurrent_scans",
33-
"label": "Max Concurrent Scans",
34-
"type": "int",
35-
"default": "1",
36-
"description": "Maximum number of scans to keep in memory simultaneously. Oldest scans are evicted when this limit is exceeded.",
37-
"required": false
38-
},
3923
{
4024
"name": "accelerating_voltage",
4125
"label": "Accelerating voltage",
@@ -61,11 +45,19 @@
6145
"required": false
6246
},
6347
{
64-
"name": "initial_defocus",
65-
"label": "STEM defocus",
66-
"type": "float",
67-
"default": "0.0",
68-
"description": "The STEM defocus in nm.",
48+
"name": "defocus_search_range_min",
49+
"label": "Defocus search range minimum",
50+
"type": "int",
51+
"default": "0",
52+
"description": "The defocus search range minimum in nanometers.",
53+
"required": false
54+
},
55+
{
56+
"name": "defocus_search_range_max",
57+
"label": "Defocus search range maximum",
58+
"type": "int",
59+
"default": "30",
60+
"description": "The defocus search range maximum in nanometers.",
6961
"required": false
7062
},
7163
{
@@ -76,6 +68,14 @@
7668
"description": "The rotation of the diffraction pattern on the detector in degrees.",
7769
"required": false
7870
},
71+
{
72+
"name": "maximum_C12_magnitude",
73+
"label": "Maximum C12 magnitude",
74+
"type": "int",
75+
"default": "10",
76+
"description": "The maximum C12 magnitude in nanometers.",
77+
"required": false
78+
},
7979
{
8080
"name": "crop_probes",
8181
"label": "Crop probes on each side",
@@ -93,19 +93,19 @@
9393
"required": false
9494
},
9595
{
96-
"name": "defocus_search_range_nm",
97-
"label": "Defocus search range",
98-
"type": "float",
99-
"default": "50.0",
100-
"description": "The defocus search range in nanometers. Unused if defocus is input.",
96+
"name": "n_trials",
97+
"label": "Number of trials for hyperparameter optimization",
98+
"type": "int",
99+
"default": "25",
100+
"description": "The number of trials for hyperparameter optimization.",
101101
"required": false
102102
},
103-
{
104-
"name": "maximum_C12_magnitude_nm",
105-
"label": "Maximum C12 magnitude",
103+
{
104+
"name": "max_batch_size",
105+
"label": "Maximum batch size",
106106
"type": "int",
107-
"default": "2",
108-
"description": "The maximum C12 magnitude in nanometers.",
107+
"default": "10",
108+
"description": "The maximum batch size for processing frames. Reduce if you run out of GPU memory.",
109109
"required": false
110110
},
111111
{
@@ -116,14 +116,6 @@
116116
"options": ["parallax", "ssb", "icom"],
117117
"description": "The deconvolution kernel.",
118118
"required": false
119-
},
120-
{
121-
"name": "use_optimization",
122-
"label": "Use optimization routine",
123-
"type": "bool",
124-
"default": true,
125-
"description": "Use the optimization routine with initial parameters.",
126-
"required": false
127119
}
128120
],
129121
"parallel_config": {

operators/quantem-direct-ptycho/run.py

Lines changed: 54 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def quantem_direct_ptycho(
8080
scan_number = batch.header.scan_number
8181

8282
# --- 2. Get or Create FrameAccumulator ---
83-
max_concurrent_scans = int(parameters.get("max_concurrent_scans", 1))
83+
max_concurrent_scans = 1 # TODO: remove saving of old scans
8484

8585
if scan_number not in accumulators:
8686
# Check if we need to evict old accumulators before creating new one
@@ -114,56 +114,35 @@ def quantem_direct_ptycho(
114114
return None
115115

116116
# --- 5. Perform Calculation ---
117-
logger.info(
118-
f"Scan {scan_number}: Triggering calculation after {accumulator.num_batches_added} messages."
119-
)
120-
logger.info(f"Accumulator finished: {accumulator.finished}")
121-
122117
logger.info(f"Scan {scan_number}: Calculating ptycho images.")
123118

124119
# Calculation parameters
125-
probe_semiangle = parameters.get("probe_semiangle", 25.0)
126120
energy = parameters.get("accelerating_voltage", 300e3)
127-
probe_step_size = parameters.get(
128-
"probe_step_size", 0.1
129-
) # test data set: 0.14383155 nm
130-
crop_probes = parameters.get("crop_probes", 0)
121+
probe_semiangle = parameters.get("probe_semiangle", 25.0)
122+
# test data set: 0.14383155 nm probe step size
123+
probe_step_size_nm = parameters.get("probe_step_size", 0.1)
124+
probe_step_size_A = probe_step_size_nm * 10
131125
upsampling_factor = parameters.get("upsampling_factor", 2)
126+
n_trials = parameters.get("n_trials", 25)
127+
max_batch_size = parameters.get("max_batch_size", 10)
132128

133-
# Parameters for optimize_hyperparameters function
134-
initial_defocus_nm = parameters.get(
135-
"initial_defocus", None
136-
) # in nanometers, can be None
137-
if initial_defocus_nm is not None:
138-
initial_defocus_nm = initial_defocus_nm
139-
initial_defocus_A = initial_defocus_nm * 10 # convert to Angstroms
140-
else:
141-
initial_defocus_A = None
142-
143-
diffraction_rotation_angle = parameters.get(
144-
"diffraction_rotation_angle", None
145-
) # in degrees, can be None
146-
if diffraction_rotation_angle is not None:
147-
diffraction_rotation_angle = diffraction_rotation_angle
148-
rotation_angle = diffraction_rotation_angle * np.pi / 180 # convert to radians
149-
else:
150-
rotation_angle = None
129+
defocus_search_min_nm = parameters.get("defocus_search_range_min", 50)
130+
defocus_search_max_nm = parameters.get("defocus_search_range_max", 50)
131+
defocus_search_min_A = defocus_search_min_nm * 10 # convert to Angstroms
132+
defocus_search_max_A = defocus_search_max_nm * 10 # convert to Angstroms
133+
# Need to convert signs and order because of different conventions in FEI and quantem
134+
defocus_search_range_A = (-defocus_search_max_A, -defocus_search_min_A)
151135

152-
defocus_search_range_nm = parameters.get(
153-
"defocus_search_range", 50
154-
) # in nanometers
155-
defocus_search_range_A = defocus_search_range_nm * 10 # convert to Angstroms
136+
# in degrees
137+
diffraction_rotation_angle_deg = parameters.get("diffraction_rotation_angle", 0)
138+
rotation_angle = diffraction_rotation_angle_deg * np.pi / 180 # convert to radians
156139

157-
maximum_C12_magnitude_nm = parameters.get(
158-
"maximum_C12_magnitude", 10
159-
) # in nanometers
140+
maximum_C12_magnitude_nm = parameters.get("maximum_C12_magnitude", 10)
160141
maximum_C12_magnitude_A = maximum_C12_magnitude_nm * 10 # convert to Angstroms
161142

162143
deconvolution_kernel = parameters.get("deconvolution_kernel", "parallax")
163144

164-
# Determine whether to use optimization or manual settings
165-
use_optimization = bool(parameters.get("use_optimization", True))
166-
145+
crop_probes = parameters.get("crop_probes", 0)
167146
if crop_probes == 0:
168147
logger.info(f"Scan {scan_number}: No cropping of probes applied.")
169148
dense_data = accumulator[:, :-1, :, :].to_dense() ## remove the flyback column
@@ -173,6 +152,7 @@ def quantem_direct_ptycho(
173152
crop_probes:-crop_probes, crop_probes : -crop_probes - 1, :, :
174153
].to_dense() ## crop the edges if needed and remove the flyback column
175154

155+
# Convert SparseArray to Dataset4dstem
176156
dset = em.datastructures.Dataset4dstem.from_array(array=dense_data)
177157
logger.debug(f"dense shape = {dense_data.shape}")
178158

@@ -184,81 +164,64 @@ def quantem_direct_ptycho(
184164
dset.sampling[3] = probe_semiangle / probe_R
185165
dset.units[2:] = ["mrad", "mrad"]
186166

187-
dset.sampling[0] = (
188-
probe_step_size * 10
189-
) ## convert to be Anggstrom for quantem. distiller will give nanometers.
190-
dset.sampling[1] = probe_step_size * 10
167+
dset.sampling[0:2] = probe_step_size_A
191168
dset.units[0:2] = ["A", "A"]
192169

193170
logger.info(f"Scan {scan_number}: Start direct ptycho")
194171
try:
195-
# Initialize DirectPtychography with initial guesses
196-
aberration_coefs = {}
197-
if initial_defocus_A is not None:
198-
aberration_coefs["C10"] = -initial_defocus_A # Note the negative sign
172+
# Initialize DirectPtychography
199173

200174
direct_ptycho = DirectPtychography.from_dataset4d(
201175
dset,
202176
energy=energy,
203177
semiangle_cutoff=probe_semiangle,
204178
device=QUANTEM_DEVICE,
205-
aberration_coefs=aberration_coefs if aberration_coefs else None,
206-
max_batch_size=10,
179+
aberration_coefs={},
180+
max_batch_size=max_batch_size,
207181
rotation_angle=rotation_angle, # need radians
208182
)
209183

210-
if use_optimization:
211-
logger.info(f"Scan {scan_number}: Optimizing hyperparameters")
212-
213-
# Build optimization aberration coefficients
214-
opt_aberration_coefs = {}
215-
if initial_defocus_A is None:
216-
opt_aberration_coefs["C10"] = OptimizationParameter(
217-
defocus_search_range_A, defocus_search_range_A
218-
)
219-
else:
220-
opt_aberration_coefs["C10"] = -initial_defocus_A
221-
222-
opt_aberration_coefs["C12"] = OptimizationParameter(
223-
0, maximum_C12_magnitude_A
224-
)
225-
opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2)
226-
227-
# Build rotation angle optimization
228-
if rotation_angle is None:
229-
opt_rotation_angle = OptimizationParameter(0, np.pi)
230-
else:
231-
opt_rotation_angle = rotation_angle
232-
233-
direct_ptycho.optimize_hyperparameters(
234-
aberration_coefs=opt_aberration_coefs,
235-
rotation_angle=opt_rotation_angle,
236-
deconvolution_kernel=deconvolution_kernel,
237-
n_trials=50,
238-
max_batch_size=10,
239-
)
240-
else:
241-
logger.info(f"Scan {scan_number}: Using manual hyperparameter settings")
184+
# Build optimization aberration coefficients
185+
logger.info(f"Scan {scan_number}: Optimizing hyperparameters")
186+
opt_aberration_coefs = {}
187+
opt_aberration_coefs["C10"] = OptimizationParameter(
188+
defocus_search_range_A[0], defocus_search_range_A[1]
189+
)
190+
opt_aberration_coefs["C12"] = OptimizationParameter(0, maximum_C12_magnitude_A)
191+
opt_aberration_coefs["phi12"] = OptimizationParameter(-np.pi / 2, np.pi / 2)
192+
193+
# Optimize hyperparameters
194+
direct_ptycho.optimize_hyperparameters(
195+
aberration_coefs=opt_aberration_coefs,
196+
deconvolution_kernel="parallax",
197+
n_trials=n_trials,
198+
max_batch_size=max_batch_size,
199+
)
242200

243-
initial_parallax = direct_ptycho.reconstruct(
201+
# Do reconstruction
202+
logger.info(f"Scan {scan_number}: Starting reconstruction")
203+
direct_ptycho.reconstruct(
244204
deconvolution_kernel=deconvolution_kernel,
245205
upsampling_factor=upsampling_factor,
246-
max_batch_size=10,
206+
max_batch_size=max_batch_size,
247207
)
248208

249209
# Process and return result
250210
logger.info(f"Scan {scan_number}: Reconstruction done")
251-
output_bytes = initial_parallax.obj.tobytes()
211+
output_bytes = direct_ptycho.obj.tobytes()
252212
output_meta = {
253213
"scan_number": scan_number,
254-
"shape": initial_parallax.obj.shape,
255-
"dtype": str(initial_parallax.obj.dtype),
214+
"shape": direct_ptycho.obj.shape,
215+
"dtype": str(direct_ptycho.obj.dtype),
256216
"source_operator": "quantem-direct-ptycho",
257-
"direct_ptycho_params": {'C12': direct_ptycho.hyperparameter_state.optimized_aberrations['C12'],
258-
'phi12': direct_ptycho.hyperparameter_state.optimized_aberrations['phi12'],
259-
'C10': -direct_ptycho.aberration_coefs['C10'],
260-
'rotation_angle': direct_ptycho.rotation_angle,
261-
},
217+
"direct_ptycho_params": {
218+
"C12": direct_ptycho.hyperparameter_state.optimized_aberrations["C12"],
219+
"phi12": direct_ptycho.hyperparameter_state.optimized_aberrations[
220+
"phi12"
221+
],
222+
"C10": -direct_ptycho.aberration_coefs["C10"],
223+
"rotation_angle": direct_ptycho.rotation_angle,
224+
},
262225
}
263226
header = MessageHeader(subject=MessageSubject.BYTES, meta=output_meta)
264227
return BytesMessage(header=header, data=output_bytes)

0 commit comments

Comments
 (0)