Kernels
danieldk's picture
danieldk HF Staff
Build with Torch 2.7 RC1
77d95c1
raw
history blame
2.31 kB
from typing import List, Union, Tuple
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import torch.nn as nn
from ._ops import ops
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
context,
value: Tensor,
value_spatial_shapes: Tensor,
value_level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
context.im2col_step = im2col_step
output = ops.ms_deform_attn_forward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
context.im2col_step,
)
context.save_for_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
)
return output
@staticmethod
@once_differentiable
def backward(context, grad_output):
(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
) = context.saved_tensors
grad_value, grad_sampling_loc, grad_attn_weight = ops.ms_deform_attn_backward(
value,
value_spatial_shapes,
value_level_start_index,
sampling_locations,
attention_weights,
grad_output,
context.im2col_step,
)
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
class MultiScaleDeformableAttention(nn.Module):
def forward(
self,
value: Tensor,
value_spatial_shapes: Tensor,
value_spatial_shapes_list: List[Tuple],
level_start_index: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
im2col_step: int,
):
return MultiScaleDeformableAttentionFunction.apply(
value,
value_spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
im2col_step,
)
__all__ = ["MultiScaleDeformableAttention"]