Losses
Contains various loss functions to be used for optimization
It exposes three losses, one returning the mean squared amplitude error, one that returns the mean squared intensity error, and one that returns the maximum likelihood metric for a system with Poisson statistics.
- cdtools.tools.losses.amplitude_mse(intensities, sim_intensities, mask=None, use_sum=False)
Returns the mean squared error of a simulated dataset’s amplitudes
Calculates the mean squared error between a given set of measured diffraction intensities and a simulated set.
This function calculates the mean squared error between their associated amplitudes. Because this is not well defined for negative numbers, make sure that all the intensities are >0 before using this loss.
It can accept intensity and simulated intensity tensors of any shape as long as their shapes match, and the provided mask array can be broadcast correctly along them.
This is empirically the most useful loss function for most cases where a photon counting detector cannot be used.
Note that, when used with the AmplitudeMSENormalizer, this function should be called with use_sum=True, in order to return the sum-squared error rather than the mean-squared error. This allows for the AmplitudeMSENormalizer to properly weight the loss arising from minibatches which may not have equal length.
- Parameters:
intensities (torch.Tensor) – A tensor with measured detector values
sim_intensities (torch.Tensor) – A tensor of simulated detector intensities
mask (torch.Tensor) – A mask with ones for pixels to include and zeros for pixels to exclude
use_sum (bool) – Default is False. If set to True, actually performs the sum squared error
- Returns:
loss – A single value for the mean amplitude mse
- Return type:
torch.Tensor
- class cdtools.tools.losses.AmplitudeMSENormalizer
Bases:
objectNormalizer for the amplitude MSE loss, used with recon.optimize
This is a normalizer designed for use with the recon.optimize function. The normalization is done separately from the loss, in order to make it simple to use different normalization strategies for different loss metrics and to make it easier to work with different minibatch sizes.
This normalizer accumulates the total number of pixels across all patterns during the first epoch, then divides the summed loss by this count to convert from sum-squared error to mean-squared error.
The normalizer is stateful: it completes its accumulation phase on the first epoch and then applies the same normalization factor for all subsequent epochs.
- accumulate(patterns, mask=None)
Accumulate the normalization factor (called once per minibatch).
- normalize_loss(loss)
Apply the accumulated normalization (called once per epoch).
- __init__()
- cdtools.tools.losses.intensity_mse(intensities, sim_intensities, mask=None, use_sum=False)
Returns the mean squared error of a simulated dataset’s intensities
Calculates the summed mean squared error between a given set of diffraction intensities - the measured set of detector intensities - and a simulated set of diffraction intensities. This function calculates the mean squared error between the intensities.
It can accept intensity and simulated intensity tensors of any shape as long as their shapes match, and the provided mask array can be broadcast correctly along them.
This is rarely a good loss function for ptychography, but can occasionally be useful.
Note that, when used with the IntensityMSENormalizer, this function should be called with use_sum=True, in order to return the sum-squared error rather than the mean-squared error. This allows for the IntensityMSENormalizer to properly weight the loss arising from minibatches which may not have equal length.
- Parameters:
intensities (torch.Tensor) – A tensor with measured detector intensities.
sim_intensities (torch.Tensor) – A tensor of simulated detector intensities
mask (torch.Tensor) – A mask with ones for pixels to include and zeros for pixels to exclude
use_sum (bool) – Default is False. If set to True, actually performs the sum squared error
- Returns:
loss – A single value for the mean intensity mse
- Return type:
torch.Tensor
- class cdtools.tools.losses.IntensityMSENormalizer
Bases:
objectNormalizer for the intensity MSE loss, used with recon.optimize
This is a normalizer designed for use with the recon.optimize function. The normalization is done separately from the loss, in order to make it simple to use different normalization strategies for different loss metrics and to make it easier to work with different minibatch sizes.
This normalizer accumulates the total number of pixels across all patterns during the first epoch, then divides the summed loss by this count to convert from sum-squared error to mean-squared error.
The normalizer is stateful: it completes its accumulation phase on the first epoch and then applies the same normalization factor for all subsequent epochs.
- accumulate(patterns, mask=None)
Accumulate the normalization factor (called once per minibatch).
- normalize_loss(loss)
Apply the accumulated normalization (called once per epoch).
- __init__()
- accumulate(patterns, mask=None)
Accumulate pixel counts from a batch of patterns.
- Parameters:
patterns (torch.Tensor) – A tensor of measured detector patterns
mask (torch.Tensor, optional) – A mask with ones for pixels to include and zeros for pixels to exclude. If provided, only masked pixels are counted.
- normalize_loss(loss)
Convert summed loss to mean loss by dividing by pixel count.
- Parameters:
loss (torch.Tensor) – The accumulated summed loss across minibatches in an epoch
- Returns:
normalized_loss – The loss divided by the total number of pixels
- Return type:
torch.Tensor
- cdtools.tools.losses.poisson_nll(intensities, sim_intensities, mask=None, eps=1e-06, subtract_min=False)
Returns the Poisson negative log likelihood for simulated intensities
Calculates the overall Poisson maximum likelihood metric using diffraction intensities - the measured set of detector intensities - and a simulated set of intensities. This loss would be appropriate for detectors in a single-photon counting mode, with their output scaled to number of photons
Note that this calculation ignores the log(intensities!) term in the full expression for Poisson negative log likelihood. This term doesn’t change the calculated gradients so isn’t worth taking the time to compute
It can accept intensity and simulated intensity tensors of any shape as long as their shapes match, and the provided mask array can be broadcast correctly along them.
The default value of eps is 1e-6 - a nonzero value here helps avoid divergence of the log function near zero.
This is generally the best loss metric to use for ptychography when a photon counting detector is used.
- Parameters:
intensities (torch.Tensor) – A tensor with measured detector intensities.
sim_intensities (torch.Tensor) – A tensor of simulated detector intensities
mask (torch.Tensor) – A mask with ones for pixels to include and zeros for pixels to exclude
eps (float) – Optional, a small number to add to the simulated intensities
subtract_min (bool) – Default is False, whether to subtract a min to produce a nonnegative output
- Returns:
loss – A single value for the poisson negative log likelihood
- Return type:
torch.Tensor
- class cdtools.tools.losses.SimplePoissonNLLNormalizer
Bases:
objectNormalizer for the intensity MSE loss, used with recon.optimize
This is a normalizer designed for use with the recon.optimize function. The normalization is done separately from the loss, in order to make it simple to use different normalization strategies for different loss metrics and to make it easier to work with different minibatch sizes.
This normalizer converts raw Poisson negative log likelihood values into a statistic that is more interpretable for comparing reconstructions. It performs two operations:
Offset subtraction: Subtracts the NLL calculated when comparing measured patterns to themselves (i.e., poisson_nll(data, data)). This represents the best-case scenario and makes the loss non-negative.
Normalization scaling: Divides by 0.5 times the count of non-zero pixels in the measured patterns. This is because, roughly, each non-zero pixel is expected to contribute 0.5 to the Poisson NLL, if Poisson noise were the only relevant source of noise in the data.
The normalizer is stateful: it completes its accumulation phase on the first epoch by processing all patterns in the data, then applies the same normalization factors for all subsequent epochs.
- accumulate(patterns, mask=None)
Accumulate the normalization factor (called once per minibatch).
- normalize_loss(loss)
Apply the accumulated normalization (called once per epoch).
- __init__()
- accumulate(patterns, mask=None)
Accumulate statistics needed for normalization from a batch.
During the first epoch, this method counts non-zero pixels and computes the Poisson NLL comparing patterns to themselves, which defines the offset baseline for the loss.
- Parameters:
patterns (torch.Tensor) – A tensor of measured detector patterns
mask (torch.Tensor, optional) – A mask with ones for pixels to include and zeros for pixels to exclude. If provided, only masked pixels are counted.
- normalize_loss(loss)
Normalize the Poisson NLL for interpretability across datasets.
- Parameters:
loss (torch.Tensor) – The accumulated Poisson NLL across minibatches in an epoch
- Returns:
normalized_loss – The offset-corrected and scaled loss value
- Return type:
torch.Tensor