Source code for spacetimeformer.data.datamodule

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader


[docs]class DataModule(pl.LightningDataModule): def __init__(self, datasetCls, dataset_kwargs: dict, batch_size: int, workers: int): super().__init__() self.datasetCls = datasetCls self.batch_size = batch_size if "split" in dataset_kwargs.keys(): del dataset_kwargs["split"] self.dataset_kwargs = dataset_kwargs self.workers = workers
[docs] def train_dataloader(self): return self._make_dloader("train")
[docs] def val_dataloader(self): return self._make_dloader("val")
[docs] def test_dataloader(self): return self._make_dloader("test")
def _make_dloader(self, split): return DataLoader( self.datasetCls(**self.dataset_kwargs, split=split), shuffle=True if split == "train" else False, batch_size=self.batch_size, num_workers=self.workers, )
[docs] @classmethod def add_cli(self, parser): parser.add_argument("--batch_size", type=int, default=128) parser.add_argument( "--workers", type=int, default=6, help="number of parallel workers for pytorch dataloader", )