Forecasters¶
Module that handles all forecaster objects for training PyTorch models.
-
class
forecasters.
Forecaster
(model, loss, optimizer, n_epochs=1, device='cpu', checkpoint_path='./', verbose=True)¶ Handles training of a PyTorch model and can be used to generate samples from approximate posterior predictive distribution.
- Arguments:
- model (
torch.nn.Module
): Instance of Deep4castmodels
. - loss (
torch.distributions
): Instance of PyTorch distribution. - optimizer (
torch.optim
): Instance of PyTorch optimizer. - n_epochs (int): Number of training epochs.
- device (str): Device used for training (cpu or cuda).
- checkpoint_path (str): File system path for writing model checkpoints.
- verbose (bool): Verbosity of forecaster.
- model (
-
embed
(dataloader, n_samples=100) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f7e2719e630>¶ Generate embedding vectors.
- Arguments:
- dataloader (
torch.utils.data.DataLoader
): Data to make embedding vectors. - n_samples (int): Number of forecast samples.
- dataloader (
-
fit
(dataloader_train, dataloader_val=None, eval_model=False)¶ Fits a model to a given a dataset.
- Arguments:
- dataloader_train (
torch.utils.data.DataLoader
): Training data. - dataloader_val (
torch.utils.data.DataLoader
): Validation data. - eval_model (bool): Flag to switch on model evaluation after every epoch.
- dataloader_train (
-
predict
(dataloader, n_samples=100) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f7e2719e9e8>¶ Generates predictions.
- Arguments:
- dataloader (
torch.utils.data.DataLoader
): Data to make forecasts. - n_samples (int): Number of forecast samples.
- dataloader (