spacetimeformer.data package
- class spacetimeformer.data.csv_dataset.CSVTimeSeries(data_path: str, target_cols: List[str], read_csv_kwargs={}, val_split: float = 0.2, test_split: float = 0.15)[source]
Bases:
object
- property test_data
- property train_data
- property val_data
- class spacetimeformer.data.csv_dataset.CSVTorchDset(csv_time_series: spacetimeformer.data.csv_dataset.CSVTimeSeries, split: str = 'train', context_points: int = 128, target_points: int = 32, time_resolution: int = 1)[source]
Bases:
torch.utils.data.dataset.Dataset
- class spacetimeformer.data.datamodule.DataModule(*args: Any, **kwargs: Any)[source]
Bases:
pytorch_lightning.core.datamodule.LightningDataModule
- test_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a postive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying testing samples.
Example:
def test_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def test_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a test dataset and a
test_step()
, you don’t need to implement this method.Note
In the case where you return multiple test dataloaders, the
test_step()
will have an argumentdataloader_idx
which matches the order here.
- train_dataloader()[source]
Implement one or more PyTorch DataLoaders for training.
- Returns
A collection of
torch.utils.data.DataLoader
specifying training samples. In the case of multiple dataloaders, please see this page.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
…
prepare_data()
setup()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Example:
# single dataloader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True ) return loader # multiple dataloaders, return as list def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a list of tensors: [batch_mnist, batch_cifar] return [mnist_loader, cifar_loader] # multiple dataloader, return as dict def train_dataloader(self): mnist = MNIST(...) cifar = CIFAR(...) mnist_loader = torch.utils.data.DataLoader( dataset=mnist, batch_size=self.batch_size, shuffle=True ) cifar_loader = torch.utils.data.DataLoader( dataset=cifar, batch_size=self.batch_size, shuffle=True ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader}
- val_dataloader()[source]
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data()
.fit()
…
prepare_data()
Note
Lightning adds the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns
A
torch.utils.data.DataLoader
or a sequence of them specifying validation samples.
Examples:
def val_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=False ) return loader # can also return multiple dataloaders def val_dataloader(self): return [loader_a, loader_b, ..., loader_n]
Note
If you don’t need a validation dataset and a
validation_step()
, you don’t need to implement this method.Note
In the case where you return multiple validation dataloaders, the
validation_step()
will have an argumentdataloader_idx
which matches the order here.