File size: 474 Bytes
53abcc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
class Text2VideoTrainer:
def __init__(self, model, optimizer, device):
self.model = model
self.optimizer = optimizer
self.device = device
def train_step(self, text_batch, video_batch):
self.optimizer.zero_grad()
generated_video = self.model(text_batch)
loss = F.mse_loss(generated_video, video_batch)
loss.backward()
self.optimizer.step()
return loss.item()
|