File size: 1,720 Bytes
9bf0f66 3723df7 9bf0f66 3723df7 0ab3f49 ec3f716 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
---
license: apache-2.0
language:
- zh
- ja
- ar
- en
- hi
metrics:
- accuracy
---
## Language Identification
该模型是基于 AllenNLP 在 [qgyd2021/language_identification](https://huggingface.co/datasets/qgyd2021/language_identification) 数据集上训练的语种识别模型。
测试代码:
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import time
from allennlp.models.archival import archive_model, load_archive
from allennlp.predictors.text_classifier import TextClassifierPredictor
from project_settings import project_path
def get_args():
"""
python3 step_5_predict_by_archive.py
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
default="hello guy.",
type=str
)
parser.add_argument(
"--archive_file",
default=(project_path / "trained_models/language_identification").as_posix(),
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
archive = load_archive(archive_file=args.archive_file)
predictor = TextClassifierPredictor(
model=archive.model,
dataset_reader=archive.dataset_reader,
)
json_dict = {
"sentence": args.text
}
begin_time = time.time()
outputs = predictor.predict_json(
json_dict
)
label = outputs["label"]
prob = round(max(outputs["probs"]), 4)
print(label)
print(prob)
print('time cost: {}'.format(time.time() - begin_time))
return
if __name__ == '__main__':
main()
```
requirements.txt
```text
allennlp==2.10.1
allennlp-models==2.10.1
torch==1.12.1
overrides==1.9.0
pytorch_pretrained_bert==0.6.2
```
|