tractolearn.generative package#

Submodules#

tractolearn.generative.generate_points module#

class tractolearn.generative.generate_points.RejectionSampler(data: ndarray, kde_bw: float | None = None, kde_bw_factor: float = 1, kernel='gaussian', proposal_distribution_params: dict | None = None, scaling_mode: str = 'max', allow_singular=False, kde_bw_auto_estimation: str = 'Cross-Validation', proposal_distribution_name: str = 'multivariate_normal', bundle_name=None, output=None, cluster_estimation: str = 'silhouette')#

Bases: object

Implements the rejection sampling algorithm using an abstract interface that only requires data from the distribution we want to sample from.

sample(nb_samples: int, batch_size: int | None = None, entropy: int = 1234) tuple#

Performs rejection sampling to sample N samples that fit the visible distribution of data. :param nb_samples: The number of samples to sample from the data distribution. :type nb_samples: int :param batch_size: Number of samples to generate in each batch. If None, defaults

to nb_samples / 100.

Parameters:

entropy (int) – Entropy for the seed generator.

Returns:

  • M x D array where M equals nb_samples and D is the dimensionality

  • of the sampled data.

tractolearn.generative.generate_points.generate_points(output: str, name: str, device: str, model, bundle: array, num_generate_points: int = 1000, atlas_bundle: array | None = None, max_seeds: int | None = None, composition: Tuple[int, int] = (1, 0), bandwidth: float | None = None, plot_seeds_generated: bool = False, use_rs: bool = False, optimization='composition', gmm_n_component: int = 11, random_seed: int = 1234)#

Generate new streamlines from an AE model and seed streamlines.

Parameters:
  • output (str) – Output path.

  • name (str) – Bundle name.

  • device (str) – cpu or cuda.

  • model (str) – AE model for streamline compression.

  • bundle (ndarray) – Bundle array (N x 256).

  • num_generate_points (int) – Number of streamline to generate.

  • atlas_bundle (ndarray) – Atlas bundle array (N x 256)

  • max_seeds (int) – Maximum number of seed streamlines.

  • composition (Tuple[int, int]) – Composition of seeds (subject bundle|atlas bundle).

  • bandwidth (float) – Kernel Density bandwidth.

  • plot_seeds_generated (bool) – Flag to plot umap streamlines in latent space.

  • use_rs (bool) – If true will use RS instead of gaussian sampling.

  • optimization (str) – Possible options are [‘composition’, ‘max_seeds’].

  • gmm_n_component (str) – Number of GMM components for RS proposal distribution.

  • random_seed (int) – Random seed.

Returns:

  • ndarray

  • Generated streamlines.

Module contents#