File size: 474 Bytes
53abcc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Text2VideoTrainer:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        
    def train_step(self, text_batch, video_batch):
        self.optimizer.zero_grad()
        
        generated_video = self.model(text_batch)
        loss = F.mse_loss(generated_video, video_batch)
        
        loss.backward()
        self.optimizer.step()
        
        return loss.item()