|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from argparse import Namespace | 
					
						
						|  | from typing import NamedTuple | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AdaptorInput(NamedTuple): | 
					
						
						|  | images: torch.Tensor | 
					
						
						|  | summary: torch.Tensor | 
					
						
						|  | features: torch.Tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RadioOutput(NamedTuple): | 
					
						
						|  | summary: torch.Tensor | 
					
						
						|  | features: torch.Tensor | 
					
						
						|  |  | 
					
						
						|  | def to(self, *args, **kwargs): | 
					
						
						|  | return RadioOutput( | 
					
						
						|  | self.summary.to(*args, **kwargs) if self.summary is not None else None, | 
					
						
						|  | self.features.to(*args, **kwargs) if self.features is not None else None, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AdaptorBase(nn.Module): | 
					
						
						|  | def forward(self, input: AdaptorInput) -> RadioOutput: | 
					
						
						|  | raise NotImplementedError("Subclasses must implement this!") | 
					
						
						|  |  |