|
from mmocr.apis import MMOCRInferencer |
|
from mmocr.apis.inferencers.base_mmocr_inferencer import BaseMMOCRInferencer |
|
import torch |
|
from rich.progress import track |
|
import torch.nn as nn |
|
|
|
class BackboneEncoderOnly(nn.Module): |
|
def __init__(self, original_model): |
|
super().__init__() |
|
|
|
self.backbone = original_model.backbone |
|
self.encoder = original_model.encoder |
|
|
|
def forward(self, x): |
|
x = self.backbone(x) |
|
return self.encoder(x) |
|
|
|
|
|
class DecoderOnly(nn.Module): |
|
def __init__(self, original_model): |
|
super().__init__() |
|
|
|
original_decoder = original_model.decoder |
|
|
|
self.classifier = original_decoder.classifier |
|
self.trg_word_emb = original_decoder.trg_word_emb |
|
self.position_enc = original_decoder.position_enc |
|
self._get_target_mask = original_decoder._get_target_mask |
|
self.dropout = original_decoder.dropout |
|
self.layer_stack = original_decoder.layer_stack |
|
self.layer_norm = original_decoder.layer_norm |
|
self._get_source_mask = original_decoder._get_source_mask |
|
self.postprocessor = original_decoder.postprocessor |
|
self.start_idx = 90 |
|
self.padding_idx = 91 |
|
self.max_seq_len = 25 |
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
|
|
def forward(self, trg_seq,src,src_mask,step): |
|
|
|
trg_embedding = self.trg_word_emb(trg_seq) |
|
trg_pos_encoded = self.position_enc(trg_embedding) |
|
trg_mask = self._get_target_mask(trg_seq) |
|
tgt_seq = self.dropout(trg_pos_encoded) |
|
|
|
output = tgt_seq |
|
for dec_layer in self.layer_stack: |
|
output = dec_layer( |
|
output, |
|
src, |
|
self_attn_mask=trg_mask, |
|
dec_enc_attn_mask=src_mask) |
|
output = self.layer_norm(output) |
|
|
|
step_result = self.classifier(output[:, step, :]) |
|
return step_result |
|
|
|
|
|
def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor: |
|
""" |
|
对 uint8 张量进行标准化处理 |
|
参数: |
|
tensor: 输入张量,形状为 [3, 32, 100],数据类型为 uint8 |
|
返回: |
|
标准化后的张量,形状不变,数据类型为 float32 |
|
""" |
|
|
|
assert tensor.shape == (3, 32, 100), "输入张量形状必须为 [3, 32, 100]" |
|
assert tensor.dtype == torch.uint8, "输入张量数据类型必须为 uint8" |
|
|
|
|
|
tensor = tensor.float() |
|
|
|
|
|
mean = torch.tensor([123.675, 116.28, 103.53], dtype=torch.float32).view(3, 1, 1) |
|
std = torch.tensor([58.395, 57.12, 57.375], dtype=torch.float32).view(3, 1, 1) |
|
|
|
|
|
normalized_tensor = (tensor - mean) / std |
|
|
|
return normalized_tensor |
|
|
|
|
|
infer = MMOCRInferencer(rec='satrn') |
|
model = infer.textrec_inferencer.model |
|
model.eval() |
|
model.cpu() |
|
input_path = 'mmor_demo/demo/demo_text_recog.jpg' |
|
ori_inputs = infer._inputs_to_list([input_path]) |
|
base = BaseMMOCRInferencer(model='satrn') |
|
chunked_inputs = base._get_chunk_data(ori_inputs, 1) |
|
for ori_inputs in track(chunked_inputs, description='Inference'): |
|
input = ori_inputs[0][1] |
|
input_img = input['inputs'] |
|
input_image = normalize_tensor(input_img).unsqueeze(0) |
|
input_sample = input['data_samples'] |
|
|
|
|
|
model_backbone_encoder = BackboneEncoderOnly(model) |
|
model_decoder = DecoderOnly(model) |
|
|
|
out_enc = model_backbone_encoder(input_image) |
|
data_samples = None |
|
|
|
N = out_enc.size(0) |
|
init_target_seq = torch.full((N, model_decoder.max_seq_len + 1), |
|
model_decoder.padding_idx, |
|
device=out_enc.device, |
|
dtype=torch.long) |
|
|
|
init_target_seq[:, 0] = model_decoder.start_idx |
|
|
|
outputs = [] |
|
for step in range(0, model_decoder.max_seq_len): |
|
valid_ratios = [1.0 for _ in range(out_enc.size(0))] |
|
if data_samples is not None: |
|
valid_ratios = [] |
|
for data_sample in data_samples: |
|
valid_ratios.append(data_sample.get('valid_ratio')) |
|
|
|
src_mask = model_decoder._get_source_mask(out_enc, valid_ratios) |
|
step_result = model_decoder(init_target_seq,out_enc,src_mask,step) |
|
outputs.append(step_result) |
|
_, step_max_index = torch.max(step_result, dim=-1) |
|
init_target_seq[:, step + 1] = step_max_index |
|
outputs = torch.stack(outputs, dim=1) |
|
out_dec = model_decoder.softmax(outputs) |
|
output = model_decoder.postprocessor(out_dec, [input_sample]) |
|
outstr = output[0].pred_text.item |
|
outscore = output[0].pred_text.score |
|
|
|
print('pred_text:',outstr) |
|
print('score:',outscore) |
|
|
|
|
|
|
|
|
|
|