metadata
license: apache-2.0
language:
- zh
- ja
- ar
- en
- hi
metrics:
- accuracy
Language Identification
该模型是基于 AllenNLP 在 qgyd2021/language_identification 数据集上训练的语种识别模型。
测试代码:
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import time
from allennlp.models.archival import archive_model, load_archive
from allennlp.predictors.text_classifier import TextClassifierPredictor
from project_settings import project_path
def get_args():
"""
python3 step_5_predict_by_archive.py
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--text",
default="hello guy.",
type=str
)
parser.add_argument(
"--archive_file",
default=(project_path / "trained_models/language_identification").as_posix(),
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
archive = load_archive(archive_file=args.archive_file)
predictor = TextClassifierPredictor(
model=archive.model,
dataset_reader=archive.dataset_reader,
)
json_dict = {
"sentence": args.text
}
begin_time = time.time()
outputs = predictor.predict_json(
json_dict
)
label = outputs["label"]
prob = round(max(outputs["probs"]), 4)
print(label)
print(prob)
print('time cost: {}'.format(time.time() - begin_time))
return
if __name__ == '__main__':
main()
requirements.txt
allennlp==2.10.1
allennlp-models==2.10.1
torch==1.12.1
overrides==1.9.0
pytorch_pretrained_bert==0.6.2