| from typing import List | |
| import torch | |
| from ._ops import ops | |
| from . import layers | |
| def ms_deform_attn_backward( | |
| value: torch.Tensor, | |
| spatial_shapes: torch.Tensor, | |
| level_start_index: torch.Tensor, | |
| sampling_loc: torch.Tensor, | |
| attn_weight: torch.Tensor, | |
| grad_output: torch.Tensor, | |
| im2col_step: int, | |
| ) -> List[torch.Tensor]: | |
| return ops.ms_deform_attn_backward( | |
| value, | |
| spatial_shapes, | |
| level_start_index, | |
| sampling_loc, | |
| attn_weight, | |
| grad_output, | |
| im2col_step, | |
| ) | |
| def ms_deform_attn_forward( | |
| value: torch.Tensor, | |
| spatial_shapes: torch.Tensor, | |
| level_start_index: torch.Tensor, | |
| sampling_loc: torch.Tensor, | |
| attn_weight: torch.Tensor, | |
| im2col_step: int, | |
| ) -> torch.Tensor: | |
| return ops.ms_deform_attn_forward( | |
| value, | |
| spatial_shapes, | |
| level_start_index, | |
| sampling_loc, | |
| attn_weight, | |
| im2col_step, | |
| ) | |
| __all__ = ["layers", "ms_deform_attn_forward", "ms_deform_attn_backward"] | |