tractolearn.utils package#

Submodules#

tractolearn.utils.layer_utils module#

class tractolearn.utils.layer_utils.PredictWrapper(fn)#

Bases: object

predict(x)#

tractolearn.utils.logging_setup module#

tractolearn.utils.logging_setup.set_up(log_fname)#

tractolearn.utils.losses module#

tractolearn.utils.losses.loss_contrastive_lecun_classes(z, margin)#

Attract pairs of latent vectors of the same class, repulse pairs of different classes

This is the contrastive loss as defined by Hadsell, Chopra and LeCun, 2006. However, in their paper, they don’t use class information.

Parameters:
  • z (torch.Tensor) – Tensor of size (num_pos_pairs * 2 + num_neg_pairs * 2, latent_size). This is the batch format output by ContrastiveDataset.

  • margin (float) – The margin hyperparameter

Returns:

Contrastive loss tensor

Return type:

torch.tensor

tractolearn.utils.losses.loss_function_ae(recon_x, x)#
tractolearn.utils.losses.loss_function_vae(recon_x, x, mu, logvar)#
tractolearn.utils.losses.loss_triplet_classes(z, margin, metric='l2', swap=False)#

Triplet loss implementation [1]

Parameters:
  • z (torch.Tensor) – Tensor of size (num_pos_pairs * 2 + num_neg_pairs * 2, latent_size). This is the batch format output by TripletDataset.

  • margin (float) – The margin hyperparameter

  • metric (str) – latent space distance metric

  • swap (bool) – If True, and if the positive example is closer to the negative example than the anchor is, swaps the positive example and the anchor in the loss computation.

References

[1] Balntas, V., Riba, E., Ponsa, D. & Mikolajczyk, K. Learning local feature descriptors with triplets and shallow

convolutional neural networks. in Procedings of the British Machine Vision Conference 2016 119.1-119.11 (British Machine Vision Association, 2016). doi:10.5244/C.30.119.

tractolearn.utils.losses.loss_triplet_hierarchical_classes(z, margin, metric='l2')#

Custom implementation of a hierarchical triplet loss using QuickBundlesX hierarchy

Parameters:
  • z (torch.Tensor) – Tensor of size (num_pos_pairs * 2 + num_neg_pairs * 2, latent_size). This is the batch format output by TripletDataset.

  • margin (float) – The margin hyperparameter

  • metric (str) – latent space distance metric

tractolearn.utils.losses.triplet_margin_with_distance_loss_hierarchical(anchor: Tensor, positives: List[Tensor], negative: Tensor, distance_function: Callable[[Tensor, Tensor], Tensor] | None = None, margin: float = 1.0, reduction: str = 'mean') Tensor#

See TripletMarginWithDistanceLoss for details.

tractolearn.utils.processing_utils module#

tractolearn.utils.processing_utils.postprocess(x_reconstructed, isocenter, volume)#

tractolearn.utils.timer module#

class tractolearn.utils.timer.Timer#

Bases: object

Timer class to estimate time in a with statement.

tractolearn.utils.utils module#

tractolearn.utils.utils.generate_uuid()#
tractolearn.utils.utils.make_run_dir(out_path=None)#

Create a directory for this training run

Module contents#