Update README.md
Browse files
README.md
CHANGED
|
@@ -239,7 +239,6 @@ pipe = pipeline(
|
|
| 239 |
model_kwargs=model_kwargs
|
| 240 |
)
|
| 241 |
|
| 242 |
-
|
| 243 |
# load sample audio & downsample to 16kHz
|
| 244 |
dataset = load_dataset("japanese-asr/ja_asr.reazonspeech_test", split="test")
|
| 245 |
|
|
@@ -296,60 +295,45 @@ pip install --upgrade transformers datasets[audio] evaluate jiwer
|
|
| 296 |
Evaluation can then be run end-to-end with the following example:
|
| 297 |
|
| 298 |
```python
|
| 299 |
-
from
|
|
|
|
|
|
|
|
|
|
| 300 |
from datasets import load_dataset, Audio
|
| 301 |
from evaluate import load
|
| 302 |
-
import torch
|
| 303 |
-
from tqdm import tqdm
|
| 304 |
|
| 305 |
-
# config
|
| 306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
| 307 |
-
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
| 308 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 309 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
audio_column = 'audio'
|
| 311 |
text_column = 'transcription'
|
| 312 |
-
batch_size = 16
|
| 313 |
|
| 314 |
# load model
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
# load the dataset and sample the audio with 16kHz
|
| 320 |
dataset = load_dataset(dataset_name, split="test")
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def inference(batch):
|
| 326 |
-
# 1. Pre-process the audio data to log-mel spectrogram inputs
|
| 327 |
-
audio = [sample["array"] for sample in batch["audio"]]
|
| 328 |
-
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
|
| 329 |
-
input_features = input_features.to(device, dtype=torch_dtype)
|
| 330 |
-
# 2. Auto-regressively generate the predicted token ids
|
| 331 |
-
pred_ids = model.generate(input_features, language="ja", max_new_tokens=128)
|
| 332 |
-
# 3. Decode the token ids to the final transcription
|
| 333 |
-
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
|
| 334 |
-
batch["reference"] = batch[text_column]
|
| 335 |
-
return batch
|
| 336 |
-
|
| 337 |
-
dataset = dataset.map(function=inference, batched=True, batch_size=batch_size)
|
| 338 |
-
|
| 339 |
-
# iterate over the dataset and run inference
|
| 340 |
-
all_transcriptions = []
|
| 341 |
-
all_references = []
|
| 342 |
-
for result in tqdm(dataset, desc="Evaluating..."):
|
| 343 |
-
all_transcriptions.append(result["transcription"])
|
| 344 |
-
all_references.append(result["reference"])
|
| 345 |
-
|
| 346 |
-
# normalize predictions and references
|
| 347 |
-
all_transcriptions = [transcription.replace(" ", "") for transcription in all_transcriptions]
|
| 348 |
-
all_references = [reference.replace(" ", "") for reference in all_references]
|
| 349 |
|
| 350 |
# compute the CER metric
|
| 351 |
cer_metric = load("cer")
|
| 352 |
-
cer = 100 * cer_metric.compute(predictions=
|
| 353 |
print(cer)
|
| 354 |
```
|
| 355 |
|
|
|
|
| 239 |
model_kwargs=model_kwargs
|
| 240 |
)
|
| 241 |
|
|
|
|
| 242 |
# load sample audio & downsample to 16kHz
|
| 243 |
dataset = load_dataset("japanese-asr/ja_asr.reazonspeech_test", split="test")
|
| 244 |
|
|
|
|
| 295 |
Evaluation can then be run end-to-end with the following example:
|
| 296 |
|
| 297 |
```python
|
| 298 |
+
from tqdm import tqdm
|
| 299 |
+
|
| 300 |
+
import torch
|
| 301 |
+
from transformers import pipeline
|
| 302 |
from datasets import load_dataset, Audio
|
| 303 |
from evaluate import load
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
# model config
|
| 306 |
model_id = "kotoba-tech/kotoba-whisper-v1.0"
|
|
|
|
| 307 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 308 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 309 |
+
model_kwargs = {"attn_implementation": "sdpa"} if torch.cuda.is_available() else {}
|
| 310 |
+
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
| 311 |
+
|
| 312 |
+
# data config
|
| 313 |
+
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
| 314 |
+
dataset_name = "japanese-asr/ja_asr.reazonspeech_test"
|
| 315 |
audio_column = 'audio'
|
| 316 |
text_column = 'transcription'
|
|
|
|
| 317 |
|
| 318 |
# load model
|
| 319 |
+
pipe = pipeline(
|
| 320 |
+
"automatic-speech-recognition",
|
| 321 |
+
model=model_id,
|
| 322 |
+
torch_dtype=torch_dtype,
|
| 323 |
+
device=device,
|
| 324 |
+
model_kwargs=model_kwargs,
|
| 325 |
+
batch_size=16
|
| 326 |
+
)
|
| 327 |
|
| 328 |
# load the dataset and sample the audio with 16kHz
|
| 329 |
dataset = load_dataset(dataset_name, split="test")
|
| 330 |
+
transcriptions = pipe(dataset['audio'])
|
| 331 |
+
transcriptions = [i['text'].replace(" ", "") for i in transcriptions]
|
| 332 |
+
references = [i.replace(" ", "") for i in dataset['transcription']]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# compute the CER metric
|
| 335 |
cer_metric = load("cer")
|
| 336 |
+
cer = 100 * cer_metric.compute(predictions=transcriptions, references=references)
|
| 337 |
print(cer)
|
| 338 |
```
|
| 339 |
|