class Text2VideoTrainer: | |
def __init__(self, model, optimizer, device): | |
self.model = model | |
self.optimizer = optimizer | |
self.device = device | |
def train_step(self, text_batch, video_batch): | |
self.optimizer.zero_grad() | |
generated_video = self.model(text_batch) | |
loss = F.mse_loss(generated_video, video_batch) | |
loss.backward() | |
self.optimizer.step() | |
return loss.item() | |