Datasets

This module contains all the datasets for interacting with ptychography data

All the access to data from standard ptychography and CDI experiments is coordinated through the various datasets defined in this module. They make use of the lower-level data reading and writing functions defined in tools.data, but critically all of these datasets subclass torch.Dataset. This allows them to be used as standard torch datasets during reconstructions, which helps make it easy to use the various data-handling strategies that are implemented by default in pytorch (such as drawing data in a random order, drawing minibatches, etc.)

New Datasets can be defined a subclass of the main CDataset class defined in the base.py file, and should define the following functions:

  • __init__

  • __len__

  • _load

  • to

  • from_cxi

  • to_cxi

  • inspect

Example implementations of all these functions can be found in the code for the Ptycho2DDataset class. In addition, it is recommended to read through the tutorial section on defining a new CDI dataset before attempting to do so

class cdtools.datasets.CDataset(entry_info=None, sample_info=None, wavelength=None, detector_geometry=None, mask=None, background=None)

Bases: Dataset

The base dataset class which all other datasets subclass

Subclasses torch.utils.data.Dataset

This base dataset class defines the functionality which should be common to all subclassed datasets. This includes the loading and storage of the metadata portions of .cxi files, as well as the tools needed to allow for easy mixing of data on the CPU and GPU.

The __init__ function allows construction from python objects.

The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.

Parameters:
  • entry_info (dict) – A dictionary containing the entry_info metadata

  • sample_info (dict) – A dictionary containing the sample_info metadata

  • wavelength (float) – The wavelength of light used in the experiment

  • detector_geometry (dict) – A dictionary containing the various detector geometry parameters

  • mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead

  • background (array) – An initial guess for the not-previously-subtracted detector background

__init__(entry_info=None, sample_info=None, wavelength=None, detector_geometry=None, mask=None, background=None)

The __init__ function allows construction from python objects.

The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.

Parameters:
  • entry_info (dict) – A dictionary containing the entry_info metadata

  • sample_info (dict) – A dictionary containing the sample_info metadata

  • wavelength (float) – The wavelength of light used in the experiment

  • detector_geometry (dict) – A dictionary containing the various detector geometry parameters

  • mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead

  • background (array) – An initial guess for the not-previously-subtracted detector background

to(*args, **kwargs)

Sends the relevant data to the given device and dtype

This function sends the stored mask and background to the specified device and dtype

Accepts the same parameters as torch.Tensor.to

get_as(*args, **kwargs)

Sets the dataset to return data on the given device and dtype

Oftentimes there isn’t room to store an entire dataset on a GPU, but it is still worth running the calculation on the GPU even with the overhead incurred by transferring data back and forth. In that case, get_as can be used instead of to, to declare a set of device and dtype that the data should be returned as, whenever it is accessed through the __getitem__ function (as it would be in any reconstructions).

Parameters:

torch.Tensor.to (Accepts the same parameters as)

_load(index)

Internal function to load data

In all subclasses of CDataset, a _load function should be defined. This function is used internally by the global __getitem__ function defined in the base class, which handles moving data around when the dataset is (for example) storing the data on the CPU but getting data as GPU tensors.

It should accept an index or slice, and return output as a tuple. The first item of the tuple is a tuple containing the inputs to the forward model for the related ptychography model. The second item of the tuple should be the set of diffraction patterns associated with the returned inputs.

Since there is no kind of data stored in a CDataset, this function is defined as returing a NotImplemented Error

classmethod from_cxi(cxi_file)

Generates a new CDataset from a .cxi file directly

This is the most commonly used constructor for CDatasets and subclasses thereof. It populates the dataset using the information in a .cxi file. It can either take an h5py.File object directly, or a filename or pathlib object pointing to the file

Parameters:

file (str, pathlib.Path, or h5py.File) – The .cxi file to load from

Returns:

dataset – The constructed dataset object

Return type:

CDataset

to_cxi(cxi_file)

Saves out a CDataset as a .cxi file

This function saves all the compatible information in a CDataset object into a .cxi file. This is useful for saving out modified or simulated datasets

Parameters:

cxi_file (str, pathlib.Path, or h5py.File) – The .cxi file to write to

inspect()

The prototype for the inspect function

In all subclasses of CDataset, an inspect function should be defined which opens a tool that shows the data in a natural layout for that kind of experiment. In the base class, no actual data is stored, so this is defined to raise a NotImplementedError

class cdtools.datasets.Ptycho2DDataset(translations, patterns, intensities=None, axes=None, *args, **kwargs)

Bases: CDataset

The standard dataset for a 2D ptychography scan

Subclasses datasets.CDataset

This class loads and saves 2D ptychography scan data from .cxi files. It should save and load files compatible with most reconstruction programs, although it is only tested against SHARP.

The __init__ function allows construction from python objects.

The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.

Note that the created dataset object will not copy the data in the patterns parameter in order to avoid doubling the memory requiement for large datasets.

Parameters:
  • translations (array) – An nx3 array containing the probe translations at each scan point

  • patterns (array) – An nxmxl array containing the full stack of measured diffraction patterns

  • axes (list(str)) – A list of names for the axes of the probe translations

  • entry_info (dict) – A dictionary containing the entry_info metadata

  • sample_info (dict) – A dictionary containing the sample_info metadata

  • wavelength (float) – The wavelength of light used in the experiment

  • detector_geometry (dict) – A dictionary containing the various detector geometry parameters

  • mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead

  • background (array) – An initial guess for the not-previously-subtracted detector background

  • intensities (array) – A list of measured shot-to-shot intensities

__init__(translations, patterns, intensities=None, axes=None, *args, **kwargs)

The __init__ function allows construction from python objects.

The detector_geometry dictionary is defined to have the entries defined by the outputs of data.get_detector_geometry.

Note that the created dataset object will not copy the data in the patterns parameter in order to avoid doubling the memory requiement for large datasets.

Parameters:
  • translations (array) – An nx3 array containing the probe translations at each scan point

  • patterns (array) – An nxmxl array containing the full stack of measured diffraction patterns

  • axes (list(str)) – A list of names for the axes of the probe translations

  • entry_info (dict) – A dictionary containing the entry_info metadata

  • sample_info (dict) – A dictionary containing the sample_info metadata

  • wavelength (float) – The wavelength of light used in the experiment

  • detector_geometry (dict) – A dictionary containing the various detector geometry parameters

  • mask (array) – A mask for the detector, defined as 1 for live pixels, 0 for dead

  • background (array) – An initial guess for the not-previously-subtracted detector background

  • intensities (array) – A list of measured shot-to-shot intensities

_load(index)

Internal function to load data

This function is used internally by the global __getitem__ function defined in the base class, which handles moving data around when the dataset is (for example) storing the data on the CPU but getting data as GPU tensors.

It loads data in the format (inputs, output)

The inputs for a 2D ptychogaphy data set are:

  1. The indices of the patterns to use

  2. The recorded probe positions associated with those points

Parameters:

index (int or slice) – The index or indices of the scan points to use

Returns:

  • inputs (tuple) – A tuple of the inputs to the related forward models

  • outputs (tuple) – The output pattern or stack of output patterns

to(*args, **kwargs)

Sends the relevant data to the given device and dtype

This function sends the stored translations, patterns, mask and background to the specified device and dtype

Accepts the same parameters as torch.Tensor.to

classmethod from_cxi(cxi_file, cut_zeros=True, load_patterns=True)

Generates a new Ptycho2DDataset from a .cxi file directly

This generates a new Ptycho2DDataset from a .cxi file storing a 2D ptychography scan.

Parameters:
  • file (str, pathlib.Path, or h5py.File) – The .cxi file to load from

  • cut_zeros (bool) – Default True, whether to set all negative data to zero

Returns:

dataset – The constructed dataset object

Return type:

Ptycho2DDataset

to_cxi(cxi_file)

Saves out a Ptycho2DDataset as a .cxi file

This function saves all the compatible information in a Ptycho2DDataset object into a .cxi file. This saved .cxi file should be compatible with any standard .cxi file based reconstruction tool, such as SHARP.

Parameters:

cxi_file (str, pathlib.Path, or h5py.File) – The .cxi file to write to

inspect(logarithmic=True, units='um', log_offset=1, plot_mean_pattern=True)

Launches an interactive plot for perusing the data

This launches an interactive plotting tool in matplotlib that shows the spatial map constructed from the integrated intensity at each position on the left, next to a panel on the right that can display a base-10 log plot of the detector readout at each position.

plot_mean_pattern(log_offset=1)

Plots the mean diffraction pattern across the dataset

The output is normalized so that the summed intensity on the detector is equal to the total intensity of light that passed through the sample within each detector conjugate field of view.

The plot is plotted as log base 10 of the output plus log_offset. By default, log_offset is set equal to 1, which is a good level for shot-noise limited data captured in units of photons. More generally, log_offset should be set roughly at the background noise level.

split()

Splits a dataset into two pseudorandomly selected sub-datasets

pad(to_pad, value=0, mask=True)

Pads all the diffraction patterns by a speficied amount

This is useful for scenarios where the diffraction is strong, even near the edge of the detector. In this scenario, the discrete version of the ptychography model will alias. Padding the diffraction patterns to increase their size and masking off the outer region can account for this effect.

If to_pad is an integer, the patterns will be padded on all sides by this value. If it is a tuple of length 2, then the patterns will be padded (left/right, top/bottom, left/right). If a tuple of length 4, the padding is done as (left, right, top, bottom), following the convention for torch.nn.functional.pad

Any mask and background data which is stored with the dataset will be padded along with the diffraction patterns

Parameters:
  • to_pad (int or tuple(int)) – The number of pixels to pad by.

  • value (float) – Optional, the fill value to pad with. Default is 0

  • mask (bool) – Optional, whether to mask off the new pixels. Default is True

downsample(factor=2)

Downsamples all diffraction patterns by the specified factor

This is an easy way to shrink the amount of data you need to work with if the speckle size is much larger than the detector pixel size.

The downsampling factor must be an integer. The size of the output patterns are reduced by the specified factor, with each output pixel equal to the sum of a <factor> x <factor> region of pixels in the input pattern. This summation is done by pytorch.functional.avg_pool2d.

Any mask and background data which is stored with the dataset is downsampled with the data. The background is downsampled using the same method as the data. The mask is expanded so that any output pixel containing a masked pixel will be masked.

Parameters:

factor (int) – Default 2, the factor to downsample by

remove_translations_mask(mask_remove)

Removes one or more translation positions, and their associated properties, from the dataset using logical indexing.

This takes a 1D mask (boolean torch tensor) with the length self.translations.shape[0] (i.e., the number of individual translated points). Patterns, translations, and intensities associated with indices that are “True” will be removed.

Parameters:

mask_remove1D torch.tensor(dtype=torch.bool)

The boolean mask indicating which elements are to be removed from the dataset. True indicates that the corresponding element will be removed.

crop_translations(roi)

Shrinks the range of translation positions that are analyzed

This deletes all diffraction patterns associated with x- and y-translations that lie outside of a specified rectangular region of interest. In essence, this operation crops the “relative displacement map” (shown in self.inspect()) down to the region of interest.

Parameters:

roituple(float, float, float, float)

The translation-x and -y coordinates that define the rectangular region of interest as (in units of meters) (left, right, bottom, top). The definition of these bounds are based on how an image is normally displayed with matplotlib’s imshow. The order in which these elements are defined in roi do not matter as long as roi[:2] and roi[2:] correspond with the x and y coordinates, respectively.