Examples
Included with the repository are a number of example scripts that demonstrate how various aspects of CDTools work. It is recommended to read through at least a few of them before continuing to the tutorial.
All the datasets used in these example scripts are included in the repository, and the scripts should be runnable as soon as CDTools is installed
Inspect Dataset
The first example loads and visualizes ptychography data stored in a .cxi file.
import cdtools
from matplotlib import pyplot as plt
# First, we load an example dataset from a .cxi file
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# And we take a look at the data
dataset.inspect()
plt.show()
First, data is read into a Ptycho2DDataset object, which is a subclass of a pytorch dataset that knows a bit about the structure of ptychography data. Calling dataset.inspect generates a plot showing an overview of the ptychography scan data. On the left is a scatter plot showing the integrated detector intensity at each probe location. On the right, raw detector images are shown. The dataset can be navigated by clicking around the scatter plot on the left.
Simple Ptycho
This script runs a ptychography reconstruction using the SimplePtycho model, a bare-bones ptychography model for the transmission geometry.
The purpose of the SimplePtycho model is pedagogical: there are very few situations where it would preferred to the FancyPtycho model which will be introduced later.
Because of it’s simplicity, the definition of this model is much simpler than the definition of FancyPtycho, and it is therefore a good first model to look at to learn how to implement a custom ptychography model in CDTools.
"""
Runs a very simple reconstruction using the SimplePtycho model, which was
designed to be an easy introduction to show how the models are made and used.
For a more realistic example of how to use cdtools for real-world data,
look at fancy_ptycho.py and gold_ball_ptycho.py, both of which use the
more powerful FancyPtycho model and include more information on how to
correct for common sources of error.
"""
import cdtools
from matplotlib import pyplot as plt
# We load an example dataset from a .cxi file
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We create a ptychography model from the dataset
model = cdtools.models.SimplePtycho.from_dataset(dataset)
# We move the model to the GPU
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# We run the reconstruction
for loss in model.Adam_optimize(100, dataset, batch_size=10):
# We print a quick report of the optimization status
print(model.report())
# And liveplot the updates to the model as they happen
model.inspect(dataset)
# We study the results
model.inspect(dataset)
model.compare(dataset)
plt.show()
When reading this script, note the basic workflow. After the data is loaded, a model is created to match the geometry stored in the dataset with a sensible default initialization for all the parameters.
Next, the model is moved to the GPU using the model.to function. Any device understood by torch.Tensor.to can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the dataset.get_as function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using dataset.to, but the speedup is empirically quite small.
Once the device is selected, a reconstruction is run using model.Adam_optimize. This is a generator function which will yield at the end of every epoch, to allow some monitoring code to be run.
Finally, the results can be studied using model.inspect(dataset), which creates or updates a set of plots showing the current state of the model parameters. model.compare(dataset) is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.
Fancy Ptycho
This script runs a reconstruction on the same data, but using the workhorse FancyPtycho model, demonstrating some of it’s more commonly used features
import cdtools
from matplotlib import pyplot as plt
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# FancyPtycho is the workhorse model
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3, # Use 3 incoherently mixing probe modes
oversampling=2, # Simulate the probe on a 2xlarger real-space array
probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix
propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm
units='mm', # Set the units for the live plots
obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix
)
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# For this script, we use a slightly different pattern where we explicitly
# create a `Reconstructor` class to orchestrate the reconstruction. The
# reconstructor will store the model and dataset and create an appropriate
# optimizer. This allows the optimizer to persist between loops, along with
# e.g. estimates of the moments of individual parameters
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# The learning rate parameter sets the alpha for Adam.
# The beta parameters are (0.9, 0.999) by default
# The batch size sets the minibatch size
for loss in recon.optimize(50, lr=0.02, batch_size=10):
print(model.report())
# Plotting is expensive, so we only do it every tenth epoch
if model.epoch % 10 == 0:
model.inspect(dataset)
# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe in the previous
# loop, and now we polish the reconstruction with a lower learning rate
# and larger minibatch
for loss in recon.optimize(50, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
# This orthogonalizes the recovered probe modes
model.tidy_probes()
model.inspect(dataset)
model.compare(dataset)
plt.show()
The FancyPtycho.from_dataset factory function has many keyword arguments which can turn on or modify various mixins. In this case, we perform a reconstruction with:
3 incoherently mixing probe modes (in the vein of doi:10.1038/nature11806)
A probe array expanded by a factor of 2 in real space, i.e. simulated on a 2x2 upsampled grid in Fourier space (in the vein of doi:10.1103/PhysRevA.87.053850)
A circular finite support constraint applied to the probe
An initial guess for the probe which has been propagated from its focus position
By default, FancyPtycho will also optimize over the following model parameters, each of which corrects for a specific source of errror:
model.backgroundA frame-independent detector background
model.weightsA frame-to-frame variation in the incoming probe intensity
model.translation_offsetsA frame-independent detector background
These corrections can be turned off (on) by calling model.<parameter>.requires_grad = False #(True).
Note as well two other changes that are made in this script, when compared to simple_ptycho.py. First, a Reconstructor object is explicitly created, in this case an AdamReconstructor. This object stores a model, dataset, and pytorch optimizer. It is then used to orchestrate the later reconstruction using a call to Reconstructor.optimize().
We use this pattern, instead of the simpler call to model.Adam_optimize(), because having the reconstructor store the optimizer as well as the model and dataset allows the moment estimates to persist between multiple rounds of optimization. This leads to the second change: In this script, we run two optimization loops. The first loop aggressively refines the probe, with a low minibatch size and a high learning rate. The second loop has a smaller learning rate and a larger batch size, which allow for a more precise final estimation of the object.
In this case, we used one reconstructor, but it is possible to create additional reconstructors to zero out all the persistant information in the optimizer, if desired, or even to instantiate multiple reconstructors on the same model with different optimization algorithms (e.g. model.LBFGS_optimize()).
Gold Ball Ptycho
This script shows how the FancyPtycho model might be used in a realistic situation, to perform a reconstruction on the classic gold balls dataset. This script also shows how to save results!
import cdtools
from matplotlib import pyplot as plt
import torch as t
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We pad the dataset with 10 pixels of zeroes around the edge. This
# data gets masked off, so it is not used for the reconstruction. This padding
# helps prevent aliasing when the probe and object get multiplied. It's a
# helpful step when there is signal present out to the edge of the detector,
# and is usually set to the radius of the probe's Fourier transform (in pixels)
pad = 10
dataset.pad(pad)
dataset.inspect()
# When the dataset is padded with zeroes and masked, the probe reconstruction
# becomes very unstable, and often develops noise at these high-frequency,
# masked off frequencies. To combat this, we simulate the probe at lower
# resolution, using the probe_fourier_crop argument. This is good practice
# in general when padding the dataset
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
# This is a trick that my grandmother taught me, to combat the raster grid
# pathology: we randomze the our initial guess of the probe positions.
# The units here are pixels in the object array.
# Try running this script with and without this line to see the difference!
model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets)
# Not much probe intensity instability in this dataset, no need for this
model.weights.requires_grad = False
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# This will save out the intermediate results if an exception is thrown
# during the reconstruction
with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
# We can often reset our guess of the probe positions once we have a
# good guess of probe and object, but in this case it causes the
# raster grid pathology to return.
# model.translation_offsets.data[:] = 0
# Setting schedule=True automatically lowers the learning rate if
# the loss fails to improve after 10 epochs
for loss in recon.optimize(100, lr=0.001, batch_size=100, schedule=True):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
model.tidy_probes()
# This saves the final result
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)
model.inspect(dataset)
model.compare(dataset)
plt.show()
Note, in particular, the use of model.save_on_exception and model.save_to_h5 to save the results of the reconstruction. If a different file format is required, model.save_results will save to a pure-python dictionary.
Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the weights parameter.
Gold Ball Split
It is very common to run reconstructions on two disjoint subsets of the same dataset, as well as the full dataset. This is primarily done to estimate the resolution of a reconstruction via the Fourier ring correlation (FRC).
import cdtools
import torch as t
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
pad = 10
dataset.pad(pad)
# This splits the dataset into a pseudorandomly chosen set of two disjoint
# datasets. The partitioning is drawn from a saved list, so the split is
# deterministic
dataset_1, dataset_2 = dataset.split()
datasets = [dataset_1, dataset_2, dataset]
labels = ['half_1', 'half_2', 'full']
for label, dataset in zip(labels, datasets):
print(f'Working on dataset {label}')
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += \
0.7 * t.randn_like(model.translation_offsets)
model.weights.requires_grad = False
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# Create the reconstructor
recon = cdtools.reconstructors.AdamReconstructor(model, dataset)
# For batched reconstructions like this, there's no need to live-plot
# the progress
for loss in recon.optimize(20, lr=0.005, batch_size=50):
print(model.report())
for loss in recon.optimize(50, lr=0.002, batch_size=100):
print(model.report())
for loss in recon.optimize(100, lr=0.001, batch_size=100,
schedule=True):
print(model.report())
model.tidy_probes()
model.save_to_h5(f'example_reconstructions/gold_balls_{label}.h5', dataset)
This script simply divides the dataset in two, and then performs the same reconstruction on both halves of the dataset, as well as the full dataset.
Gold Ball Synthesize
Once we have a set of three reconstructions - two half data reconstructions and a full data reconstruction, we need to calculate the resolution metrics. This script shows how that is done.
import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
import numpy as np
# We load all three reconstructions
half_1 = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_half_1.h5')
half_2 = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_half_2.h5')
full = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_full.h5')
# This defines the region of recovered object to use for the analysis.
pad = 260
window = np.s_[pad:-pad, pad:-pad]
# This brings all three reconstructions to a common basis, correcting for
# possible global phase offsets, position shifts, and phase ramps. It also
# calculates a Fourier ring correlation and spectral signal-to-noise ratio
# estimate from the two half reconstructions.
results = cdtools.tools.analysis.standardize_reconstruction_set(
half_1,
half_2,
full,
window=window,
nbins=40, # The number of bins to use for the FRC calculation
)
# We plot the normalized object images
p.plot_amplitude(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_full'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_full'][window], basis=results['obj_basis'])
# We plot the calculated Fourier ring correlation
plt.figure()
plt.plot(1e-6*results['frc_freqs'], results['frc'])
plt.plot(1e-6*results['frc_freqs'], results['frc_threshold'], 'k--')
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Fourier Ring Correlation')
plt.legend(['FRC', 'threshold'])
# We plot the calculated spectral signal-to-noise ratio
plt.figure()
plt.semilogy(1e-6*results['frc_freqs'], results['ssnr'])
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Spectral signal-to-noise ratio')
plt.grid('on')
plt.show()
Transmission RPI
CDTools also contains a forward model for randomized probe imaging (RPI, doi:10.1364/OE.397421). This is currently not documented fully, but hopefully will be soon