qgyd2021's picture
Update README.md
ec3f716 verified
|
raw
history blame
1.72 kB
metadata
license: apache-2.0
language:
  - zh
  - ja
  - ar
  - en
  - hi
metrics:
  - accuracy

Language Identification

该模型是基于 AllenNLP 在 qgyd2021/language_identification 数据集上训练的语种识别模型。

测试代码:

#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import time

from allennlp.models.archival import archive_model, load_archive
from allennlp.predictors.text_classifier import TextClassifierPredictor

from project_settings import project_path


def get_args():
    """
    python3 step_5_predict_by_archive.py
    :return:
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text",
        default="hello guy.",
        type=str
    )
    parser.add_argument(
        "--archive_file",
        default=(project_path / "trained_models/language_identification").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    archive = load_archive(archive_file=args.archive_file)

    predictor = TextClassifierPredictor(
        model=archive.model,
        dataset_reader=archive.dataset_reader,
    )

    json_dict = {
        "sentence": args.text
    }

    begin_time = time.time()
    outputs = predictor.predict_json(
        json_dict
    )
    label = outputs["label"]
    prob = round(max(outputs["probs"]), 4)
    print(label)
    print(prob)

    print('time cost: {}'.format(time.time() - begin_time))
    return


if __name__ == '__main__':
    main()

requirements.txt

allennlp==2.10.1
allennlp-models==2.10.1
torch==1.12.1
overrides==1.9.0
pytorch_pretrained_bert==0.6.2