| import torch |
| import librosa |
| from datasets import load_dataset, Audio |
| from transformers import WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration |
| from huggingface_hub import login |
| import argparse |
| from evaluate import load |
|
|
| my_parser = argparse.ArgumentParser() |
| |
|
|
| my_parser.add_argument("--model_name", "-model_name", type=str, action="store", default = "openai/whisper-tiny") |
| my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store") |
| my_parser.add_argument("--dataset_name", "-dataset_name", type=str, action="store", default = "google/fleurs") |
| my_parser.add_argument("--split", "-split", type=str, action="store", default = "test") |
| my_parser.add_argument("--subset", "-subset", type=str, action="store") |
|
|
| args = my_parser.parse_args() |
| try: |
| login(args.hf_token) |
| except: |
| raise(f"Can't login please set --hf_token {args.hf_token}") |
|
|
|
|
| dataset_name = args.dataset_name |
| model_name = args.model_name |
| subset = args.subset |
| text_column = "sentence" |
| if dataset_name == "google/fleurs": |
| text_column = "transcription" |
| |
| print(f"Evaluating {args.model_name} on {args.dataset_name} [{subset}]") |
|
|
|
|
| feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) |
| model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
|
| test_dataset = load_dataset(dataset_name, subset, split=args.split, use_auth_token=True) |
| processor = WhisperProcessor.from_pretrained(model_name, language="Arabic", task="transcribe") |
| tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Arabic", task="transcribe") |
| test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
| |
| def prepare_dataset(batch): |
| |
| audio = batch["audio"] |
|
|
| |
| batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] |
|
|
| |
| batch["labels"] = tokenizer(batch[text_column]).input_ids |
| return batch |
|
|
| test_dataset = test_dataset.map(prepare_dataset) |
|
|
| model = model.to("cuda") |
| model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe") |
|
|
| def map_to_result(batch): |
|
|
| with torch.no_grad(): |
| input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0) |
| pred_ids = model.generate(input_values) |
|
|
| batch["pred_str"] = processor.batch_decode(pred_ids, skip_special_tokens = True)[0] |
| batch["text"] = processor.decode(batch["labels"], skip_special_tokens = True) |
| |
| return batch |
| results = test_dataset.map(map_to_result) |
|
|
| wer = load("wer") |
| print("Test WER: {:.3f}".format(wer.compute(predictions=results["pred_str"], references=results["text"]))) |
|
|