File size: 1,720 Bytes
9bf0f66
 
3723df7
 
 
 
 
 
 
 
9bf0f66
3723df7
 
0ab3f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec3f716
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
license: apache-2.0
language:
- zh
- ja
- ar
- en
- hi
metrics:
- accuracy
---
## Language Identification

该模型是基于 AllenNLP 在 [qgyd2021/language_identification](https://huggingface.co/datasets/qgyd2021/language_identification) 数据集上训练的语种识别模型。


测试代码:
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import time

from allennlp.models.archival import archive_model, load_archive
from allennlp.predictors.text_classifier import TextClassifierPredictor

from project_settings import project_path


def get_args():
    """
    python3 step_5_predict_by_archive.py
    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text",
        default="hello guy.",
        type=str
    )
    parser.add_argument(
        "--archive_file",
        default=(project_path / "trained_models/language_identification").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    archive = load_archive(archive_file=args.archive_file)

    predictor = TextClassifierPredictor(
        model=archive.model,
        dataset_reader=archive.dataset_reader,
    )

    json_dict = {
        "sentence": args.text
    }

    begin_time = time.time()
    outputs = predictor.predict_json(
        json_dict
    )
    label = outputs["label"]
    prob = round(max(outputs["probs"]), 4)
    print(label)
    print(prob)

    print('time cost: {}'.format(time.time() - begin_time))
    return


if __name__ == '__main__':
    main()

```

requirements.txt
```text
allennlp==2.10.1
allennlp-models==2.10.1
torch==1.12.1
overrides==1.9.0
pytorch_pretrained_bert==0.6.2
```