Tutorial
The following tutorial gives a peek under the hood, and is intended for someone who might want to write their own variant of a ptychography model or modify an existing model to meet a specific need. If you just need to use CDTools for a reconstruction, or are just starting to work with the code, the examples section is a great first introduction.
In the first section of the tutorial, we will discuss how the datasets are defined and go through the process of defining a new dataset type. Following that, we will go through the process of defining a simplified model for standard ptychography.
Datasets
In this section, we will write a bare-bones dataset class for 2D ptychography data to demonstrate the process of writing a new dataset class. At the end of the tutorial, we will have written the file examples/tutorial_basic_ptycho_dataset.py, which can be consulted for reference.
Basic Idea
At it’s core, a dataset object for CDTools is just an object that implements the dataset interface from pytorch. For this reason, the base class (CDataset
) from which all the datasets are defined is itself a subclass of torch.utils.data.Dataset
. In addition, CDataset implements an extra layer that allows for a separation between the device (CPU or GPU) that the data is stored on and the device that it returns data on. This allows for GPU-based reconstructions on datasets that are too large to fit into the GPU in their entirety.
The pytorch Dataset interface is very simple. A dataset simply has to define two functions, __len__()
and __getitem__()
. Thus, we can always access the data in a Dataset mydata
using the syntax mydata[index]
or mydata[slice]
. Overriding these functions will be the first task in defining a new dataset.
In CDTools datasets, the layer that allows for separation between the device that data is stored on and the device that data is loaded onto is implemented in the __getitem__()
function. Instead of overriding this function directly, one should override the _load()
function, which is used internally by __getitem__()
.
In addition to acting as a pytorch Dataset, CDTools Datasets also work as interfaces to .cxi files. Therefore, when writing a new dataset, it is important to also override the functions to_cxi()
and from_cxi()
which handle writing to and reading from cxi files, respectively.
The final piece of the puzzle is the inspect()
method. This is not required to be defined for all datasets, however it is extremely valuable to offer a simple way of exploring a dataset visually. Therefore it is highly recommended to implement this function, which should load a plot or interactive plot that allows a user to visualize the data that they have loaded.
Writing the Skeleton
We can start with the basic skeleton for this file. In addition to our standard imports, we also import the base CDataset class and the data tools. We then define an __all__
list as good practice, and set up the inheritance of our class.
import numpy as np
import torch as t
from matplotlib import pyplot as plt
from cdtools.datasets import CDataset
from cdtools.tools import data as cdtdata
__all__ = ['BasicPtychoDataset']
class BasicPtychoDataset(CDataset):
"""The standard dataset for a 2D ptychography scan"""
pass
Initialization
The next thing to implement is the initialization code. Here we can leverage some of the work already done in the base CDataset class. There are a number of kinds of metadata that can be stored in a .cxi file that aren’t related to the kind of experiment you’re performing - sample ID, start and end times, and so on. The CDataset’s initialization routine handles loading and storing these various kinds of metadata, so we can start the definition of our initialization routine by leveraging this:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Of course, there is also some data that are unique to this kind of dataset. In this case, those data are the probe translations and the measured diffraction patterns. Therefore, we extend this definition to the following:
def __init__(self, translations, patterns, *args, **kwargs):
"""Initialize the dataset from python objects"""
super().__init__(*args, **kwargs)
self.translations = t.Tensor(translations).clone()
self.patterns = t.Tensor(patterns).clone()
Dataset Interface
The next set of functions to write are those that plug into the dataset interface. We want len(dataset)
to return the number of diffraction patterns, which is straightforward to implement.
For the _load()
implementation, we need to consider what format the data should be returned in. The standard for all CDTools datasets is to return a tuple of (inputs, output). The inputs should always be defined as a tuple of inputs, even if there is only one input for this kind of data. As we will see later in the section on constructing models, this makes it possible to write the automatic differentiation code in a way that is applicable to every model.
In this case, our “inputs” will be a tuple of (pattern index, probe translation). This is not the only reasonable choice - it would also be possible, for example to define the input as just a pattern index (and store the probe translations in the model). For simple ptychography models with no error correction, it’s also possible to just take a probe translation as an input with no index. Requiring both is the compromise that’s been implemented in the default ptychography models defined with CDTools, and therefore we will follow that format here.
def __len__(self):
return self.patterns.shape[0]
def _load(self, index):
return (index, self.translations[index]), self.patterns[index]
Remember that it’s not needed to worry about what device or datatype the data is stored as here, as the relevant conversions will be performed by the __getitem()
method defined in the superclass. However, we do generally implement a method, to()
, that moves the data back and forth between devices and datatypes. This lets a user speed up data loading onto the GPU by preloading the data, for example - provided there is enough space.
def to(self, *args, **kwargs):
"""Sends the relevant data to the given device and dtype"""
super().to(*args,**kwargs)
self.translations = self.translations.to(*args, **kwargs)
self.patterns = self.patterns.to(*args, **kwargs)
Here we can see that we first make sure to call the superclass function to handle sending any information (such as a pixel mask, or detector background) that would have been defined in CDataset to the relevant device. Then we handle the new objects that are defined specifically for this kind of dataset.
Loading and Saving
Now we turn to writing the tools to load and save data. First, to load the data, we override from_cxi()
, which is a factory method. In this case, we start by using the superclass to load the metadata. Then we explicitly load in and add the data that’s specific to this dataset class
@classmethod
def from_cxi(cls, cxi_file):
"""Generates a new CDataset from a .cxi file directly"""
# Generate a base dataset
dataset = CDataset.from_cxi(cxi_file)
# Mutate the class to this subclass (BasicPtychoDataset)
dataset.__class__ = cls
# Load the data that is only relevant for this class
patterns, axes = cdtdata.get_data(cxi_file)
translations = cdtdata.get_ptycho_translations(cxi_file)
# And now re-add it
dataset.translations = t.Tensor(translations).clone()
dataset.patterns = t.Tensor(patterns).clone()
return dataset
Now to save the data, we override to_cxi()
, in a fairly self-explanatory way.
def to_cxi(self, cxi_file):
"""Saves out a BasicPtychoDataset as a .cxi file"""
super().to_cxi(cxi_file)
cdtdata.add_data(cxi_file, self.patterns, axes=self.axes)
cdtdata.add_ptycho_translations(cxi_file, self.translations)
Note that these functions should be defined to work on h5py file objects representing the .cxi files (.cxi files are just .h5 files with a special formatting).
Inspecting
The final piece of the puzzle is writing a function to look at your data! This is an important thing to work on for a dataset class that you intend to use regularly, as being able to easily peruse your raw data has incalculable value. Here, we satisfy ourselves with just plotting a random diffraction pattern.
def inspect(self):
"""Plots a random diffraction pattern"""
index = np.random.randint(len(self))
plt.figure()
plt.imshow(self.patterns[index,:,:].cpu().numpy())
Notes
This is a bare-bones class, set up to demonstrate the minimum neccessary to develop a new type of dataset class. As a result, it doesn’t implement a number of things that are useful or valuable in practice (and which the default Ptycho2DDataset does implement). That includes a useful data inspector, the ability to load datasets directly from filenames, and default tweaks to how metadata such as backgrounds and masks are loaded.
Models
In this section, we will write a basic model for 2D ptychography reconstructions. At the end of this tutorial, we will have written the class defined in examples/tutorial_simple_ptycho_model.py
Basic Idea
Just like CDTools Datasets subclass pytorch Datasets, CDTools models subclass pytorch modules. However, the concept of a CDTools model does differ slightly from that of a pytorch module, because the CDTools models also contain a few standard methods to run automatic differentiation reconstructons on themselves.
This isn’t necessarily the cleanest or most portable approach, but we’ve found that it feels very natural from the perspective of an end user interacting with the toolbox only through the reconstruction scripts.
The heart of each model is a model.forward()
function. In any CDTools model, this forward function maps a set of parameters describing the specific diffraction pattern to simulate to the simulated result. When it’s paired with an appropriate dataset for a reconstruction, it maps from the “inputs” defined by the dataset to the “outputs”.
For ptychography, this information is usually index of the exposure within the dataset (which is used to retrieve exposure-to-exposure information, like probe intensity factors) and the object translation.
A simple forward model is defined in the top level CDIModel
class from which all other models are derived, and rarely needs to be overridden. The definition is quite simple:
def forward(self, *args):
return self.measurement(self.forward_propagator(self.interaction(*args)))
So we can see that to fully implement this forward model, we have to define the three functions model.interaction()
, model.forward_propagator()
, and model.measurement()
, which simulate conceptual stages in the diffraction process.
In addition to the core model definition, a few other functions need to be defined to make the model useful. The model needs an initializer to create itself from a dataset, it must have an appropriate loss function defined for use with automtic differentiation, a way of plotting the progress of a reconstruction, and must know how to save the results of a reconstruction in a useful format.
Writing the Skeleton
Once again, we start with the basic skeleton
import torch as t
from cdtools.models import CDIModel
from cdtools import tools
from cdtools.tools import plotting as p
__all__ = ['SimplePtycho']
class SimplePtycho(CDIModel):
"""A simple ptychography model to demonstrate the structure of a model
"""
Note that we imported the full tools package, as we will find ourselves using many low-level functions defined there to implement the model.
Initialization from Python
Two initialization functions need to be written. First, we write the __init__()
function, which initializes the model from a collection of python objects describing the system. We then write an initializer that creates a model using a dataset to initialize the various parameters.
There is no requirement for what the arguments to the initialization function of any particular model should be, only that they contain enough information to run the simulations! It should be chosen in a model-by-model basis to allow for the most sensible code.
def __init__(
self,
wavelength,
probe_basis,
probe_guess,
obj_guess,
min_translation = [0,0],
):
# We initialize the superclass
super().__init__()
# We register all the constants, like wavelength, as buffers. This
# lets the model hook into some nice pytorch features, like using
# model.to, and broadcasting the model state across multiple GPUs
self.register_buffer('wavelength', t.as_tensor(wavelength))
self.register_buffer('min_translation', t.as_tensor(min_translation))
self.register_buffer('probe_basis', t.as_tensor(probe_basis))
# We cast the probe and object to 64-bit complex tensors
probe_guess = t.as_tensor(probe_guess, dtype=t.complex64)
obj_guess = t.as_tensor(obj_guess, dtype=t.complex64)
# We rescale the probe here so it learns at the same rate as the
# object when using optimizers, like Adam, which set the stepsize
# to a fixed maximum
self.register_buffer('probe_norm', t.max(t.abs(probe_guess)))
# And we store the probe and object guesses as parameters, so
# they can get optimized by pytorch
self.probe = t.nn.Parameter(probe_guess / self.probe_norm)
self.obj = t.nn.Parameter(obj_guess)
The first thing to notice about this model is that all the fixed, geometric information is stored with the module.register_buffer()
function. This is what makes it possible to move all the relevant tensors between devices using a single call to module.to()
, for example. It stores thetensor as an object attribute, but it also registers it so that pytorch is aware that this attribute helps to encode the state of the model.
The supporting information we need is the wavelength of the illumination, the basis of the probe array in real space, and an offset to define the zero point of the translation.
The final two pieces of information that we need to save are the probe and object, and both of these get defined as t.nn.Parameter
objects instead of Tensors. As a result, they get registered as parameters in the pytorch module, and will therefore be optimized over in any later reconstructions. In addition, the requires_grad
flag is set to True
, which means that the information needed for gradient calculations will be stored on every Tensor that results from a calculation including a Parameter.
A list of all parameters associated with the module can be found by calling module.parameters()
.
Any additional targets of reconstruction - such as exposure-to-exposure illumination weights, translation offsets, or a detector background - would be added to the model as a parameter in a similar way.
One final note is that we actually store a scaled version of the probe. This is a specific case of a general policy designed around making it easy to use the Adam optimizer.
The Adam optimizer is designed so that the learning rate sets the maximum stepsize which will be taken in any single iteration. Therefore, it is important to make sure that all parameters of the model are of order unity. To enable this, we scale the probe so that the typical pixel value within the probe array is of order 1.
This is important to remember when adding additional error models. Rescaling all the parameters to have a typical amplitude near 1 is the best way to get well-behaved reconstructions.
Initialization from Dataset
To initialize the object from a dataset, we need to start by extracting the relevant information from the dataset, before calling the constructor we defined above:
@classmethod
def from_dataset(cls, dataset):
# We get the key geometry information from the dataset
wavelength = dataset.wavelength
det_basis = dataset.detector_geometry['basis']
det_shape = dataset[0][1].shape
distance = dataset.detector_geometry['distance']
# Then, we generate the probe geometry
ewg = tools.initializers.exit_wave_geometry
probe_basis = ewg(det_basis, det_shape, wavelength, distance)
# Next generate the object geometry from the probe geometry and
# the translations
(indices, translations), patterns = dataset[:]
pix_translations = tools.interactions.translations_to_pixel(
probe_basis,
translations,
)
obj_size, min_translation = tools.initializers.calc_object_setup(
det_shape,
pix_translations,
)
# Finally, initialize the probe and object using this information
probe = tools.initializers.SHARP_style_probe(dataset)
obj = t.ones(obj_size).to(dtype=t.complex64)
return cls(
wavelength,
probe_basis,
probe,
obj,
min_translation=min_translation
)
Here, we start by pulling the basic geometric information from the dataset. Then, we use a number of the basic tools to do calculations such as finding the probe basis from the detector geometry, or calculating how big our object array should be.
Once we have the basic setup ready, we then use one of the initialization functions - in this case, tools.initializers.SHARP_style_probe
, to find a sensible initialization for the probe. This particular initialization is based on the approach used in the SHARP package, where the square-root of the mean diffraction pattern intensity is used to estimate the structure of the illumination at focus.
Once all the needed information has been collected, we initialize the object.
The Forward Model
First, we have to implement the interaction model, as below:
def interaction(self, index, translations):
# We map from real-space to pixel-space units
pix_trans = tools.interactions.translations_to_pixel(
self.probe_basis,
translations)
pix_trans -= self.min_translation
# This function extracts the appropriate window from the object and
# multiplies the object and probe functions
return tools.interactions.ptycho_2D_round(
self.probe_norm * self.probe,
self.obj,
pix_trans)
Here, we take input in the form of an index and a translation. Note that this input format much match the format that is output by the associated datasets that we will use for reconstruction, in this case BasicPtychoDataset.
We start by mapping the translation, given in real space, into pixel coordinates. Then, we use an “off-the-shelf” interaction model - ptycho_2d_round
, which models a standard 2D ptychography interaction, but rounds the translations to the nearest whole pixel (does not attempt subpixel translations).
The next three definitions amount to just choosing an off-the-shelf function to simulate each step in the chain.
def forward_propagator(self, wavefields):
return tools.propagators.far_field(wavefields)
def measurement(self, wavefields):
return tools.measurements.intensity(wavefields)
def loss(self, sim_data, real_data):
return tools.losses.amplitude_mse(real_data, sim_data)
The forward propagator maps the exit wave to the wave at the surface of the detector, here using a far-field propagator. The measurement maps that exit wave to a measured pixel value, and the loss defines a loss function to attempt to minimize. The loss function we’ve chosen - the amplitude mean squared error - is the most reliable one, and can also easily be overridden by an end user.
Plotting
The base CDIModel class has a function, model.inspect()
, which looks for a class variable called plot_list
and plots everything contained within. The plot list should be formatted as a list of tuples, with each tuple containing:
The title of the plot
A function that takes in the model and generates the relevant plot
Optional, a function that takes in the model and returns whether or not the plot should be generated
# This lists all the plots to display on a call to model.inspect()
plot_list = [
('Probe Amplitude',
lambda self, fig: p.plot_amplitude(self.probe, fig=fig, basis=self.probe_basis)),
('Probe Phase',
lambda self, fig: p.plot_phase(self.probe, fig=fig, basis=self.probe_basis)),
('Object Amplitude',
lambda self, fig: p.plot_amplitude(self.obj, fig=fig, basis=self.probe_basis)),
('Object Phase',
lambda self, fig: p.plot_phase(self.obj, fig=fig, basis=self.probe_basis))
]
In this case, we’ve made use of the convenience plotting functions defined in tools.plotting
.
Saving
By default, a function model.save_results()
is defined, which returns a python dictionary with an entry, 'state_dict'
, containing all the registered parameters and buffers in the model. It also contains a basic record of the model’s training history. This function is used internally by model.save_to_h5()
, as well as all other convenience functions for saving results.
Sometimes, it is also useful to return a more user-friendly version of the results, such as a properly rescaled version of the probe. To make this possible, model.save_results()
is often overridden:
def save_results(self, dataset):
# This will save out everything needed to recreate the object
# in the same state, but it's not the best formatted.
base_results = super().save_results()
# So we also save out the main results in a more useable format
probe_basis = self.probe_basis.detach().cpu().numpy()
probe = self.probe.detach().cpu().numpy()
probe = probe * self.probe_norm.detach().cpu().numpy()
obj = self.obj.detach().cpu().numpy()
wavelength = self.wavelength.cpu().numpy()
results = {
'probe_basis': probe_basis,
'probe': probe,
'obj': obj,
'wavelength': wavelength,
}
return {**base_results, **results}
However, it is perfectly possible to write a new ptychography model without overriding model.save_results()
Testing
We can test this model with a simple script, in examples/tutorial_finale.py. By filling in the backend here, we’ve been able to create a ptychography model that can be accessed and used in reconstructions via the same interface as the models we discussed in the examples section.
from tutorial_basic_ptycho_dataset import BasicPtychoDataset
from tutorial_simple_ptycho import SimplePtycho
from h5py import File
from matplotlib import pyplot as plt
filename = 'example_data/lab_ptycho_data.cxi'
with File(filename, 'r') as f:
dataset = BasicPtychoDataset.from_cxi(f)
dataset.inspect()
model = SimplePtycho.from_dataset(dataset)
model.to(device='mps')#cuda')
dataset.get_as(device='mps')#cuda')
for loss in model.Adam_optimize(10, dataset):
model.inspect(dataset)
print(model.report())
model.inspect(dataset)
model.compare(dataset)
plt.show()
Happy modeling!