tractolearn.generative package#
Submodules#
tractolearn.generative.generate_points module#
- class tractolearn.generative.generate_points.RejectionSampler(data: ndarray, kde_bw: float | None = None, kde_bw_factor: float = 1, kernel='gaussian', proposal_distribution_params: dict | None = None, scaling_mode: str = 'max', allow_singular=False, kde_bw_auto_estimation: str = 'Cross-Validation', proposal_distribution_name: str = 'multivariate_normal', bundle_name=None, output=None, cluster_estimation: str = 'silhouette')#
Bases:
object
Implements the rejection sampling algorithm using an abstract interface that only requires data from the distribution we want to sample from.
- sample(nb_samples: int, batch_size: int | None = None, entropy: int = 1234) tuple #
Performs rejection sampling to sample N samples that fit the visible distribution of data. :param nb_samples: The number of samples to sample from the data distribution. :type nb_samples: int :param batch_size: Number of samples to generate in each batch. If
None
, defaultsto
nb_samples / 100
.- Parameters:
entropy (int) – Entropy for the seed generator.
- Returns:
M x D array where M equals nb_samples and D is the dimensionality
of the sampled data.
- tractolearn.generative.generate_points.generate_points(output: str, name: str, device: str, model, bundle: array, num_generate_points: int = 1000, atlas_bundle: array | None = None, max_seeds: int | None = None, composition: Tuple[int, int] = (1, 0), bandwidth: float | None = None, plot_seeds_generated: bool = False, use_rs: bool = False, optimization='composition', gmm_n_component: int = 11, random_seed: int = 1234)#
Generate new streamlines from an AE model and seed streamlines.
- Parameters:
output (str) – Output path.
name (str) – Bundle name.
device (str) – cpu or cuda.
model (str) – AE model for streamline compression.
bundle (ndarray) – Bundle array (N x 256).
num_generate_points (int) – Number of streamline to generate.
atlas_bundle (ndarray) – Atlas bundle array (N x 256)
max_seeds (int) – Maximum number of seed streamlines.
composition (Tuple[int, int]) – Composition of seeds (subject bundle|atlas bundle).
bandwidth (float) – Kernel Density bandwidth.
plot_seeds_generated (bool) – Flag to plot umap streamlines in latent space.
use_rs (bool) – If true will use RS instead of gaussian sampling.
optimization (str) – Possible options are [‘composition’, ‘max_seeds’].
gmm_n_component (str) – Number of GMM components for RS proposal distribution.
random_seed (int) – Random seed.
- Returns:
ndarray
Generated streamlines.