BootyShakerAI / trainer.py
ChromiumPlutoniumAI's picture
Create trainer.py
53abcc6 verified
raw
history blame
474 Bytes
class Text2VideoTrainer:
def __init__(self, model, optimizer, device):
self.model = model
self.optimizer = optimizer
self.device = device
def train_step(self, text_batch, video_batch):
self.optimizer.zero_grad()
generated_video = self.model(text_batch)
loss = F.mse_loss(generated_video, video_batch)
loss.backward()
self.optimizer.step()
return loss.item()