tractolearn.models package#

Submodules#

tractolearn.models.autoencoding_utils module#

tractolearn.models.autoencoding_utils.encode_data(latent_space_loader: DataLoader, device: device, model: Module, limit_batch: int | None = None) Tuple[array, array]#

Encode streamlines in a Dataloader object in a smaller latent space.

Parameters:
  • latent_space_loader (DataLoader) – Dataloader containing streamlines to encode.

  • device (torch.device) – Device to use.

  • model (Module) – Deep learning model.

  • limit_batch (int) – Encode a limit number of batch from the latent_space_loader.

Returns:

Latent space and streamline bundle classes.

Return type:

Tuple[np.array, np.array]

tractolearn.models.forward module#

tractolearn.models.forward.forward_ae(model, loss_fn, device, batch)#

Take a labeled batch from HDF5Dataset, encode it, decode it, and compute the appropriate loss.

tractolearn.models.forward.forward_ae_contrastive(model, loss_fn, device, streamline_batch)#

Take a batch of streamlines from ContrastiveDataset, encode it, decode it, and compute the supplied loss fn.

The loss fn takes in x, x_recon and z; which allows to compute both a reconstruction loss and a constrastive loss.

tractolearn.models.forward.forward_contrastive(model, loss_fn, device, streamline_batch)#

Take a batch of streamlines from ContrastiveDataset, encode it, and compute a contrastive loss.

tractolearn.models.forward.make_forward(model, device, experiment_dict)#

Make a forward pass function by combining a model execution and a loss function computation.

tractolearn.models.model_performance_history module#

class tractolearn.models.model_performance_history.LossHistory#

Bases: object

History of the loss during training. (Lighter version of MetricHistory)

Usage:

monitor = LossHistory() … # Call update at each iteration monitor.update(2.3) … monitor.avg # returns the average loss … monitor.end_epoch() # call at epoch end … monitor.epochs # returns the loss curve as a list

property avg#
end_epoch(write=True)#
update(value)#

tractolearn.models.model_pool module#

tractolearn.models.model_pool.get_model(model_name, latent_space_dims, device)#

Autoencoders (AE) [1]; Variational Autoencoders (VAE) [2]; Reparamaterization trick [2].

References

tractolearn.models.track_ae_cnn1d_incr_feat_strided_conv_fc_upsamp_reflect_pad_pytorch module#

class tractolearn.models.track_ae_cnn1d_incr_feat_strided_conv_fc_upsamp_reflect_pad_pytorch.IncrFeatStridedConvFCUpsampReflectPadAE(latent_space_dims)#

Bases: Module

Strided convolution-upsampling-based AE using reflection-padding and increasing feature maps in decoder.

decode(z)#
encode(x)#
forward(x)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents#