Datasets
This module contains all the datasets for interacting with ptychography data
All the access to data from standard ptychography and CDI experiments is coordinated through the various datasets defined in this module. They make use of the lower-level data reading and writing functions defined in tools.data, but critically all of these datasets subclass torch.Dataset. This allows them to be used as standard torch datasets during reconstructions, which helps make it easy to use the various data-handling strategies that are implemented by default in pytorch (such as drawing data in a random order, drawing minibatches, etc.)
New Datasets can be defined a subclass of the main CDataset class defined in the base.py file, and should define the following functions:
__init__
__len__
_load
to
from_cxi
to_cxi
inspect
Example implementations of all these functions can be found in the code for the Ptycho2DDataset class. In addition, it is recommended to read through the tutorial section on defining a new CDI dataset before attempting to do so
- class cdtools.datasets.CDataset(entry_info=None, sample_info=None, wavelength=None, detector_geometry=None, mask=None, background=None)
Bases:
Dataset
The base dataset class which all other datasets subclass
Subclasses torch.utils.data.Dataset
This base dataset class defines the functionality which should be common to all subclassed datasets. This includes the loading and storage of the metadata portions of .cxi files, as well as the tools needed to allow for easy mixing of data on the CPU and GPU.
The __init__ function allows construction from python objects.
The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.
- Parameters:
entry_info (dict) – A dictionary containing the entry_info metadata
sample_info (dict) – A dictionary containing the sample_info metadata
wavelength (float) – The wavelength of light used in the experiment
detector_geometry (dict) – A dictionary containing the various detector geometry parameters
mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead
background (array) – An initial guess for the not-previously-subtracted detector background
- __init__(entry_info=None, sample_info=None, wavelength=None, detector_geometry=None, mask=None, background=None)
The __init__ function allows construction from python objects.
The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.
- Parameters:
entry_info (dict) – A dictionary containing the entry_info metadata
sample_info (dict) – A dictionary containing the sample_info metadata
wavelength (float) – The wavelength of light used in the experiment
detector_geometry (dict) – A dictionary containing the various detector geometry parameters
mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead
background (array) – An initial guess for the not-previously-subtracted detector background
- to(*args, **kwargs)
Sends the relevant data to the given device and dtype
This function sends the stored mask and background to the specified device and dtype
Accepts the same parameters as torch.Tensor.to
- get_as(*args, **kwargs)
Sets the dataset to return data on the given device and dtype
Oftentimes there isn’t room to store an entire dataset on a GPU, but it is still worth running the calculation on the GPU even with the overhead incurred by transferring data back and forth. In that case, get_as can be used instead of to, to declare a set of device and dtype that the data should be returned as, whenever it is accessed through the __getitem__ function (as it would be in any reconstructions).
- Parameters:
torch.Tensor.to (Accepts the same parameters as)
- _load(index)
Internal function to load data
In all subclasses of CDataset, a _load function should be defined. This function is used internally by the global __getitem__ function defined in the base class, which handles moving data around when the dataset is (for example) storing the data on the CPU but getting data as GPU tensors.
It should accept an index or slice, and return output as a tuple. The first item of the tuple is a tuple containing the inputs to the forward model for the related ptychography model. The second item of the tuple should be the set of diffraction patterns associated with the returned inputs.
Since there is no kind of data stored in a CDataset, this function is defined as returing a NotImplemented Error
- classmethod from_cxi(cxi_file)
Generates a new CDataset from a .cxi file directly
This is the most commonly used constructor for CDatasets and subclasses thereof. It populates the dataset using the information in a .cxi file. It can either take an h5py.File object directly, or a filename or pathlib object pointing to the file
- Parameters:
file (str, pathlib.Path, or h5py.File) – The .cxi file to load from
- Returns:
dataset – The constructed dataset object
- Return type:
- to_cxi(cxi_file)
Saves out a CDataset as a .cxi file
This function saves all the compatible information in a CDataset object into a .cxi file. This is useful for saving out modified or simulated datasets
- Parameters:
cxi_file (str, pathlib.Path, or h5py.File) – The .cxi file to write to
- inspect()
The prototype for the inspect function
In all subclasses of CDataset, an inspect function should be defined which opens a tool that shows the data in a natural layout for that kind of experiment. In the base class, no actual data is stored, so this is defined to raise a NotImplementedError
- class cdtools.datasets.Ptycho2DDataset(translations, patterns, intensities=None, axes=None, *args, **kwargs)
Bases:
CDataset
The standard dataset for a 2D ptychography scan
Subclasses datasets.CDataset
This class loads and saves 2D ptychography scan data from .cxi files. It should save and load files compatible with most reconstruction programs, although it is only tested against SHARP.
The __init__ function allows construction from python objects.
The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.
Note that the created dataset object will not copy the data in the patterns parameter in order to avoid doubling the memory requiement for large datasets.
- Parameters:
translations (array) – An nx3 array containing the probe translations at each scan point
patterns (array) – An nxmxl array containing the full stack of measured diffraction patterns
axes (list(str)) – A list of names for the axes of the probe translations
entry_info (dict) – A dictionary containing the entry_info metadata
sample_info (dict) – A dictionary containing the sample_info metadata
wavelength (float) – The wavelength of light used in the experiment
detector_geometry (dict) – A dictionary containing the various detector geometry parameters
mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead
background (array) – An initial guess for the not-previously-subtracted detector background
intensities (array) – A list of measured shot-to-shot intensities
- __init__(translations, patterns, intensities=None, axes=None, *args, **kwargs)
The __init__ function allows construction from python objects.
The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.
Note that the created dataset object will not copy the data in the patterns parameter in order to avoid doubling the memory requiement for large datasets.
- Parameters:
translations (array) – An nx3 array containing the probe translations at each scan point
patterns (array) – An nxmxl array containing the full stack of measured diffraction patterns
axes (list(str)) – A list of names for the axes of the probe translations
entry_info (dict) – A dictionary containing the entry_info metadata
sample_info (dict) – A dictionary containing the sample_info metadata
wavelength (float) – The wavelength of light used in the experiment
detector_geometry (dict) – A dictionary containing the various detector geometry parameters
mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead
background (array) – An initial guess for the not-previously-subtracted detector background
intensities (array) – A list of measured shot-to-shot intensities
- _load(index)
Internal function to load data
This function is used internally by the global __getitem__ function defined in the base class, which handles moving data around when the dataset is (for example) storing the data on the CPU but getting data as GPU tensors.
It loads data in the format (inputs, output)
The inputs for a 2D ptychogaphy data set are:
The indices of the patterns to use
The recorded probe positions associated with those points
- Parameters:
index (int or slice) – The index or indices of the scan points to use
- Returns:
inputs (tuple) – A tuple of the inputs to the related forward models
outputs (tuple) – The output pattern or stack of output patterns
- to(*args, **kwargs)
Sends the relevant data to the given device and dtype
This function sends the stored translations, patterns, mask and background to the specified device and dtype
Accepts the same parameters as torch.Tensor.to
- classmethod from_cxi(cxi_file, cut_zeros=True, load_patterns=True)
Generates a new Ptycho2DDataset from a .cxi file directly
This generates a new Ptycho2DDataset from a .cxi file storing a 2D ptychography scan.
- Parameters:
file (str, pathlib.Path, or h5py.File) – The .cxi file to load from
cut_zeros (bool) – Default True, whether to set all negative data to zero
- Returns:
dataset – The constructed dataset object
- Return type:
- to_cxi(cxi_file)
Saves out a Ptycho2DDataset as a .cxi file
This function saves all the compatible information in a Ptycho2DDataset object into a .cxi file. This saved .cxi file should be compatible with any standard .cxi file based reconstruction tool, such as SHARP.
- Parameters:
cxi_file (str, pathlib.Path, or h5py.File) – The .cxi file to write to
- inspect(logarithmic=True, units='um', log_offset=1, plot_mean_pattern=True)
Launches an interactive plot for perusing the data
This launches an interactive plotting tool in matplotlib that shows the spatial map constructed from the integrated intensity at each position on the left, next to a panel on the right that can display a base-10 log plot of the detector readout at each position.
- plot_mean_pattern(log_offset=1)
Plots the mean diffraction pattern across the dataset
The output is normalized so that the summed intensity on the detector is equal to the total intensity of light that passed through the sample within each detector conjugate field of view.
The plot is plotted as log base 10 of the output plus log_offset. By default, log_offset is set equal to 1, which is a good level for shot-noise limited data captured in units of photons. More generally, log_offset should be set roughly at the background noise level.
- split()
Splits a dataset into two pseudorandomly selected sub-datasets
- pad(to_pad, value=0, mask=True)
Pads all the diffraction patterns by a speficied amount
This is useful for scenarios where the diffraction is strong, even near the edge of the detector. In this scenario, the discrete version of the ptychography model will alias. Padding the diffraction patterns to increase their size and masking off the outer region can account for this effect.
If to_pad is an integer, the patterns will be padded on all sides by this value. If it is a tuple of length 2, then the patterns will be padded (left/right, top/bottom, left/right). If a tuple of length 4, the padding is done as (left, right, top, bottom), following the convention for torch.nn.functional.pad
Any mask and background data which is stored with the dataset will be padded along with the diffraction patterns
- Parameters:
to_pad (int or tuple(int)) – The number of pixels to pad by.
value (float) – Optional, the fill value to pad with. Default is 0
mask (bool) – Optional, whether to mask off the new pixels. Default is True
- downsample(factor=2)
Downsamples all diffraction patterns by the specified factor
This is an easy way to shrink the amount of data you need to work with if the speckle size is much larger than the detector pixel size.
The downsampling factor must be an integer. The size of the output patterns are reduced by the specified factor, with each output pixel equal to the sum of a <factor> x <factor> region of pixels in the input pattern. This summation is done by pytorch.functional.avg_pool2d.
Any mask and background data which is stored with the dataset is downsampled with the data. The background is downsampled using the same method as the data. The mask is expanded so that any output pixel containing a masked pixel will be masked.
- Parameters:
factor (int) – Default 2, the factor to downsample by
- remove_translations_mask(mask_remove)
Removes one or more translation positions, and their associated properties, from the dataset using logical indexing.
This takes a 1D mask (boolean torch tensor) with the length self.translations.shape[0] (i.e., the number of individual translated points). Patterns, translations, and intensities associated with indices that are “True” will be removed.
Parameters:
- mask_remove1D torch.tensor(dtype=torch.bool)
The boolean mask indicating which elements are to be removed from the dataset. True indicates that the corresponding element will be removed.
- crop_translations(roi)
Shrinks the range of translation positions that are analyzed
This deletes all diffraction patterns associated with x- and y-translations that lie outside of a specified rectangular region of interest. In essence, this operation crops the “relative displacement map” (shown in self.inspect()) down to the region of interest.
Parameters:
- roituple(float, float, float, float)
The translation-x and -y coordinates that define the rectangular region of interest as (in units of meters) (left, right, bottom, top). The definition of these bounds are based on how an image is normally displayed with matplotlib’s imshow. The order in which these elements are defined in roi do not matter as long as roi[:2] and roi[2:] correspond with the x and y coordinates, respectively.