matejhornik commited on
Commit
ce588ec
·
verified ·
1 Parent(s): 10ded47

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -112
README.md CHANGED
@@ -135,127 +135,49 @@ For detailed training logs, metrics, and visualizations, please refer to the Wei
135
 
136
  ## How to Use
137
 
138
- You can use this model for inference with the Hugging Face `transformers` library. Make sure you have `torchaudio` and `librosa` (or `soundfile`) installed for audio processing.
 
139
 
140
  ```python
141
  from transformers import SpeechEncoderDecoderModel, AutoProcessor
142
  import torch
143
- import soundfile as sf
144
-
145
- model_id = "matejhornik/wav2vec2-base_bart-base_voxpopuli-en"
146
- device = "cuda" if torch.cuda.is_available() else "cpu"
147
-
148
- # Load the processor (feature extractor and tokenizer)
149
- processor = AutoProcessor.from_pretrained(model_id)
150
-
151
- # Load the model
152
- model = SpeechEncoderDecoderModel.from_pretrained(model_id).to(device)
153
-
154
- def transcribe_audio(audio_path):
155
- """Loads audio, processes it, and transcribes it."""
156
- speech_array, sampling_rate = sf.read(audio_path)
157
-
158
- # Ensure audio is 16kHz as expected by the model
159
- if sampling_rate != processor.feature_extractor.sampling_rate:
160
- raise ValueError(f"Audio sampling rate {sampling_rate} does not match model's required {processor.feature_extractor.sampling_rate}Hz. Please resample.")
161
-
162
- # Preprocess the audio
163
- inputs = processor(speech_array, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt", padding=True)
164
- input_features = inputs.input_features.to(device)
165
- attention_mask = inputs.attention_mask.to(device)
166
-
167
- # Generate transcription
168
- with torch.no_grad():
169
- predicted_ids = model.generate(input_features, attention_mask=attention_mask, max_length=128)
170
-
171
- # Decode the transcription
172
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
173
- return transcription[0]
174
-
175
- # Example usage:
176
- audio_file_path = "path/to/your/audio.wav"
177
- try:
178
- transcription = transcribe_audio(audio_file_path)
179
- print(f"Transcription: {transcription}")
180
- except ValueError as e:
181
- print(e)
182
- except FileNotFoundError:
183
- print(f"Audio file not found at: {audio_file_path}. Please provide a valid path.")
184
- ```
185
-
186
- ## Reproducing Evaluation on VoxPopuli
187
- To reproduce the evaluation results on the VoxPopuli test set:
188
-
189
- ```python
190
  from datasets import load_dataset
191
- from transformers import SpeechEncoderDecoderModel, AutoProcessor
192
- import torch
193
- from jiwer import wer
194
- from tqdm import tqdm
195
 
196
- model_id = "matejhornik/wav2vec2-base_bart-base_voxpopuli-en"
197
- dataset_name = "facebook/voxpopuli"
198
- dataset_config = "en"
199
- split = "test" # or "validation"
200
 
201
  device = "cuda" if torch.cuda.is_available() else "cpu"
202
-
203
- # Load processor and model
204
- processor = AutoProcessor.from_pretrained(model_id)
205
- model = SpeechEncoderDecoderModel.from_pretrained(model_id).to(device)
206
- model.eval() # Set model to evaluation mode
207
-
208
- # Load dataset
209
- # Note: You might need to authenticate with Hugging Face if the dataset requires it
210
- # from huggingface_hub import login
211
- voxpopuli_test = load_dataset(dataset_name, dataset_config, split=split, streaming=False) # Set streaming=True for large datasets if memory is an issue
212
-
213
- # Preprocessing function
214
- def map_to_pred(batch):
215
- # Ensure audio is in the correct format (array, 16kHz)
216
- audio_data = batch["audio"]["array"]
217
- sampling_rate = batch["audio"]["sampling_rate"]
218
-
219
- if sampling_rate != processor.feature_extractor.sampling_rate:
220
- print(f"Warning: Resampling needed or sample skipped for audio with rate {sampling_rate}")
221
- # Dummy processing for now if rate mismatch
222
- input_features = torch.zeros((1,1000)) # Placeholder
223
- else:
224
- inputs = processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
225
- input_features = inputs.input_features.to(device)
226
-
227
- with torch.no_grad():
228
- predicted_ids = model.generate(input_features, max_length=128)
229
-
230
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
231
- batch["prediction"] = transcription[0]
232
- batch["reference"] = batch["normalized_text"]
233
- return batch
234
-
235
-
236
- predictions = []
237
- references = []
238
-
239
- for sample in tqdm(voxpopuli_test):
240
- try:
241
- processed_sample = map_to_pred(sample)
242
- predictions.append(processed_sample["prediction"])
243
- references.append(processed_sample["reference"])
244
- except Exception as e:
245
- print(f"Error processing sample: {e}")
246
-
247
-
248
- # Calculate WER
249
- if predictions and references:
250
- current_wer = wer(references, predictions)
251
- print(f"WER on {split} set: {current_wer:.4f}")
252
- else:
253
- print("No samples processed or an error occurred.")
254
-
255
- # Expected WER on test set: 0.0885
256
- # Expected WER on validation set: 0.0855
257
  ```
258
 
 
 
259
  ### Framework Versions
260
 
261
  This model was trained using:
@@ -268,6 +190,8 @@ This model was trained using:
268
  - Evaluate: `^0.4.3`
269
  - WandB: `^0.19.7`
270
 
 
 
271
  ## Citation
272
  Citation
273
  If you use this model or findings from the thesis, please cite:
@@ -295,4 +219,6 @@ If you use this model or findings from the thesis, please cite:
295
  For questions, feedback, or collaboration opportunities related to this thesis or any other stuff, feel free to reach out:
296
 
297
298
- - **GitHub:** [hornikmatej](https://github.com/hornikmatej)
 
 
 
135
 
136
  ## How to Use
137
 
138
+ You can use this model for inference with the Hugging Face `transformers` library.
139
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hornikmatej/thesis_mit/blob/main/graphs/colab_ntb.ipynb)
140
 
141
  ```python
142
  from transformers import SpeechEncoderDecoderModel, AutoProcessor
143
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  from datasets import load_dataset
 
 
 
 
145
 
146
+ MODEL_ID = "matejhornik/wav2vec2-base_bart-base_voxpopuli-en"
147
+ DATASET_ID = "facebook/voxpopuli"
148
+ DATASET_CONFIG = "en"
149
+ DATASET_SPLIT = "test" # "validation"
150
 
151
  device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
153
+ model = SpeechEncoderDecoderModel.from_pretrained(MODEL_ID).to(device)
154
+
155
+ print(f"Using device: {device}\nStreaming one sample from '{DATASET_ID}'"
156
+ "(config: '{DATASET_CONFIG}', split: '{DATASET_SPLIT}')...")
157
+ streamed_dataset = load_dataset(
158
+ DATASET_ID,
159
+ DATASET_CONFIG,
160
+ split=DATASET_SPLIT,
161
+ streaming=True,
162
+ )
163
+ sample = next(iter(streamed_dataset))
164
+
165
+ audio_input = sample["audio"]["array"]
166
+ input_sampling_rate = sample["audio"]["sampling_rate"]
167
+
168
+ inputs = processor(audio_input, sampling_rate=input_sampling_rate, return_tensors="pt", padding=True)
169
+ input_features = inputs.input_values.to(device)
170
+
171
+ with torch.no_grad():
172
+ predicted_ids = model.generate(input_features, max_length=128)
173
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
174
+
175
+ print(f"\nOriginal: {sample['normalized_text']}")
176
+ print(f"Transcribed: {transcription}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  ```
178
 
179
+
180
+
181
  ### Framework Versions
182
 
183
  This model was trained using:
 
190
  - Evaluate: `^0.4.3`
191
  - WandB: `^0.19.7`
192
 
193
+ Visit the [pyproject.toml](https://github.com/hornikmatej/thesis_mit/blob/main/pyproject.toml) file for a complete list of dependencies.
194
+
195
  ## Citation
196
  Citation
197
  If you use this model or findings from the thesis, please cite:
 
219
  For questions, feedback, or collaboration opportunities related to this thesis or any other stuff, feel free to reach out:
220
 
221
222
+ - **GitHub:** [hornikmatej](https://github.com/hornikmatej)
223
+
224
+