File size: 5,084 Bytes
c99dcd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from mmocr.apis import MMOCRInferencer
from mmocr.apis.inferencers.base_mmocr_inferencer import BaseMMOCRInferencer
import torch
from rich.progress import track
import torch.nn as nn

class BackboneEncoderOnly(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        # 保留 backbone 和 encoder
        self.backbone = original_model.backbone
        self.encoder = original_model.encoder
        
    def forward(self, x):
        x = self.backbone(x)
        return self.encoder(x)


class DecoderOnly(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        # 保留 backbone 和 encoder
        original_decoder = original_model.decoder
        # self._attention = original_decoder._attention
        self.classifier = original_decoder.classifier
        self.trg_word_emb = original_decoder.trg_word_emb
        self.position_enc = original_decoder.position_enc
        self._get_target_mask = original_decoder._get_target_mask
        self.dropout = original_decoder.dropout
        self.layer_stack = original_decoder.layer_stack
        self.layer_norm = original_decoder.layer_norm
        self._get_source_mask = original_decoder._get_source_mask
        self.postprocessor = original_decoder.postprocessor
        self.start_idx = 90
        self.padding_idx = 91
        self.max_seq_len = 25
        self.softmax = nn.Softmax(dim=-1)
        
        
        
    def forward(self, trg_seq,src,src_mask,step):
        # decoder_output = self._attention(init_target_seq, out_enc, src_mask=src_mask)
        trg_embedding = self.trg_word_emb(trg_seq)
        trg_pos_encoded = self.position_enc(trg_embedding)
        trg_mask = self._get_target_mask(trg_seq)
        tgt_seq = self.dropout(trg_pos_encoded)

        output = tgt_seq
        for dec_layer in self.layer_stack:
            output = dec_layer(
                output,
                src,
                self_attn_mask=trg_mask,
                dec_enc_attn_mask=src_mask)
        output = self.layer_norm(output)
        # bsz * seq_len * C
        step_result = self.classifier(output[:, step, :])
        return step_result


def normalize_tensor(tensor: torch.Tensor) -> torch.Tensor:
    """
    对 uint8 张量进行标准化处理
    参数:
        tensor: 输入张量,形状为 [3, 32, 100],数据类型为 uint8
    返回:
        标准化后的张量,形状不变,数据类型为 float32
    """
    # 检查输入张量的形状和数据类型
    assert tensor.shape == (3, 32, 100), "输入张量形状必须为 [3, 32, 100]"
    assert tensor.dtype == torch.uint8, "输入张量数据类型必须为 uint8"
    
    # 转换为 float32 类型
    tensor = tensor.float()
    
    # 定义标准化参数(RGB 通道顺序)
    mean = torch.tensor([123.675, 116.28, 103.53], dtype=torch.float32).view(3, 1, 1)
    std = torch.tensor([58.395, 57.12, 57.375], dtype=torch.float32).view(3, 1, 1)
    
    # 执行标准化:(x - mean) / std
    normalized_tensor = (tensor - mean) / std
    
    return normalized_tensor


infer = MMOCRInferencer(rec='satrn')
model = infer.textrec_inferencer.model
model.eval()
model.cpu()
input_path = 'mmor_demo/demo/demo_text_recog.jpg'
ori_inputs = infer._inputs_to_list([input_path])
base = BaseMMOCRInferencer(model='satrn')
chunked_inputs = base._get_chunk_data(ori_inputs, 1)
for ori_inputs in track(chunked_inputs, description='Inference'):
    input = ori_inputs[0][1]
    input_img = input['inputs']
    input_image = normalize_tensor(input_img).unsqueeze(0)
    input_sample = input['data_samples']
    
    # backbone+encoder
    model_backbone_encoder = BackboneEncoderOnly(model)
    model_decoder = DecoderOnly(model)
    
    out_enc = model_backbone_encoder(input_image)
    data_samples = None
    
    N = out_enc.size(0)
    init_target_seq = torch.full((N, model_decoder.max_seq_len + 1),
                                model_decoder.padding_idx,
                                device=out_enc.device,
                                dtype=torch.long)
# bsz * seq_len
    init_target_seq[:, 0] = model_decoder.start_idx

    outputs = []
    for step in range(0, model_decoder.max_seq_len):
        valid_ratios = [1.0 for _ in range(out_enc.size(0))]
        if data_samples is not None:
            valid_ratios = []
            for data_sample in data_samples:
                valid_ratios.append(data_sample.get('valid_ratio'))
        
        src_mask = model_decoder._get_source_mask(out_enc, valid_ratios)
        step_result = model_decoder(init_target_seq,out_enc,src_mask,step)
        outputs.append(step_result)
        _, step_max_index = torch.max(step_result, dim=-1)
        init_target_seq[:, step + 1] = step_max_index
    outputs = torch.stack(outputs, dim=1)
    out_dec = model_decoder.softmax(outputs)
    output = model_decoder.postprocessor(out_dec, [input_sample])
    outstr = output[0].pred_text.item
    outscore = output[0].pred_text.score
    
    print('pred_text:',outstr)
    print('score:',outscore)