Examples
Included with the repository are a number of example scripts that demonstrate how various aspects of CDTools work. It is recommended to read through at least a few of them before continuing to the tutorial.
All the datasets used in these example scripts are included in the repository, and the scripts should be runnable as soon as CDTools is installed
Inspect Dataset
The first example loads and visualizes ptychography data stored in a .cxi file.
import cdtools
from matplotlib import pyplot as plt
# First, we load an example dataset from a .cxi file
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# And we take a look at the data
dataset.inspect()
plt.show()
First, data is read into a Ptycho2DDataset object, which is a subclass of a pytorch dataset that knows a bit about the structure of ptychography data. Calling dataset.inspect
generates a plot showing an overview of the ptychography scan data. On the left is a scatter plot showing the integrated detector intensity at each probe location. On the right, raw detector images are shown. The dataset can be navigated by clicking around the scatter plot on the left.
Simple Ptycho
This script runs a ptychography reconstruction using the SimplePtycho model, a bare-bones ptychography model for the transmission geometry.
The purpose of the SimplePtycho model is pedagogical: there are very few situations where it would preferred to the FancyPtycho model which will be introduced later.
Because of it’s simplicity, the definition of this model is much simpler than the definition of FancyPtycho, and it is therefore a good first model to look at to learn how to implement a custom ptychography model in CDTools.
import cdtools
from matplotlib import pyplot as plt
# We load an example dataset from a .cxi file
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We create a ptychography model from the dataset
model = cdtools.models.SimplePtycho.from_dataset(dataset)
# We move the model to the GPU
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# We run the actual reconstruction
for loss in model.Adam_optimize(100, dataset, batch_size=10):
# We print a quick report of the optimization status
print(model.report())
# And liveplot the updates to the model as they happen
model.inspect(dataset)
# We study the results
model.inspect(dataset)
model.compare(dataset)
plt.show()
When reading this script, note the basic workflow. After the data is loaded, a model is created to match the geometry stored in the dataset with a sensible default initialization for all the parameters.
Next, the model is moved to the GPU using the model.to
function. Any device understood by torch.Tensor.to
can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the dataset.get_as
function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using dataset.to
, but the speedup is empirically quite small.
Once the device is selected, a reconstruction is run using model.Adam_optimize
. This is a generator function which will yield at every epoch, to allow some monitoring code to be run.
Finally, the results can be studied using model.inspect(dataet)
, which creates or updates a set of plots showing the current state of the model parameters. model.compare(dataset)
is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.
Fancy Ptycho
This script runs a reconstruction on the same data, but using the workhorse FancyPtycho model, demonstrating some of it’s more commonly used features
import cdtools
from matplotlib import pyplot as plt
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# FancyPtycho is the workhorse model
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3, # Use 3 incoherently mixing probe modes
oversampling=2, # Simulate the probe on a 2xlarger real-space array
probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix
propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm
units='mm', # Set the units for the live plots
obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix
)
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# The learning rate parameter sets the alpha for Adam.
# The beta parameters are (0.9, 0.999) by default
# The batch size sets the minibatch size
for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10):
print(model.report())
# Plotting is expensive, so we only do it every tenth epoch
if model.epoch % 10 == 0:
model.inspect(dataset)
# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe, and now we
# polish the reconstruction with a lower learning rate and larger minibatch
for loss in model.Adam_optimize(50, dataset, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
# This orthogonalizes the recovered probe modes
model.tidy_probes()
model.inspect(dataset)
model.compare(dataset)
plt.show()
The FancyPtycho.from_dataset
factory function has many keyword arguments which can turn on or modify various mixins. In this case, we perform a reconstruction with:
3 incoherently mixing probe modes (in the vein of doi:10.1038/nature11806)
A probe array expanded by a factor of 2 in real space, i.e. simulated on a 2x2 upsampled grid in Fourier space (in the vein of doi:10.1103/PhysRevA.87.053850)
A circular finite support constraint applied to the probe
An initial guess for the probe which has been propagated from its focus position
By default, FancyPtycho will also optimize over the following model parameters, each of which corrects for a specific source of errror:
model.background
A frame-independent detector background
model.weights
A frame-to-frame variation in the incoming probe intensity
model.translation_offsets
A frame-independent detector background
These corrections can be turned off (on) by calling model.<parameter>.requires_grad = False #(True)
.
Gold Ball Ptycho
This script shows how the FancyPtycho model might be used in a realistic situation, to perform a reconstruction on the classic gold balls dataset. This script also shows how to save results!
import cdtools
from matplotlib import pyplot as plt
import torch as t
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
# We pad the dataset with 10 pixels of zeroes around the edge. This
# data gets masked off, so it is not used for the reconstruction. This padding
# helps prevent aliasing when the probe and object get multiplied. It's a
# helpful step when there is signal present out to the edge of the detector,
# and is usually set to the radius of the probe's Fourier transform (in pixels)
pad = 10
dataset.pad(pad)
dataset.inspect()
# When the dataset is padded with zeroes and masked, the probe reconstruction
# becomes very unstable, and often develops noise at these high-frequency,
# masked off frequencies. To combat this, we simulate the probe at lower
# resolution, using the probe_fourier_crop argument. This is good practice
# in general when padding the dataset
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
# This is a trick that my grandmother taught me, to combat the raster grid
# pathology: we randomze the our initial guess of the probe positions.
# The units here are pixels in the object array.
# Try running this script with and without this line to see the difference!
model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets)
# Not much probe intensity instability in this dataset, no need for this
model.weights.requires_grad = False
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# This will save out the intermediate results if an exception is thrown
# during the reconstruction
with model.save_on_exception(
'example_reconstructions/gold_balls_earlyexit.h5', dataset):
for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
# We can often reset our guess of the probe positions once we have a
# good guess of probe and object, but in this case it causes the
# raster grid pathology to return.
# model.translation_offsets.data[:] = 0
# Setting schedule=True automatically lowers the learning rate if
# the loss fails to improve after 10 epochs
for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
schedule=True):
print(model.report())
if model.epoch % 10 == 0:
model.inspect(dataset)
model.tidy_probes()
# This saves the final result
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)
model.inspect(dataset)
model.compare(dataset)
plt.show()
Note, in particular, the use of model.save_on_exception
and model.save_to_h5
to save the results of the reconstruction. If a different file format is required, model.save_results
will save to a pure-python dictionary.
Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the weights
parameter.
Gold Ball Split
It is very common to run reconstructions on two disjoint subsets of the same dataset, as well as the full dataset. This is primarily done to estimate the resolution of a reconstruction via the Fourier ring correlation (FRC).
import cdtools
import torch as t
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)
pad = 10
dataset.pad(pad)
# This splits the dataset into a pseudorandomly chosen set of two disjoint
# datasets. The partitioning is drawn from a saved list, so the split is
# deterministic
dataset_1, dataset_2 = dataset.split()
datasets = [dataset_1, dataset_2, dataset]
labels = ['half_1', 'half_2', 'full']
for label, dataset in zip(labels, datasets):
print(f'Working on dataset {label}')
model = cdtools.models.FancyPtycho.from_dataset(
dataset,
n_modes=3,
probe_support_radius=50,
propagation_distance=2e-6,
units='um',
probe_fourier_crop=pad
)
model.translation_offsets.data += \
0.7 * t.randn_like(model.translation_offsets)
model.weights.requires_grad = False
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)
# For batched reconstructions like this, there's no need to live-plot
# the progress
for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
print(model.report())
for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
print(model.report())
for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
schedule=True):
print(model.report())
model.tidy_probes()
model.save_to_h5(f'example_reconstructions/gold_balls_{label}.h5', dataset)
This script simply divides the dataset in two, and then performs the same reconstruction on both halves of the dataset, as well as the full dataset.
Gold Ball Synthesize
Once we have a set of three reconstructions - two half data reconstructions and a full data reconstruction, we need to calculate the resolution metrics. This script shows how that is done.
import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
import numpy as np
# We load all three reconstructions
half_1 = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_half_1.h5')
half_2 = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_half_2.h5')
full = cdtools.tools.data.h5_to_nested_dict(
f'example_reconstructions/gold_balls_full.h5')
# This defines the region of recovered object to use for the analysis.
pad = 260
window = np.s_[pad:-pad, pad:-pad]
# This brings all three reconstructions to a common basis, correcting for
# possible global phase offsets, position shifts, and phase ramps. It also
# calculates a Fourier ring correlation and spectral signal-to-noise ratio
# estimate from the two half reconstructions.
results = cdtools.tools.analysis.standardize_reconstruction_set(
half_1,
half_2,
full,
window=window,
nbins=40, # The number of bins to use for the FRC calculation
)
# We plot the normalized object images
p.plot_amplitude(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_full'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_full'][window], basis=results['obj_basis'])
# We plot the calculated Fourier ring correlation
plt.figure()
plt.plot(1e-6*results['frc_freqs'], results['frc'])
plt.plot(1e-6*results['frc_freqs'], results['frc_threshold'], 'k--')
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Fourier Ring Correlation')
plt.legend(['FRC', 'threshold'])
# We plot the calculated spectral signal-to-noise ratio
plt.figure()
plt.semilogy(1e-6*results['frc_freqs'], results['ssnr'])
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Spectral signal-to-noise ratio')
plt.grid('on')
plt.show()
Transmission RPI
CDTools also contains a forward model for randomized probe imaging (RPI, doi:10.1364/OE.397421). This is currently not documented fully, but hopefully will be soon