tractolearn.learning package#

Submodules#

tractolearn.learning.data_manager module#

class tractolearn.learning.data_manager.DataManager(experiment_dict: dict, seed: int)#

Bases: object

setup_data_loader()#
setup_dataset()#

tractolearn.learning.dataset module#

class tractolearn.learning.dataset.ContrastiveDataset(experiment_dict: dict, type_set: str, seed, num_pairs)#

Bases: IterableDataset

This dataset returns batches, not items. Should be used with batch_size=None

property num_points#
property point_dims#
property rng#
set_seed(seed)#
class tractolearn.learning.dataset.HierarchicalDataset(experiment_dict: dict, type_set: str, seed: int, num_pairs: int)#

Bases: IterableDataset

property num_points#
property point_dims#
property rng#
set_seed(seed)#
class tractolearn.learning.dataset.OnTheFlyDataset(X: array, y: array, to_transpose=True)#

Bases: Dataset

class tractolearn.learning.dataset.StreamlineClassificationDataset(experiment_dict: dict, set: str, seed: int)#

Bases: IterableDataset

get_random_streamline_from_class(class_idx)#
set_seed(seed)#
class tractolearn.learning.dataset.StreamlineClassificationDatasetTree(experiment_dict: dict, set: str, seed: int)#

Bases: IterableDataset

get_random_streamline_from_class_with_merge(class_idx)#
get_random_streamline_from_class_without_merge()#
set_seed(seed)#
class tractolearn.learning.dataset.TripletDataset(experiment_dict: dict, type_set: str, seed, num_pairs)#

Bases: ContrastiveDataset

This dataset returns batches, not items. Should be used with batch_size=None

tractolearn.learning.trainer_manager module#

class tractolearn.learning.trainer_manager.Trainer(experiment_dict: dict, experiment_dir: str, device: device, data: Tuple[DataLoader, DataLoader, DataLoader], input_size: Tuple[int, int], isocenter: array, volume: array, experiment_recorder: Experiment)#

Bases: object

property best_checkpoint#
best_model_fname = 'best_model.pt'#
build_model()#
property experiment#
get_batch_iterator(dataloader, phase)#
load_checkpoint(fname, device)#
property lowest_loss#
property model#
property model_name#
property normalize#
plot_results(data, fname: str)#
save_checkpoint(state, fname='best_model.pt')#
save_loss_history()#
property test_loader#
property test_loss_recorder#
train(epoch)#
property train_loader#
property train_loss_recorder#
valid(epoch)#
property valid_loader#
property valid_loss_recorder#

Module contents#