tractolearn.learning package#
Submodules#
tractolearn.learning.data_manager module#
tractolearn.learning.dataset module#
- class tractolearn.learning.dataset.ContrastiveDataset(experiment_dict: dict, type_set: str, seed, num_pairs)#
Bases:
IterableDataset
This dataset returns batches, not items. Should be used with batch_size=None
- property num_points#
- property point_dims#
- property rng#
- set_seed(seed)#
- class tractolearn.learning.dataset.HierarchicalDataset(experiment_dict: dict, type_set: str, seed: int, num_pairs: int)#
Bases:
IterableDataset
- property num_points#
- property point_dims#
- property rng#
- set_seed(seed)#
- class tractolearn.learning.dataset.OnTheFlyDataset(X: array, y: array, to_transpose=True)#
Bases:
Dataset
- class tractolearn.learning.dataset.StreamlineClassificationDataset(experiment_dict: dict, set: str, seed: int)#
Bases:
IterableDataset
- get_random_streamline_from_class(class_idx)#
- set_seed(seed)#
- class tractolearn.learning.dataset.StreamlineClassificationDatasetTree(experiment_dict: dict, set: str, seed: int)#
Bases:
IterableDataset
- get_random_streamline_from_class_with_merge(class_idx)#
- get_random_streamline_from_class_without_merge()#
- set_seed(seed)#
- class tractolearn.learning.dataset.TripletDataset(experiment_dict: dict, type_set: str, seed, num_pairs)#
Bases:
ContrastiveDataset
This dataset returns batches, not items. Should be used with batch_size=None
tractolearn.learning.trainer_manager module#
- class tractolearn.learning.trainer_manager.Trainer(experiment_dict: dict, experiment_dir: str, device: device, data: Tuple[DataLoader, DataLoader, DataLoader], input_size: Tuple[int, int], isocenter: array, volume: array, experiment_recorder: Experiment)#
Bases:
object
- property best_checkpoint#
- best_model_fname = 'best_model.pt'#
- build_model()#
- property experiment#
- get_batch_iterator(dataloader, phase)#
- load_checkpoint(fname, device)#
- property lowest_loss#
- property model#
- property model_name#
- property normalize#
- plot_results(data, fname: str)#
- save_checkpoint(state, fname='best_model.pt')#
- save_loss_history()#
- property test_loader#
- property test_loss_recorder#
- train(epoch)#
- property train_loader#
- property train_loss_recorder#
- valid(epoch)#
- property valid_loader#
- property valid_loss_recorder#