Upload KotobaWhisperPipeline
Browse files- kotoba_whisper.py +130 -269
    	
        kotoba_whisper.py
    CHANGED
    
    | @@ -1,6 +1,5 @@ | |
| 1 | 
             
            import requests
         | 
| 2 | 
            -
            from typing import Union, Optional, Dict | 
| 3 | 
            -
            from collections import defaultdict
         | 
| 4 |  | 
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import numpy as np
         | 
| @@ -24,25 +23,13 @@ class Punctuator: | |
| 24 | 
             
                def __init__(self, model: str = "pcs_47lang"):
         | 
| 25 | 
             
                    self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
         | 
| 26 |  | 
| 27 | 
            -
                def punctuate(self,  | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
                         | 
| 33 | 
            -
             | 
| 34 | 
            -
                            punctuated = punctuated.replace("。", "")
         | 
| 35 | 
            -
                            punctuated = punctuated[:ind] + "。" + punctuated[ind:]
         | 
| 36 | 
            -
                        return punctuated
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                    text_edit = self.punctuation_model.infer([c['text'] for c in pipeline_chunk])
         | 
| 39 | 
            -
                    return [
         | 
| 40 | 
            -
                        {
         | 
| 41 | 
            -
                            'timestamp': c['timestamp'],
         | 
| 42 | 
            -
                            'speaker': c['speaker'],
         | 
| 43 | 
            -
                            'text': validate_punctuation(c['text'], "".join(e))
         | 
| 44 | 
            -
                        } for c, e in zip(pipeline_chunk, text_edit)
         | 
| 45 | 
            -
                    ]
         | 
| 46 |  | 
| 47 |  | 
| 48 | 
             
            class SpeakerDiarization:
         | 
| @@ -114,104 +101,68 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline): | |
| 114 | 
             
                    )
         | 
| 115 |  | 
| 116 | 
             
                def _sanitize_parameters(self,
         | 
| 117 | 
            -
                                         chunk_length_s=None,
         | 
| 118 | 
            -
                                         stride_length_s=None,
         | 
| 119 | 
            -
                                          | 
| 120 | 
            -
                                          | 
| 121 | 
            -
                                          | 
| 122 | 
            -
                                          | 
| 123 | 
            -
                                          | 
| 124 | 
            -
                                          | 
| 125 | 
            -
                                         add_punctuation: bool =False,
         | 
| 126 | 
            -
                                         return_unique_speaker: bool =True,
         | 
| 127 | 
             
                                         num_speakers: Optional[int] = None,
         | 
| 128 | 
             
                                         min_speakers: Optional[int] = None,
         | 
| 129 | 
             
                                         max_speakers: Optional[int] = None):
         | 
| 130 | 
            -
                     | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
                         | 
| 134 | 
            -
             | 
| 135 | 
            -
                         | 
| 136 | 
            -
             | 
| 137 | 
            -
             | 
| 138 | 
            -
                     | 
| 139 | 
            -
             | 
| 140 | 
            -
                    if generate_kwargs is  | 
| 141 | 
            -
             | 
| 142 | 
            -
                            raise ValueError(
         | 
| 143 | 
            -
                                "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
         | 
| 144 | 
            -
                                " only 1 version"
         | 
| 145 | 
            -
                            )
         | 
| 146 | 
            -
                        forward_params.update(generate_kwargs)
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                    postprocess_params = {}
         | 
| 149 | 
            -
                    if decoder_kwargs is not None:
         | 
| 150 | 
            -
                        postprocess_params["decoder_kwargs"] = decoder_kwargs
         | 
| 151 | 
            -
                    if return_timestamps is not None:
         | 
| 152 | 
            -
                        # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
         | 
| 153 | 
            -
                        if self.type == "seq2seq" and return_timestamps:
         | 
| 154 | 
            -
                            raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
         | 
| 155 | 
            -
                        if self.type == "ctc_with_lm" and return_timestamps != "word":
         | 
| 156 | 
            -
                            raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
         | 
| 157 | 
            -
                        if self.type == "ctc" and return_timestamps not in ["char", "word"]:
         | 
| 158 | 
            -
                            raise ValueError(
         | 
| 159 | 
            -
                                "CTC can either predict character level timestamps, or word level timestamps. "
         | 
| 160 | 
            -
                                "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
         | 
| 161 | 
            -
                            )
         | 
| 162 | 
            -
                        if self.type == "seq2seq_whisper" and return_timestamps == "char":
         | 
| 163 | 
            -
                            raise ValueError(
         | 
| 164 | 
            -
                                "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
         | 
| 165 | 
            -
                                "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
         | 
| 166 | 
            -
                            )
         | 
| 167 | 
            -
                        forward_params["return_timestamps"] = return_timestamps
         | 
| 168 | 
            -
                        postprocess_params["return_timestamps"] = return_timestamps
         | 
| 169 | 
            -
                    if return_language is not None:
         | 
| 170 | 
            -
                        if self.type != "seq2seq_whisper":
         | 
| 171 | 
            -
                            raise ValueError("Only Whisper can return language for now.")
         | 
| 172 | 
            -
                        postprocess_params["return_language"] = return_language
         | 
| 173 | 
            -
                    postprocess_params["return_language"] = return_language
         | 
| 174 | 
            -
                    postprocess_params["add_punctuation"] = add_punctuation
         | 
| 175 | 
            -
                    postprocess_params["return_unique_speaker"] = return_unique_speaker
         | 
| 176 | 
            -
                    postprocess_params["num_speakers"] = num_speakers
         | 
| 177 | 
            -
                    postprocess_params["min_speakers"] = min_speakers
         | 
| 178 | 
            -
                    postprocess_params["max_speakers"] = max_speakers
         | 
| 179 | 
             
                    return preprocess_params, forward_params, postprocess_params
         | 
| 180 |  | 
| 181 | 
            -
                def preprocess(self, | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 182 | 
             
                    if isinstance(inputs, str):
         | 
| 183 | 
             
                        if inputs.startswith("http://") or inputs.startswith("https://"):
         | 
| 184 | 
            -
                            # We need to actually check for a real protocol, otherwise it's impossible to use a local file
         | 
| 185 | 
            -
                            # like http_huggingface_co.png
         | 
| 186 | 
             
                            inputs = requests.get(inputs).content
         | 
| 187 | 
             
                        else:
         | 
| 188 | 
             
                            with open(inputs, "rb") as f:
         | 
| 189 | 
             
                                inputs = f.read()
         | 
| 190 | 
            -
             | 
| 191 | 
             
                    if isinstance(inputs, bytes):
         | 
| 192 | 
             
                        inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                    stride = None
         | 
| 195 | 
            -
                    extra = {}
         | 
| 196 | 
             
                    if isinstance(inputs, dict):
         | 
| 197 | 
            -
                         | 
| 198 | 
            -
                         | 
| 199 | 
            -
                        # better integration
         | 
| 200 | 
            -
                        if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
         | 
| 201 | 
             
                            raise ValueError(
         | 
| 202 | 
             
                                "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
         | 
| 203 | 
            -
                                '" | 
| 204 | 
             
                                "containing the sampling_rate associated with that array"
         | 
| 205 | 
             
                            )
         | 
| 206 | 
            -
             | 
| 207 | 
            -
                        _inputs = inputs.pop("raw", None)
         | 
| 208 | 
            -
                        if _inputs is None:
         | 
| 209 | 
            -
                            # Remove path which will not be used from `datasets`.
         | 
| 210 | 
            -
                            inputs.pop("path", None)
         | 
| 211 | 
            -
                            _inputs = inputs.pop("array", None)
         | 
| 212 | 
             
                        in_sampling_rate = inputs.pop("sampling_rate")
         | 
| 213 | 
            -
                         | 
| 214 | 
            -
                        inputs = _inputs
         | 
| 215 | 
             
                        if in_sampling_rate != self.feature_extractor.sampling_rate:
         | 
| 216 | 
             
                            if is_torchaudio_available():
         | 
| 217 | 
             
                                from torchaudio import functional as F
         | 
| @@ -220,190 +171,100 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline): | |
| 220 | 
             
                                    "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
         | 
| 221 | 
             
                                    "The torchaudio package can be installed through: `pip install torchaudio`."
         | 
| 222 | 
             
                                )
         | 
| 223 | 
            -
             | 
| 224 | 
             
                            inputs = F.resample(
         | 
| 225 | 
             
                                torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
         | 
| 226 | 
             
                            ).numpy()
         | 
| 227 | 
            -
                            ratio = self.feature_extractor.sampling_rate / in_sampling_rate
         | 
| 228 | 
            -
                        else:
         | 
| 229 | 
            -
                            ratio = 1
         | 
| 230 | 
            -
                        if stride is not None:
         | 
| 231 | 
            -
                            if stride[0] + stride[1] > inputs.shape[0]:
         | 
| 232 | 
            -
                                raise ValueError("Stride is too large for input")
         | 
| 233 |  | 
| 234 | 
            -
             | 
| 235 | 
            -
                            # swallowed by the `feature_extractor` later, and then batching
         | 
| 236 | 
            -
                            # can add extra data in the inputs, so we need to keep track
         | 
| 237 | 
            -
                            # of the original length in the stride so we can cut properly.
         | 
| 238 | 
            -
                            stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
         | 
| 239 | 
             
                    if not isinstance(inputs, np.ndarray):
         | 
| 240 | 
             
                        raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
         | 
| 241 | 
             
                    if len(inputs.shape) != 1:
         | 
| 242 | 
             
                        raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
         | 
| 243 |  | 
| 244 | 
            -
                     | 
| 245 | 
            -
                        if stride_length_s is None:
         | 
| 246 | 
            -
                            stride_length_s = chunk_length_s / 6
         | 
| 247 | 
            -
             | 
| 248 | 
            -
                        if isinstance(stride_length_s, (int, float)):
         | 
| 249 | 
            -
                            stride_length_s = [stride_length_s, stride_length_s]
         | 
| 250 | 
            -
             | 
| 251 | 
            -
                        # XXX: Carefuly, this variable will not exist in `seq2seq` setting.
         | 
| 252 | 
            -
                        # Currently chunking is not possible at this level for `seq2seq` so
         | 
| 253 | 
            -
                        # it's ok.
         | 
| 254 | 
            -
                        align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
         | 
| 255 | 
            -
                        chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 256 | 
            -
                        stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 257 | 
            -
                        stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                        if chunk_len < stride_left + stride_right:
         | 
| 260 | 
            -
                            raise ValueError("Chunk length must be superior to stride length")
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                        for item in chunk_iter(
         | 
| 263 | 
            -
                                inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
         | 
| 264 | 
            -
                        ):
         | 
| 265 | 
            -
                            item["audio_array"] = inputs
         | 
| 266 | 
            -
                            yield item
         | 
| 267 | 
            -
                    else:
         | 
| 268 | 
            -
                        if inputs.shape[0] > self.feature_extractor.n_samples:
         | 
| 269 | 
            -
                            processed = self.feature_extractor(
         | 
| 270 | 
            -
                                inputs,
         | 
| 271 | 
            -
                                sampling_rate=self.feature_extractor.sampling_rate,
         | 
| 272 | 
            -
                                truncation=False,
         | 
| 273 | 
            -
                                padding="longest",
         | 
| 274 | 
            -
                                return_tensors="pt",
         | 
| 275 | 
            -
                            )
         | 
| 276 | 
            -
                        else:
         | 
| 277 | 
            -
                            processed = self.feature_extractor(
         | 
| 278 | 
            -
                                inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
         | 
| 279 | 
            -
                            )
         | 
| 280 | 
            -
             | 
| 281 | 
            -
                        if self.torch_dtype is not None:
         | 
| 282 | 
            -
                            processed = processed.to(dtype=self.torch_dtype)
         | 
| 283 | 
            -
                        if stride is not None:
         | 
| 284 | 
            -
                            processed["stride"] = stride
         | 
| 285 | 
            -
                        yield {"is_last": True, "audio_array": inputs, **processed, **extra}
         | 
| 286 | 
            -
             | 
| 287 | 
            -
                def _forward(self, model_inputs, **generate_kwargs):
         | 
| 288 | 
            -
                    attention_mask = model_inputs.pop("attention_mask", None)
         | 
| 289 | 
            -
                    stride = model_inputs.pop("stride", None)
         | 
| 290 | 
            -
                    is_last = model_inputs.pop("is_last")
         | 
| 291 | 
            -
                    audio_array = model_inputs.pop("audio_array")
         | 
| 292 | 
            -
                    encoder = self.model.get_encoder()
         | 
| 293 | 
            -
                    # Consume values so we can let extra information flow freely through
         | 
| 294 | 
            -
                    # the pipeline (important for `partial` in microphone)
         | 
| 295 | 
            -
                    if "input_features" in model_inputs:
         | 
| 296 | 
            -
                        inputs = model_inputs.pop("input_features")
         | 
| 297 | 
            -
                    elif "input_values" in model_inputs:
         | 
| 298 | 
            -
                        inputs = model_inputs.pop("input_values")
         | 
| 299 | 
            -
                    else:
         | 
| 300 | 
            -
                        raise ValueError(
         | 
| 301 | 
            -
                            "Seq2Seq speech recognition model requires either a "
         | 
| 302 | 
            -
                            f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
         | 
| 303 | 
            -
                        )
         | 
| 304 | 
            -
             | 
| 305 | 
            -
                    # custom processing for Whisper timestamps and word-level timestamps
         | 
| 306 | 
            -
                    generate_kwargs["return_timestamps"] = True
         | 
| 307 | 
            -
                    if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
         | 
| 308 | 
            -
                        generate_kwargs["input_features"] = inputs
         | 
| 309 | 
            -
                    else:
         | 
| 310 | 
            -
                        generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
         | 
| 311 | 
            -
             | 
| 312 | 
            -
                    tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
         | 
| 313 | 
            -
                    # whisper longform generation stores timestamps in "segments"
         | 
| 314 | 
            -
                    out = {"tokens": tokens}
         | 
| 315 | 
            -
                    if self.type == "seq2seq_whisper":
         | 
| 316 | 
            -
                        if stride is not None:
         | 
| 317 | 
            -
                            out["stride"] = stride
         | 
| 318 | 
            -
             | 
| 319 | 
            -
                    # Leftover
         | 
| 320 | 
            -
                    extra = model_inputs
         | 
| 321 | 
            -
                    return {"is_last": is_last, "audio_array": audio_array, **out, **extra}
         | 
| 322 | 
            -
             | 
| 323 | 
            -
                def postprocess(self,
         | 
| 324 | 
            -
                                model_outputs,
         | 
| 325 | 
            -
                                decoder_kwargs: Optional[Dict] = None,
         | 
| 326 | 
            -
                                return_language=None,
         | 
| 327 | 
            -
                                add_punctuation: bool = False,
         | 
| 328 | 
            -
                                return_unique_speaker: bool = True,
         | 
| 329 | 
            -
                                num_speakers: Optional[int] = None,
         | 
| 330 | 
            -
                                min_speakers: Optional[int] = None,
         | 
| 331 | 
            -
                                max_speakers: Optional[int] = None,
         | 
| 332 | 
            -
                                *args,
         | 
| 333 | 
            -
                                **kwargs):
         | 
| 334 | 
            -
                    assert len(model_outputs) > 0
         | 
| 335 | 
            -
                    outputs = super().postprocess(
         | 
| 336 | 
            -
                        model_outputs=model_outputs,
         | 
| 337 | 
            -
                        decoder_kwargs=decoder_kwargs,
         | 
| 338 | 
            -
                        return_timestamps=True,
         | 
| 339 | 
            -
                        return_language=return_language
         | 
| 340 | 
            -
                    )
         | 
| 341 | 
            -
                    audio_array = outputs.pop("audio_array")[0]
         | 
| 342 | 
             
                    sd = self.model_speaker_diarization(
         | 
| 343 | 
            -
                         | 
| 344 | 
             
                        num_speakers=num_speakers,
         | 
| 345 | 
             
                        min_speakers=min_speakers,
         | 
| 346 | 
             
                        max_speakers=max_speakers,
         | 
| 347 | 
             
                        sampling_rate=self.feature_extractor.sampling_rate
         | 
| 348 | 
             
                    )
         | 
| 349 | 
            -
                    diarization_result = {s: [[i.start, i.end] for i in sd.label_timeline(s)] for s in sd.labels()}
         | 
| 350 | 
            -
                    timelines = sd.get_timeline()
         | 
| 351 | 
            -
             | 
| 352 | 
            -
                    pointer_ts = 0
         | 
| 353 | 
            -
                    pointer_chunk = 0
         | 
| 354 | 
            -
                    new_chunks = []
         | 
| 355 | 
            -
                    while True:
         | 
| 356 | 
            -
                        if pointer_ts == len(timelines):
         | 
| 357 | 
            -
                            ts = timelines[-1]
         | 
| 358 | 
            -
                            for chunk in outputs["chunks"][pointer_chunk:]:
         | 
| 359 | 
            -
                                chunk["speaker"] = sd.get_labels(ts)
         | 
| 360 | 
            -
                                new_chunks.append(chunk)
         | 
| 361 | 
            -
                            break
         | 
| 362 | 
            -
                        if pointer_chunk == len(outputs["chunks"]):
         | 
| 363 | 
            -
                            break
         | 
| 364 | 
            -
                        ts = timelines[pointer_ts]
         | 
| 365 | 
            -
             | 
| 366 | 
            -
                        chunk = outputs["chunks"][pointer_chunk]
         | 
| 367 | 
            -
                        if "speaker" not in chunk:
         | 
| 368 | 
            -
                            chunk["speaker"] = []
         | 
| 369 |  | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
                         | 
| 374 | 
            -
             | 
| 375 | 
            -
             | 
| 376 | 
            -
                             | 
| 377 | 
            -
                             | 
| 378 | 
            -
             | 
| 379 | 
            -
                             | 
| 380 | 
            -
             | 
| 381 | 
            -
                                 | 
| 382 | 
            -
                                 | 
| 383 | 
            -
             | 
| 384 | 
            -
                                 | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 389 | 
             
                            else:
         | 
| 390 | 
            -
                                 | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
             | 
| 394 | 
            -
             | 
| 395 | 
            -
             | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
             | 
| 399 | 
            -
             | 
| 400 | 
            -
             | 
| 401 | 
            -
             | 
| 402 | 
            -
             | 
| 403 | 
            -
             | 
| 404 | 
            -
             | 
| 405 | 
            -
             | 
| 406 | 
            -
             | 
| 407 | 
            -
             | 
| 408 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 409 | 
             
                    return outputs
         | 
|  | |
| 1 | 
             
            import requests
         | 
| 2 | 
            +
            from typing import Union, Optional, Dict
         | 
|  | |
| 3 |  | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import numpy as np
         | 
|  | |
| 23 | 
             
                def __init__(self, model: str = "pcs_47lang"):
         | 
| 24 | 
             
                    self.punctuation_model = PunctCapSegModelONNX.from_pretrained(model)
         | 
| 25 |  | 
| 26 | 
            +
                def punctuate(self, text: str) -> str:
         | 
| 27 | 
            +
                    if any(p in text for p in self.ja_punctuations):
         | 
| 28 | 
            +
                        return text
         | 
| 29 | 
            +
                    punctuated = "".join(self.punctuation_model.infer([text])[0])
         | 
| 30 | 
            +
                    if 'unk' in punctuated.lower():
         | 
| 31 | 
            +
                        return text
         | 
| 32 | 
            +
                    return punctuated
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 33 |  | 
| 34 |  | 
| 35 | 
             
            class SpeakerDiarization:
         | 
|  | |
| 101 | 
             
                    )
         | 
| 102 |  | 
| 103 | 
             
                def _sanitize_parameters(self,
         | 
| 104 | 
            +
                                         chunk_length_s: Optional[int] = None,
         | 
| 105 | 
            +
                                         stride_length_s: Optional[int] = None,
         | 
| 106 | 
            +
                                         generate_kwargs: Optional[Dict] = None,
         | 
| 107 | 
            +
                                         max_new_tokens: Optional[int] = None,
         | 
| 108 | 
            +
                                         add_punctuation: bool = False,
         | 
| 109 | 
            +
                                         return_unique_speaker: bool = True,
         | 
| 110 | 
            +
                                         add_silence_end: Optional[float] = None,
         | 
| 111 | 
            +
                                         add_silence_start: Optional[float] = None,
         | 
|  | |
|  | |
| 112 | 
             
                                         num_speakers: Optional[int] = None,
         | 
| 113 | 
             
                                         min_speakers: Optional[int] = None,
         | 
| 114 | 
             
                                         max_speakers: Optional[int] = None):
         | 
| 115 | 
            +
                    preprocess_params = {
         | 
| 116 | 
            +
                        "chunk_length_s": chunk_length_s,
         | 
| 117 | 
            +
                        "stride_length_s": stride_length_s,
         | 
| 118 | 
            +
                        "add_silence_end": add_silence_end,
         | 
| 119 | 
            +
                        "add_silence_start": add_silence_start,
         | 
| 120 | 
            +
                        "num_speakers": num_speakers,
         | 
| 121 | 
            +
                        "min_speakers": min_speakers,
         | 
| 122 | 
            +
                        "max_speakers": max_speakers,
         | 
| 123 | 
            +
                    }
         | 
| 124 | 
            +
                    postprocess_params = {"add_punctuation": add_punctuation, "return_timestamps": True, "return_language": False}
         | 
| 125 | 
            +
                    forward_params = {} if generate_kwargs is None else generate_kwargs
         | 
| 126 | 
            +
                    forward_params.update({"max_new_tokens": max_new_tokens, "return_timestamps": True})
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 127 | 
             
                    return preprocess_params, forward_params, postprocess_params
         | 
| 128 |  | 
| 129 | 
            +
                def preprocess(self,
         | 
| 130 | 
            +
                               inputs,
         | 
| 131 | 
            +
                               chunk_length_s: Optional[int] = None,
         | 
| 132 | 
            +
                               stride_length_s: Optional[int] = None,
         | 
| 133 | 
            +
                               add_silence_end: Optional[float] = None,
         | 
| 134 | 
            +
                               add_silence_start: Optional[float] = None,
         | 
| 135 | 
            +
                               num_speakers: Optional[int] = None,
         | 
| 136 | 
            +
                               min_speakers: Optional[int] = None,
         | 
| 137 | 
            +
                               max_speakers: Optional[int] = None):
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    def _pad_audio_array(_audio):
         | 
| 140 | 
            +
                        if add_silence_start:
         | 
| 141 | 
            +
                            _audio = np.concatenate([np.zeros(int(self.feature_extractor.sampling_rate * add_silence_start)), _audio])
         | 
| 142 | 
            +
                        if add_silence_end:
         | 
| 143 | 
            +
                            _audio = np.concatenate([_audio, np.zeros(int(self.feature_extractor.sampling_rate * add_silence_end))])
         | 
| 144 | 
            +
                        return _audio
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # load file
         | 
| 147 | 
             
                    if isinstance(inputs, str):
         | 
| 148 | 
             
                        if inputs.startswith("http://") or inputs.startswith("https://"):
         | 
| 149 | 
            +
                            # We need to actually check for a real protocol, otherwise it's impossible to use a local file like http_huggingface_co.png
         | 
|  | |
| 150 | 
             
                            inputs = requests.get(inputs).content
         | 
| 151 | 
             
                        else:
         | 
| 152 | 
             
                            with open(inputs, "rb") as f:
         | 
| 153 | 
             
                                inputs = f.read()
         | 
|  | |
| 154 | 
             
                    if isinstance(inputs, bytes):
         | 
| 155 | 
             
                        inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
         | 
|  | |
|  | |
|  | |
| 156 | 
             
                    if isinstance(inputs, dict):
         | 
| 157 | 
            +
                        # Accepting `"array"` which is the key defined in `datasets` for better integration
         | 
| 158 | 
            +
                        if not ("sampling_rate" in inputs and "array" in inputs):
         | 
|  | |
|  | |
| 159 | 
             
                            raise ValueError(
         | 
| 160 | 
             
                                "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
         | 
| 161 | 
            +
                                '"array" key containing the numpy array representing the audio and a "sampling_rate" key, '
         | 
| 162 | 
             
                                "containing the sampling_rate associated with that array"
         | 
| 163 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 164 | 
             
                        in_sampling_rate = inputs.pop("sampling_rate")
         | 
| 165 | 
            +
                        inputs = inputs.pop("array", None)
         | 
|  | |
| 166 | 
             
                        if in_sampling_rate != self.feature_extractor.sampling_rate:
         | 
| 167 | 
             
                            if is_torchaudio_available():
         | 
| 168 | 
             
                                from torchaudio import functional as F
         | 
|  | |
| 171 | 
             
                                    "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
         | 
| 172 | 
             
                                    "The torchaudio package can be installed through: `pip install torchaudio`."
         | 
| 173 | 
             
                                )
         | 
|  | |
| 174 | 
             
                            inputs = F.resample(
         | 
| 175 | 
             
                                torch.from_numpy(inputs), in_sampling_rate, self.feature_extractor.sampling_rate
         | 
| 176 | 
             
                            ).numpy()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 177 |  | 
| 178 | 
            +
                    # validate audio array
         | 
|  | |
|  | |
|  | |
|  | |
| 179 | 
             
                    if not isinstance(inputs, np.ndarray):
         | 
| 180 | 
             
                        raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
         | 
| 181 | 
             
                    if len(inputs.shape) != 1:
         | 
| 182 | 
             
                        raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
         | 
| 183 |  | 
| 184 | 
            +
                    # diarization
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 185 | 
             
                    sd = self.model_speaker_diarization(
         | 
| 186 | 
            +
                        inputs,
         | 
| 187 | 
             
                        num_speakers=num_speakers,
         | 
| 188 | 
             
                        min_speakers=min_speakers,
         | 
| 189 | 
             
                        max_speakers=max_speakers,
         | 
| 190 | 
             
                        sampling_rate=self.feature_extractor.sampling_rate
         | 
| 191 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 192 |  | 
| 193 | 
            +
                    # loop over audio chunks and speakers
         | 
| 194 | 
            +
                    labels = list(sd.labels())
         | 
| 195 | 
            +
                    for n, s in enumerate(labels):
         | 
| 196 | 
            +
                        timelines = list(sd.label_timeline(s))
         | 
| 197 | 
            +
                        for m, i in enumerate(timelines):
         | 
| 198 | 
            +
                            start = int(i.start * self.feature_extractor.sampling_rate)
         | 
| 199 | 
            +
                            end = int(i.end * self.feature_extractor.sampling_rate)
         | 
| 200 | 
            +
                            audio_array = _pad_audio_array(inputs[start: end])
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                            if chunk_length_s is not None:
         | 
| 203 | 
            +
                                stride_length_s = chunk_length_s / 6 if stride_length_s is None else stride_length_s
         | 
| 204 | 
            +
                                stride_length_s = [stride_length_s, stride_length_s] if isinstance(stride_length_s, (int, float)) else stride_length_s
         | 
| 205 | 
            +
                                align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
         | 
| 206 | 
            +
                                chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 207 | 
            +
                                stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 208 | 
            +
                                stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
         | 
| 209 | 
            +
                                if chunk_len < stride_left + stride_right:
         | 
| 210 | 
            +
                                    raise ValueError("Chunk length must be superior to stride length")
         | 
| 211 | 
            +
                                for item in chunk_iter(
         | 
| 212 | 
            +
                                        audio_array, self.feature_extractor, chunk_len, stride_left, stride_right, self.torch_dtype
         | 
| 213 | 
            +
                                ):
         | 
| 214 | 
            +
                                    item["speaker_id"] = s
         | 
| 215 | 
            +
                                    item["speaker_span"] = [i.start, i.end]
         | 
| 216 | 
            +
                                    item["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1 and item["is_last"]
         | 
| 217 | 
            +
                                    yield item
         | 
| 218 | 
             
                            else:
         | 
| 219 | 
            +
                                if inputs.shape[0] > self.feature_extractor.n_samples:
         | 
| 220 | 
            +
                                    processed = self.feature_extractor(
         | 
| 221 | 
            +
                                        audio_array,
         | 
| 222 | 
            +
                                        sampling_rate=self.feature_extractor.sampling_rate,
         | 
| 223 | 
            +
                                        truncation=False,
         | 
| 224 | 
            +
                                        padding="longest",
         | 
| 225 | 
            +
                                        return_tensors="pt",
         | 
| 226 | 
            +
                                    )
         | 
| 227 | 
            +
                                else:
         | 
| 228 | 
            +
                                    processed = self.feature_extractor(
         | 
| 229 | 
            +
                                        audio_array,
         | 
| 230 | 
            +
                                        sampling_rate=self.feature_extractor.sampling_rate,
         | 
| 231 | 
            +
                                        return_tensors="pt"
         | 
| 232 | 
            +
                                    )
         | 
| 233 | 
            +
                                if self.torch_dtype is not None:
         | 
| 234 | 
            +
                                    processed = processed.to(dtype=self.torch_dtype)
         | 
| 235 | 
            +
                                processed["speaker_id"] = s
         | 
| 236 | 
            +
                                processed["speaker_span"] = [i.start, i.end]
         | 
| 237 | 
            +
                                processed["is_last"] = m == len(timelines) - 1 and n == len(labels) - 1
         | 
| 238 | 
            +
                                yield processed
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def _forward(self, model_inputs, **generate_kwargs):
         | 
| 241 | 
            +
                    generate_kwargs["attention_mask"] = model_inputs.pop("attention_mask", None)
         | 
| 242 | 
            +
                    generate_kwargs["input_features"] = model_inputs.pop("input_features")
         | 
| 243 | 
            +
                    tokens = self.model.generate(**generate_kwargs)
         | 
| 244 | 
            +
                    return {"tokens": tokens, **model_inputs}
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                def postprocess(self, model_outputs, **postprocess_parameters):
         | 
| 247 | 
            +
                    if postprocess_parameters["add_punctuation"] and self.punctuator is None:
         | 
| 248 | 
            +
                        self.punctuator = Punctuator()
         | 
| 249 | 
            +
                    outputs = {"chunks": []}
         | 
| 250 | 
            +
                    for o in model_outputs:
         | 
| 251 | 
            +
                        text, chunks = self.tokenizer._decode_asr(
         | 
| 252 | 
            +
                            [o],
         | 
| 253 | 
            +
                            return_language=postprocess_parameters["return_language"],
         | 
| 254 | 
            +
                            return_timestamps=postprocess_parameters["return_timestamps"],
         | 
| 255 | 
            +
                            time_precision=self.feature_extractor.chunk_length / self.model.config.max_source_positions,
         | 
| 256 | 
            +
                        )
         | 
| 257 | 
            +
                        start, end = o["speaker_span"]
         | 
| 258 | 
            +
                        new_chunk = []
         | 
| 259 | 
            +
                        for c in chunks["chunks"]:
         | 
| 260 | 
            +
                            c["timestamp"] = [round(c["timestamp"][0] + start, 2), round(c["timestamp"][0] + end, 2)]
         | 
| 261 | 
            +
                            c["speaker_id"] = o["speaker_id"]
         | 
| 262 | 
            +
                            new_chunk.append(c)
         | 
| 263 | 
            +
                        outputs["chunks"] += new_chunk
         | 
| 264 | 
            +
                    outputs["speaker_ids"] = sorted(set([o["speaker_id"] for o in outputs["chunks"]]))
         | 
| 265 | 
            +
                    for s in outputs["speaker_ids"]:
         | 
| 266 | 
            +
                        outputs[f"chunk/{s}"] = sorted([o for o in outputs["chunks"] if o["speaker_id"] == s], key=lambda x: x["timestamp"][0])
         | 
| 267 | 
            +
                        outputs[f"text/{s}"] = "".join([i["text"] for i in outputs[f"chunk/{s}"]])
         | 
| 268 | 
            +
                        if postprocess_parameters["add_punctuation"]:
         | 
| 269 | 
            +
                            outputs[f"text/{s}"] = self.punctuator.punctuate(outputs[f"text/{s}"])
         | 
| 270 | 
             
                    return outputs
         | 

