RideModule¶
The RideModule
works in conjunction with the LightningModule
, to add functionality to a plain Module
.
While LightningModule
adds a bunch of structural code, that integrates with the Trainer
, the RideModule
provides good defaults for
Train loop -
training_step()
Validation loop -
validation_step()
Test loop -
test_step()
Optimizers -
configure_optimizers()
The only things left to be defined are
Initialisation -
__init__()
.Network forward pass -
forward()
.
The following thus constitutes a fully functional Neural Network module, which (when integrated with ride.Main
) provides full functionality for training, testing, hyperparameters search, profiling , etc., via a command line interface.
from ride import RideModule
from .examples.mnist_dataset import MnistDataset
class MyRideModule(RideModule, MnistDataset):
def __init__(self, hparams):
hidden_dim = 128
# `self.input_shape` and `self.output_shape` were injected via `MnistDataset`
self.l1 = torch.nn.Linear(np.prod(self.input_shape), hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, self.output_shape)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
Configs¶
Out of the box, a wide selection parameters are integrated into self.hparams through ride.Main
.
These include all the pytorch_lightning.Trainer
options, as well as configs in ride.lifecycle.Lifecycle.configs()
, the selected optimizer (default: ride.optimizers.SgdOptimizer.configs()
).
User-defined hyperparameters, which are reflected self.hparams, the command line interface, and hyperparameter serach space (by selection of choices and strategy), are easily defined by defining a configs method MyRideModule
:
@staticmethod
def configs() -> ride.Configs:
c = ride.Configs()
c.add(
name="hidden_dim",
type=int,
default=128,
strategy="choice",
choices=[128, 256, 512, 1024],
description="Number of hidden units.",
)
return c
The configs package is also available seperately in the Co-Rider package.
Advanced behavior overloading¶
Lifecycle methods¶
Naturally, the training_step()
, validation_step()
, and test_step()
can still be overloaded if complex computational schemes are required.
In that case, ending the function with common_step()
will ensure that loss computation and collection of metrics still works as expected:
def training_step(self, batch, batch_idx=None):
x, target = batch
pred = self.forward(x) # replace with complex interaction
return self.common_step(pred, target, prefix="train/", log=True)
Loss¶
By default, RideModule
automatically integrates the loss functions in torch.nn.functional
(set by command line using the “–loss” flag).
If other options are needed, one can define the self.loss()
in the module.
def loss(self, pred, target):
return my_exotic_loss(pred, target)
Optimizer¶
The SgdOptimizer
is added automatically if no other Optimizer
is found and configure_optimizers()
is not manually defined.
Other optimizers can thus be specified by using either Mixins:
class MyModel(
ride.RideModule,
ride.AdamWOneCycleOptimizer
):
def __init__(self, hparams):
...
or function overloading:
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
While the specifying parent Mixins automatically adds ride.AdamWOneCycleOptimizer.configs()
and hparams, the function overloading approach must be supplemented with a configs()
methods in order to reflect the parameter in the command line tool and hyperparameter search space.
@staticmethod
def configs() -> ride.Configs:
c = ride.Configs()
c.add(
name="learning_rate",
type=float,
default=0.1,
choices=(1e-6, 1),
strategy="loguniform",
description="Learning rate.",
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
Next, we’ll see how to specify dataset.