Examples
Included with the repository are a number of example scripts that demonstrate how various aspects of CDTools work. It is recommended to read through at least a few of them before continuing to the tutorial.
All the datasets used in these example scripts are included in the repository, and the scripts should be runnable as soon as CDTools is installed
Inspect Dataset
The first example loads and visualizes ptychography data stored in a .cxi file.
import cdtools
from matplotlib import pyplot as plt
# First, we load an example dataset from a .cxi file
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# And we take a look at the data
dataset.inspect()
plt.show()
First, data is read into a Ptycho2DDataset object, which is a subclass of a pytorch dataset that knows a bit about the structure of ptychography data. Calling dataset.inspect generates a plot showing an overview of the ptychography scan data. On the left is a scatter plot showing the integrated detector intensity at each probe location. On the right, raw detector images are shown. The dataset can be navigated by clicking around the scatter plot on the left.
Simple Ptycho
This script runs a ptychography reconstruction using the SimplePtycho model, a bare-bones ptychography model for the transmission geometry.
The purpose of the SimplePtycho model is pedagogical: there are very few situations where it would preferred to the FancyPtycho model which will be introduced later.
Because of it’s simplicity, the definition of this model is much simpler than the definition of FancyPtycho, and it is therefore a good first model to look at to learn how to implement a custom ptychography model in CDTools.
"""
Runs a very simple reconstruction using the SimplePtycho model, which was
designed to be an easy introduction to show how the models are made and used.
For a more realistic example of how to use cdtools for real-world data,
look at fancy_ptycho.py and gold_ball_ptycho.py, both of which use the
more powerful FancyPtycho model and include more information on how to
correct for common sources of error.
"""
import cdtools
import torch as t
from matplotlib import pyplot as plt
# We load an example dataset from a .cxi file
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We create a ptychography model from the dataset
model = cdtools.models.SimplePtycho.from_dataset(dataset)
# We move the model to the GPU, if possible
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')
model.inspect(dataset)
# We run the reconstruction
for loss in model.Adam_optimize(100, dataset, batch_size=10):
# We print a quick report of the optimization status
print(model.report())
# And liveplot the updates to the model as they happen
model.inspect(dataset)
# We open a comparison of the simulated and measured data
model.compare(dataset)
plt.show()
When reading this script, note the basic workflow. After the data is loaded, a model is created to match the geometry stored in the dataset with a sensible default initialization for all the parameters.
Next, the model is moved to the GPU using the model.to function. Any device understood by torch.Tensor.to can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the dataset.get_as function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using dataset.to, but the speedup is empirically quite small.
Once the device is selected, a reconstruction is run using model.Adam_optimize. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run. Inside the loop, model.inspect(dataset) is called every epoch to live-update a set of plots showing the current state of the model parameters.
Finally, model.compare(dataset) is called to show how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.
Fancy Ptycho
This script runs a reconstruction on the same data, but using the workhorse FancyPtycho model, demonstrating some of it’s more commonly used features
import cdtools
import torch as t
from matplotlib import pyplot as plt
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# FancyPtycho is the workhorse model
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3, # Use 3 incoherently mixing probe modes
oversampling=2, # Simulate the probe on a 2xlarger real-space array
probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix
propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm
units='mm', # Set the units for the live plots
obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix
)
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')
# For this script, we use a slightly different pattern where we explicitly
# create a `Reconstructor` class to orchestrate the reconstruction. The
# reconstructor will store the model and dataset and create an appropriate
# optimizer. This allows the optimizer to persist between loops, along with
# e.g. estimates of the moments of individual parameters
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# The learning rate parameter sets the alpha for Adam.
# The beta parameters are (0.9, 0.999) by default
# The batch size sets the minibatch size
for loss in recon.optimize(50, lr=0.02, batch_size=10):
print(model.report())
# Because plotting can be expensive, setting a minimum plotting interval
# (in seconds) can avoid excessive replots.
model.inspect(dataset, min_interval=10)
# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe in the previous
# loop, and now we polish the reconstruction with a lower learning rate
# and larger minibatch
for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
model.inspect(dataset, min_interval=10)
# This orthogonalizes the recovered probe modes
model.tidy_probes()
# Setting replot_all will reopen any windows which were closed earlier
model.inspect(dataset, replot_all=True)
model.compare(dataset)
plt.show()
The FancyPtycho.from_dataset factory function has many keyword arguments which can turn on or modify various mixins. In this case, we perform a reconstruction with:
3 incoherently mixing probe modes (in the vein of doi:10.1038/nature11806)
A probe array expanded by a factor of 2 in real space, i.e. simulated on a 2x2 upsampled grid in Fourier space (in the vein of doi:10.1103/PhysRevA.87.053850)
A circular finite support constraint applied to the probe
An initial guess for the probe which has been propagated from its focus position
By default, FancyPtycho will also optimize over the following model parameters, each of which corrects for a specific source of errror:
model.backgroundA frame-independent detector background
model.weightsA frame-to-frame variation in the incoming probe intensity
model.translation_offsetsA frame-independent detector background
These corrections can be turned off (on) by calling model.<parameter>.requires_grad = False #(True).
Note as well two other changes that are made in this script, when compared to simple_ptycho.py. First, a Reconstructor object is explicitly created, in this case an AdamReconstructor. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to Reconstructor.optimize().
We use this pattern, instead of the simpler call to model.Adam_optimize(), because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object.
In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. model.LBFGS_optimize()).
Note also the use of min_interval=10 in the calls to model.inspect(dataset). Because generating plots can be expensive, passing a minimum interval (in seconds) prevents excessive replots. Finally, the call to model.inspect(dataset, replot_all=True) at the end of the script reopens any plot windows that the user may have closed during the reconstruction, so that all results are visible at the end.
Gold Ball Ptycho
This script shows how the FancyPtycho model might be used in a realistic situation, to perform a reconstruction on the classic gold balls dataset. This script also shows how to save results!
import cdtools
import torch as t
from matplotlib import pyplot as plt
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We pad the dataset with 10 pixels of zeroes around the edge. This
# data gets masked off, so it is not used for the reconstruction. This padding
# helps prevent aliasing when the probe and object get multiplied. It's a
# helpful step when there is signal present out to the edge of the detector,
# and is usually set to the radius of the probe's Fourier transform (in pixels)
pad = 10
dataset.pad(pad)
dataset.inspect()
# When the dataset is padded with zeroes and masked, the probe reconstruction
# becomes very unstable, and often develops noise at these high-frequency,
# masked off frequencies. To combat this, we simulate the probe at lower
# resolution, using the probe_fourier_crop argument. This is good practice
# in general when padding the dataset
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad,
plot_level=2,
)
# This is a trick that my grandmother taught me, to combat the raster grid
# pathology: we randomze the our initial guess of the probe positions.
# The units here are pixels in the object array.
# Try running this script with and without this line to see the difference!
model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets)
# Not much probe intensity instability in this dataset, no need for this
model.weights.requires_grad = False
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')
# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# This will save out the intermediate results if an exception is thrown
# during the reconstruction
with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())
model.inspect(dataset, min_interval=5)
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())
model.inspect(dataset, min_interval=5)
# We can often reset our guess of the probe positions once we have a
# good guess of probe and object, but in this case it causes the
# raster grid pathology to return.
# model.translation_offsets.data[:] = 0
# Setting schedule=True automatically lowers the learning rate if
# the loss fails to improve after 10 epochs
for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True):
print(model.report())
model.inspect(dataset, min_interval=5)
model.tidy_probes()
# This saves the final result
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)
model.inspect(dataset, replot_all=True)
model.compare(dataset)
plt.show()
Note first the explicit addition of the plot_level=2 argument in the call to FancyPtycho.from_dataset. This value controls which plots are generated. With plot_level=1, only the main results are shown - plot_level=2 shows some more advanced monitoring of the error correction terms (background, position error, etc.), and plot_level=3 shows all registered plots.
Note also the use of model.save_on_exception and model.save_to_h5 to save the results of the reconstruction. If a different file format is required, model.save_results will save to a pure-python dictionary which can be processed further.
Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the weights parameter.
Near-Field Ptycho
This script shows how the FancyPtycho model can be used on a typical near-field ptychography (also known as Fresnel ptychography) dataset.
import cdtools
import torch as t
from matplotlib import pyplot as plt
filename = 'example_data/PETRAIII_P25_Near_Field_Ptycho.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
dataset.inspect()
# Setting near_field equal to True uses an angular spectrum propagator in
# lieu of the default Fourier-transform propagator for far-field ptychography.
#
# If propagation_distance is not set, it assumes that the geometry is
# a standard near-field geometry with flat illumination wavefronts, and
# pulls the sample to detector distance from dataset.distance
#
# If propagation_distance is set, it assumes a Fresnel scaling theorem
# geometry with:
#
# - distance (from the dataset): The sample-to-detector distance
# - propagation_distance: The focus-to-sample distance
#
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=1,
near_field=True,
propagation_distance=3.65e-3, # 3.65 downstream from focus
units='um', # Set the units for the live plots
obj_view_crop=-35, # Expand the view for the live plots
loss="poisson_nll", # Best option for photon-counting detectors
panel_plot_mode=True, # Set to False to get individual figures
)
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')
model.inspect(dataset)
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
for loss in recon.optimize(100, lr=0.04, batch_size=10):
print(model.report())
model.inspect(dataset, min_interval=5)
for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
model.inspect(dataset, min_interval=5)
# This orthogonalizes the recovered probe modes
model.tidy_probes()
model.inspect(dataset, replot_all=True)
model.compare(dataset)
plt.show()
The major change here is the setting of the near_field=True argument to FancyPtycho.from_dataset. This changes the propagator to a near-field propagator. As noted in the comments, if propagation_distance is not set, the model will assume a standard near-field geomtry with flat illumination.
If propagation_distance is set, it will assume a Fresnel scaling theorem-type geometry, with propagation_distance as the focus-to-sample distance, and the distance set in the dataset object as the sample-to-detector distance.
Finally, note the addition of the panel_plot_mode=True argument. This is the default mode, and returns the plots in a panel format, good for easily monitoring the progress of a reconstruction. If individual plots are needed for use in presentations, papers, or otherwise, setting panel_plot_mode=False will plot each output in it’s own window.
Gold Ball Split
It is very common to run reconstructions on two disjoint subsets of the same dataset, as well as the full dataset. This is primarily done to estimate the resolution of a reconstruction via the Fourier ring correlation (FRC).
import cdtools
import torch as t
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
pad = 10
dataset.pad(pad)
# This splits the dataset into a pseudorandomly chosen set of two disjoint
# datasets. The partitioning is drawn from a saved list, so the split is
# deterministic
dataset_1, dataset_2 = dataset.split()
datasets = [dataset_1, dataset_2, dataset]
labels = ['half_1', 'half_2', 'full']
for label, dataset in zip(labels, datasets):
print(f'Working on dataset {label}')
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += \
0.7 * t.randn_like(model.translation_offsets)
model.weights.requires_grad = False
if t.cuda.is_available():
model.to(device='cuda')
dataset.get_as(device='cuda')
# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# For batched reconstructions like this, there's no need to live-plot
# the progress
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())
for loss in recon.optimize(100, lr=0.001, batch_size=100,
schedule=True):
print(model.report())
model.tidy_probes()
model.save_to_h5(f'example_reconstructions/gold_balls_{label}.h5', dataset)
This script simply divides the dataset in two, and then performs the same reconstruction on both halves of the dataset, as well as the full dataset.
Gold Ball Synthesize
Once we have a set of three reconstructions - two half data reconstructions and a full data reconstruction, we need to calculate the resolution metrics. This script shows how that is done.
import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
import numpy as np
# We load all three reconstructions
half_1 = cdtools.tools.data.h5_to_nested_dict(
'example_reconstructions/gold_balls_half_1.h5')
half_2 = cdtools.tools.data.h5_to_nested_dict(
'example_reconstructions/gold_balls_half_2.h5')
full = cdtools.tools.data.h5_to_nested_dict(
'example_reconstructions/gold_balls_full.h5')
# This defines the region of recovered object to use for the analysis.
pad = 260
window = np.s_[pad:-pad, pad:-pad]
# This brings all three reconstructions to a common basis, correcting for
# possible global phase offsets, position shifts, and phase ramps. It also
# calculates a Fourier ring correlation and spectral signal-to-noise ratio
# estimate from the two half reconstructions.
results = cdtools.tools.analysis.standardize_reconstruction_set(
half_1,
half_2,
full,
window=window,
nbins=40, # The number of bins to use for the FRC calculation
)
# We plot the normalized object images
p.plot_amplitude(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_full'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_full'][window], basis=results['obj_basis'])
# We plot the calculated Fourier ring correlation
plt.figure()
plt.plot(1e-6*results['frc_freqs'], results['frc'])
plt.plot(1e-6*results['frc_freqs'], results['frc_threshold'], 'k--')
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Fourier Ring Correlation')
plt.legend(['FRC', 'threshold'])
# We plot the calculated spectral signal-to-noise ratio
plt.figure()
plt.semilogy(1e-6*results['frc_freqs'], results['ssnr'])
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Spectral signal-to-noise ratio')
plt.grid('on')
plt.show()
Transmission RPI
CDTools also contains a forward model for randomized probe imaging (RPI, doi:10.1364/OE.397421). This is currently not documented fully, but hopefully will be soon