Examples

Included with the repository are a number of example scripts that demonstrate how various aspects of CDTools work. It is recommended to read through at least a few of them before continuing to the tutorial.

All the datasets used in these example scripts are included in the repository, and the scripts should be runnable as soon as CDTools is installed

Inspect Dataset

The first example loads and visualizes ptychography data stored in a .cxi file.

import cdtools
from matplotlib import pyplot as plt

# First, we load an example dataset from a .cxi file
filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

# And we take a look at the data
dataset.inspect()
plt.show()

First, data is read into a Ptycho2DDataset object, which is a subclass of a pytorch dataset that knows a bit about the structure of ptychography data. Calling dataset.inspect generates a plot showing an overview of the ptychography scan data. On the left is a scatter plot showing the integrated detector intensity at each probe location. On the right, raw detector images are shown. The dataset can be navigated by clicking around the scatter plot on the left.

Simple Ptycho

This script runs a ptychography reconstruction using the SimplePtycho model, a bare-bones ptychography model for the transmission geometry.

The purpose of the SimplePtycho model is pedagogical: there are very few situations where it would preferred to the FancyPtycho model which will be introduced later.

Because of it’s simplicity, the definition of this model is much simpler than the definition of FancyPtycho, and it is therefore a good first model to look at to learn how to implement a custom ptychography model in CDTools.

import cdtools
from matplotlib import pyplot as plt

# We load an example dataset from a .cxi file
filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

# We create a ptychography model from the dataset
model = cdtools.models.SimplePtycho.from_dataset(dataset)

# We move the model to the GPU
device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

# We run the actual reconstruction
for loss in model.Adam_optimize(100, dataset, batch_size=10):
    # We print a quick report of the optimization status
    print(model.report())
    # And liveplot the updates to the model as they happen
    model.inspect(dataset)

# We study the results
model.inspect(dataset)
model.compare(dataset)
plt.show()

When reading this script, note the basic workflow. After the data is loaded, a model is created to match the geometry stored in the dataset with a sensible default initialization for all the parameters.

Next, the model is moved to the GPU using the model.to function. Any device understood by torch.Tensor.to can be specified here. The next line is a bit more subtle - the dataset is told to move patterns to the GPU before passing them to the model using the dataset.get_as function. This function does not move the stored patterns to the GPU. If there is sufficient GPU memory, the patterns can also be pre-moved to the GPU using dataset.to, but the speedup is empirically quite small.

Once the device is selected, a reconstruction is run using model.Adam_optimize. This is a generator function which will yield at every epoch, to allow some monitoring code to be run.

Finally, the results can be studied using model.inspect(dataet), which creates or updates a set of plots showing the current state of the model parameters. model.compare(dataset) is also called, which shows how the simulated diffraction patterns compare to the measured diffraction patterns in the dataset.

Fancy Ptycho

This script runs a reconstruction on the same data, but using the workhorse FancyPtycho model, demonstrating some of it’s more commonly used features

import cdtools
from matplotlib import pyplot as plt

filename = 'example_data/lab_ptycho_data.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

# FancyPtycho is the workhorse model
model = cdtools.models.FancyPtycho.from_dataset(
    dataset,
    n_modes=3, # Use 3 incoherently mixing probe modes
    oversampling=2, # Simulate the probe on a 2xlarger real-space array
    probe_support_radius=120, # Force the probe to 0 outside a radius of 120 pix
    propagation_distance=5e-3, # Propagate the initial probe guess by 5 mm
    units='mm', # Set the units for the live plots
    obj_view_crop=-50, # Expands the field of view in the object plot by 50 pix
)

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

# The learning rate parameter sets the alpha for Adam.
# The beta parameters are (0.9, 0.999) by default
# The batch size sets the minibatch size
for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10):
    print(model.report())
    # Plotting is expensive, so we only do it every tenth epoch
    if model.epoch % 10 == 0:
        model.inspect(dataset)

# It's common to chain several different reconstruction loops. Here, we
# started with an aggressive refinement to find the probe, and now we
# polish the reconstruction with a lower learning rate and larger minibatch
for loss in model.Adam_optimize(50, dataset,  lr=0.005, batch_size=50):
    print(model.report())
    if model.epoch % 10 == 0:
        model.inspect(dataset)

# This orthogonalizes the recovered probe modes
model.tidy_probes()

model.inspect(dataset)
model.compare(dataset)
plt.show()

The FancyPtycho.from_dataset factory function has many keyword arguments which can turn on or modify various mixins. In this case, we perform a reconstruction with:

  • 3 incoherently mixing probe modes (in the vein of doi:10.1038/nature11806)

  • A probe array expanded by a factor of 2 in real space, i.e. simulated on a 2x2 upsampled grid in Fourier space (in the vein of doi:10.1103/PhysRevA.87.053850)

  • A circular finite support constraint applied to the probe

  • An initial guess for the probe which has been propagated from its focus position

By default, FancyPtycho will also optimize over the following model parameters, each of which corrects for a specific source of errror:

model.background

A frame-independent detector background

model.weights

A frame-to-frame variation in the incoming probe intensity

model.translation_offsets

A frame-independent detector background

These corrections can be turned off (on) by calling model.<parameter>.requires_grad = False #(True).

Gold Ball Ptycho

This script shows how the FancyPtycho model might be used in a realistic situation, to perform a reconstruction on the classic gold balls dataset. This script also shows how to save results!

import cdtools
from matplotlib import pyplot as plt
import torch as t

filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

# We pad the dataset with 10 pixels of zeroes around the edge. This
# data gets masked off, so it is not used for the reconstruction. This padding
# helps prevent aliasing when the probe and object get multiplied. It's a
# helpful step when there is signal present out to the edge of the detector,
# and is usually set to the radius of the probe's Fourier transform (in pixels)
pad = 10
dataset.pad(pad)

dataset.inspect()

# When the dataset is padded with zeroes and masked, the probe reconstruction
# becomes very unstable, and often develops noise at these high-frequency,
# masked off frequencies. To combat this, we simulate the probe at lower
# resolution, using the probe_fourier_crop argument. This is good practice
# in general when padding the dataset
model = cdtools.models.FancyPtycho.from_dataset(
    dataset,
    n_modes=3,
    probe_support_radius=50,
    propagation_distance=2e-6,
    units='um',
    probe_fourier_crop=pad 
)

# This is a trick that my grandmother taught me, to combat the raster grid
# pathology: we randomze the our initial guess of the probe positions.
# The units here are pixels in the object array.
# Try running this script with and without this line to see the difference!
model.translation_offsets.data += 0.7 * t.randn_like(model.translation_offsets)

# Not much probe intensity instability in this dataset, no need for this
model.weights.requires_grad = False

device = 'cuda'
model.to(device=device)
dataset.get_as(device=device)

# This will save out the intermediate results if an exception is thrown
# during the reconstruction
with model.save_on_exception(
        'example_reconstructions/gold_balls_earlyexit.h5', dataset):
    
    for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
        print(model.report())
        if model.epoch % 10 == 0:
            model.inspect(dataset)

    for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
        print(model.report())
        if model.epoch % 10 == 0:
            model.inspect(dataset)

    # We can often reset our guess of the probe positions once we have a
    # good guess of probe and object, but in this case it causes the
    # raster grid pathology to return.
    # model.translation_offsets.data[:] = 0

    # Setting schedule=True automatically lowers the learning rate if
    # the loss fails to improve after 10 epochs
    for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
                                    schedule=True):
        print(model.report())
        if model.epoch % 10 == 0:
            model.inspect(dataset)


model.tidy_probes()

# This saves the final result
model.save_to_h5('example_reconstructions/gold_balls.h5', dataset)

model.inspect(dataset)
model.compare(dataset)
plt.show()

Note, in particular, the use of model.save_on_exception and model.save_to_h5 to save the results of the reconstruction. If a different file format is required, model.save_results will save to a pure-python dictionary.

Finally, note that there are several small adjustments made to the script to counteract particular sources of error that are present in this dataset, for example the raster grid pathology caused by the scan pattern used. Also note that not every mixin is needed every time - in this case, we turn off optimization of the weights parameter.

Gold Ball Split

It is very common to run reconstructions on two disjoint subsets of the same dataset, as well as the full dataset. This is primarily done to estimate the resolution of a reconstruction via the Fourier ring correlation (FRC).

import cdtools
import torch as t

filename = 'example_data/AuBalls_700ms_30nmStep_3_6SS_filter.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(filename)

pad = 10
dataset.pad(pad)

# This splits the dataset into a pseudorandomly chosen set of two disjoint
# datasets. The partitioning is drawn from a saved list, so the split is
# deterministic
dataset_1, dataset_2 = dataset.split()

datasets = [dataset_1, dataset_2, dataset]
labels = ['half_1', 'half_2', 'full']

for label, dataset in zip(labels, datasets):
    print(f'Working on dataset {label}')

    model = cdtools.models.FancyPtycho.from_dataset(
        dataset,
        n_modes=3,
        probe_support_radius=50,
        propagation_distance=2e-6,
        units='um',
        probe_fourier_crop=pad 
    )

    model.translation_offsets.data += \
        0.7 * t.randn_like(model.translation_offsets)
    
    model.weights.requires_grad = False
    
    device = 'cuda'
    model.to(device=device)
    dataset.get_as(device=device)

    # For batched reconstructions like this, there's no need to live-plot
    # the progress
    for loss in model.Adam_optimize(20, dataset, lr=0.005, batch_size=50):
        print(model.report())

    for loss in model.Adam_optimize(50, dataset, lr=0.002, batch_size=100):
        print(model.report())

    for loss in model.Adam_optimize(100, dataset, lr=0.001, batch_size=100,
                                    schedule=True):
        print(model.report())

        
    model.tidy_probes()

    model.save_to_h5(f'example_reconstructions/gold_balls_{label}.h5', dataset)

This script simply divides the dataset in two, and then performs the same reconstruction on both halves of the dataset, as well as the full dataset.

Gold Ball Synthesize

Once we have a set of three reconstructions - two half data reconstructions and a full data reconstruction, we need to calculate the resolution metrics. This script shows how that is done.

import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
import numpy as np

# We load all three reconstructions
half_1 = cdtools.tools.data.h5_to_nested_dict(
    f'example_reconstructions/gold_balls_half_1.h5')
half_2 = cdtools.tools.data.h5_to_nested_dict(
    f'example_reconstructions/gold_balls_half_2.h5')
full = cdtools.tools.data.h5_to_nested_dict(
    f'example_reconstructions/gold_balls_full.h5')

# This defines the region of recovered object to use for the analysis.
pad = 260
window = np.s_[pad:-pad, pad:-pad]

# This brings all three reconstructions to a common basis, correcting for
# possible global phase offsets, position shifts, and phase ramps. It also
# calculates a Fourier ring correlation and spectral signal-to-noise ratio
# estimate from the two half reconstructions.
results = cdtools.tools.analysis.standardize_reconstruction_set(
    half_1,
    half_2,
    full,
    window=window,
    nbins=40, # The number of bins to use for the FRC calculation
)

# We plot the normalized object images
p.plot_amplitude(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_1'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_half_2'][window], basis=results['obj_basis'])
p.plot_amplitude(results['obj_full'][window], basis=results['obj_basis'])
p.plot_phase(results['obj_full'][window], basis=results['obj_basis'])

# We plot the calculated Fourier ring correlation
plt.figure()
plt.plot(1e-6*results['frc_freqs'], results['frc'])
plt.plot(1e-6*results['frc_freqs'], results['frc_threshold'], 'k--')
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Fourier Ring Correlation')
plt.legend(['FRC', 'threshold'])

# We plot the calculated spectral signal-to-noise ratio
plt.figure()
plt.semilogy(1e-6*results['frc_freqs'], results['ssnr'])
plt.xlabel('Frequency (cycles / um)')
plt.ylabel('Spectral signal-to-noise ratio')
plt.grid('on')

plt.show()

Transmission RPI

CDTools also contains a forward model for randomized probe imaging (RPI, doi:10.1364/OE.397421). This is currently not documented fully, but hopefully will be soon