File size: 4,716 Bytes
bec7578
 
 
967aebb
c72e80d
 
 
2d00549
c72e80d
2d00549
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a75d7
c72e80d
 
2d00549
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d00549
 
 
 
 
 
 
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
 
 
2d00549
 
 
 
c72e80d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d00549
 
c72e80d
 
2d00549
c72e80d
2d00549
 
 
 
 
 
c72e80d
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
os.system("python -m unidic download")

from typing import Dict, Any, List, Generator
import torch
import os
import logging
from s2s_pipeline import main, prepare_all_args, get_default_arguments, setup_logger, initialize_queues_and_events, build_pipeline
import numpy as np
from queue import Queue, Empty
import threading

class EndpointHandler:
    def __init__(self, path=""):
        (
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        ) = get_default_arguments(mode='none', tts='melo', log_level='DEBUG')
        setup_logger(self.module_kwargs.log_level)

        prepare_all_args(
            self.module_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
        )

        self.queues_and_events = initialize_queues_and_events()

        self.pipeline_manager = build_pipeline(
            self.module_kwargs,
            self.socket_receiver_kwargs,
            self.socket_sender_kwargs,
            self.vad_handler_kwargs,
            self.whisper_stt_handler_kwargs,
            self.paraformer_stt_handler_kwargs,
            self.language_model_handler_kwargs,
            self.mlx_language_model_handler_kwargs,
            self.parler_tts_handler_kwargs,
            self.melo_tts_handler_kwargs,
            self.chat_tts_handler_kwargs,
            self.queues_and_events,
        )

        self.pipeline_manager.start()

        # Add a new queue for collecting the final output
        self.final_output_queue = Queue()

    def _collect_output(self):
        while True:
            try:
                output = self.queues_and_events['send_audio_chunks_queue'].get(timeout=5)  # 2-second timeout
                if isinstance(output, (str, bytes)) and output in (b"END", "END"):
                    self.final_output_queue.put("END")
                    break
                elif isinstance(output, np.ndarray):
                    self.final_output_queue.put(output.tobytes())
                else:
                    self.final_output_queue.put(output)
            except Empty:
                # If no output for 2 seconds, assume processing is complete
                self.final_output_queue.put("END")
                break

    def __call__(self, data: Dict[str, Any]) -> Generator[Dict[str, Any], None, None]:
        """
        Args:
            data (Dict[str, Any]): The input data containing the necessary arguments.
        
        Returns:
            Generator[Dict[str, Any], None, None]: A generator yielding output chunks from the model or pipeline.
        """
        # Start a thread to collect the final output
        self.output_collector_thread = threading.Thread(target=self._collect_output)
        self.output_collector_thread.start()

        input_type = data.get("input_type", "text")
        input_data = data.get("input", "")

        if input_type == "speech":
            # Convert input audio data to numpy array
            audio_array = np.frombuffer(input_data, dtype=np.int16)
            
            # Put audio data into the recv_audio_chunks_queue
            self.queues_and_events['recv_audio_chunks_queue'].put(audio_array.tobytes())
        elif input_type == "text":
            # Put text data directly into the text_prompt_queue
            self.queues_and_events['text_prompt_queue'].put(input_data)
        else:
            raise ValueError(f"Unsupported input type: {input_type}")

        # Collect all output chunks
        output_chunks = []
        while True:
            chunk = self.final_output_queue.get()
            if chunk == "END":
                break
            output_chunks.append(chunk)

        # Combine all audio chunks into a single byte string
        combined_audio = b''.join(output_chunks)

        return {"output": combined_audio}

    def cleanup(self):
        # Stop the pipeline
        self.pipeline_manager.stop()
        
        # Stop the output collector thread
        self.queues_and_events['send_audio_chunks_queue'].put(b"END")
        self.output_collector_thread.join()