import pytorch_lightning as pl
[docs]class TeacherForcingAnnealCallback(pl.Callback):
def __init__(self, start, end, epochs):
assert start >= end
self.start = start
self.end = end
self.epochs = epochs
self.slope = float((start - end)) / epochs
[docs] def on_validation_epoch_end(self, trainer, model):
current = model.teacher_forcing_prob
new_teacher_forcing_prob = max(self.end, current - self.slope)
model.teacher_forcing_prob = new_teacher_forcing_prob
[docs] @classmethod
def add_cli(self, parser):
parser.add_argument("--teacher_forcing_start", type=float, default=0.8)
parser.add_argument("--teacher_forcing_end", type=float, default=0.0)
parser.add_argument("--teacher_forcing_anneal_epochs", type=int, default=8)
[docs]class TimeMaskedLossCallback(pl.Callback):
def __init__(self, start, end, steps):
assert start <= end
self.start = start
self.end = end
self.steps = steps
self.slope = float((end - start)) / steps
self._time_mask = self.start
@property
def time_mask(self):
return round(self._time_mask)
[docs] def on_train_start(self, trainer, model):
if model.time_masked_idx is None:
model.time_masked_idx = self.time_mask
[docs] def on_train_batch_end(self, trainer, model, *args):
self._time_mask = min(self.end, self._time_mask + self.slope)
model.time_masked_idx = self.time_mask
model.log("time_masked_idx", self.time_mask)
[docs] @classmethod
def add_cli(self, parser):
parser.add_argument("--time_mask_start", type=int, default=1)
parser.add_argument("--time_mask_end", type=int, default=12)
parser.add_argument("--time_mask_anneal_steps", type=int, default=1000)
parser.add_argument("--time_mask_loss", action="store_true")