Datasets¶
In PyTorch Lightning, datasets can be integrated by overloading dataloader functions in the LightningModule
:
train_dataloader()
val_dataloader()
test_dataloader()
This is exactly what a RideDataset
does.
In addition, it adds num_workers
and batch_size
configs
as well as self.input_shape
and self.output_shape
tuples (which are very handy for computing layer shapes).
For classification dataset, the RideClassificationDataset
expects a list of class-names defined in self.classes
and provides a self.num_classes
attribute.
self.classes
are then used plotting, e.g. if “–test_confusion_matrix True” is specified in the CLI.
In order to define a RideDataset
, one can either define the train_dataloader()
, val_dataloader()
, test_dataloader()
and functions or assign a LightningDataModule
to self.datamodule
as seen here:
from ride.core import AttributeDict, RideClassificationDataset, Configs
from ride.utils.env import DATASETS_PATH
import pl_bolts
class MnistDataset(RideClassificationDataset):
@staticmethod
def configs():
c = Configs.collect(MnistDataset)
c.add(
name="val_split",
type=int,
default=5000,
strategy="constant",
description="Number samples from train dataset used for val split.",
)
c.add(
name="normalize",
type=int,
default=1,
choices=[0, 1],
strategy="constant",
description="Whether to normalize dataset.",
)
return c
def __init__(self, hparams: AttributeDict):
self.datamodule = pl_bolts.datamodules.MNISTDataModule(
data_dir=DATASETS_PATH,
val_split=self.hparams.val_split,
num_workers=self.hparams.num_workers,
normalize=self.hparams.normalize,
batch_size=self.hparams.batch_size,
seed=42,
shuffle=True,
pin_memory=self.hparams.num_workers > 1,
drop_last=False,
)
self.output_shape = 10
self.classes = list(range(10))
self.input_shape = self.datamodule.dims
Changing dataset¶
Though the dataset is specified at module definition, we can change the dataset using with_dataset()
.
This is especially handy for experiments using a single module over multiple datasets:
MyRideModuleWithMnistDataset = MyRideModule.with_dataset(MnistDataset)
MyRideModuleWithCifar10Dataset = MyRideModule.with_dataset(Cifar10Dataset)
...
Next, we’ll cover how the RideModule
integrates with Main
.