using System.Collections.Generic; using UnityEngine; using Unity.Sentis; using System.Text; using Unity.Collections; public class RunWhisper : MonoBehaviour { public enum WhisperLanguage { English = 0, Chinese = 1, German = 2, Spanish = 3, Russian = 4, Korean = 5, French = 6, Japanese = 7, Portuguese = 8, Turkish = 9, Polish = 10, Catalan = 11, Dutch = 12, Arabic = 13, Swedish = 14, Italian = 15, Indonesian = 16, Hindi = 17, Finnish = 18, Vietnamese = 19, Hebrew = 20, Ukrainian = 21, Greek = 22, Malay = 23, Czech = 24, Romanian = 25, Danish = 26, Hungarian = 27, Tamil = 28, Norwegian = 29, Thai = 30, Urdu = 31, Croatian = 32, Bulgarian = 33, Lithuanian = 34, Latin = 35, Maori = 36, Malayalam = 37, Welsh = 38, Slovak = 39, Telugu = 40, Persian = 41, Latvian = 42, Bengali = 43, Serbian = 44, Azerbaijani = 45, Slovenian = 46, Kannada = 47, Estonian = 48, Macedonian = 49, Breton = 50, Basque = 51, Icelandic = 52, Armenian = 53, Nepali = 54, Mongolian = 55, Bosnian = 56, Kazakh = 57, Albanian = 58, Swahili = 59, Galician = 60, Marathi = 61, Punjabi = 62, Sinhala = 63, Khmer = 64, Shona = 65, Yoruba = 66, Somali = 67, Afrikaans = 68, Occitan = 69, Georgian = 70, Belarusian = 71, Tajik = 72, Sindhi = 73, Gujarati = 74, Amharic = 75, Yiddish = 76, Lao = 77, Uzbek = 78, Faroese = 79, HaitianCreole = 80, Pashto = 81, Turkmen = 82, Nynorsk = 83, Maltese = 84, Sanskrit = 85, Luxembourgish = 86, Myanmar = 87, Tibetan = 88, Tagalog = 89, Malagasy = 90, Assamese = 91, Tatar = 92, Hawaiian = 93, Lingala = 94, Hausa = 95, Bashkir = 96, Javanese = 97, Sundanese = 98, Cantonese = 99 } static int GetLanguageCode(WhisperLanguage language) { return 50259 + (int)language; } Worker decoder1, decoder2, encoder, spectrogram; Worker argmax; public AudioClip audioClip; // This is how many tokens you want. It can be adjusted. const int maxTokens = 100; // Special tokens see added tokens file for details const int END_OF_TEXT = 50257; const int START_OF_TRANSCRIPT = 50258; const int TRANSCRIBE = 50359; //for speech-to-text in specified language const int TRANSLATE = 50358; //for speech-to-text then translate to English const int NO_TIME_STAMPS = 50363; const int START_TIME = 50364; int numSamples; string[] tokens; int tokenCount = 0; NativeArray outputTokens; // Used for special character decoding int[] whiteSpaceCharacters = new int[256]; Tensor encodedAudio; bool transcribe = false; string outputString = ""; // Maximum size of audioClip (30s at 16kHz) const int maxSamples = 30 * 16000; public ModelAsset audioDecoder1, audioDecoder2; public ModelAsset audioEncoder; public ModelAsset logMelSpectro; public async void Start() { SetupWhiteSpaceShifts(); GetTokens(); decoder1 = new Worker(ModelLoader.Load(audioDecoder1), BackendType.GPUCompute); decoder2 = new Worker(ModelLoader.Load(audioDecoder2), BackendType.GPUCompute); FunctionalGraph graph = new FunctionalGraph(); var input = graph.AddInput(DataType.Float, new DynamicTensorShape(1, 1, 51865)); var amax = Functional.ArgMax(input, -1, false); var selectTokenModel = graph.Compile(amax); argmax = new Worker(selectTokenModel, BackendType.GPUCompute); encoder = new Worker(ModelLoader.Load(audioEncoder), BackendType.GPUCompute); spectrogram = new Worker(ModelLoader.Load(logMelSpectro), BackendType.GPUCompute); outputTokens = new NativeArray(maxTokens, Allocator.Persistent); outputTokens[0] = START_OF_TRANSCRIPT; outputTokens[1] = GetLanguageCode(WhisperLanguage.English);// GERMAN;//FRENCH;// outputTokens[2] = TRANSCRIBE; //TRANSLATE;// //outputTokens[3] = NO_TIME_STAMPS;// START_TIME;// tokenCount = 3; LoadAudio(); EncodeAudio(); transcribe = true; tokensTensor = new Tensor(new TensorShape(1, maxTokens)); ComputeTensorData.Pin(tokensTensor); tokensTensor.Reshape(new TensorShape(1, tokenCount)); tokensTensor.dataOnBackend.Upload(outputTokens, tokenCount); lastToken = new NativeArray(1, Allocator.Persistent); lastToken[0] = NO_TIME_STAMPS; lastTokenTensor = new Tensor(new TensorShape(1, 1), new[] { NO_TIME_STAMPS }); while (true) { if (!transcribe || tokenCount >= (outputTokens.Length - 1)) return; m_Awaitable = InferenceStep(); await m_Awaitable; } } Awaitable m_Awaitable; NativeArray lastToken; Tensor lastTokenTensor; Tensor tokensTensor; Tensor audioInput; void LoadAudio() { numSamples = audioClip.samples; var data = new float[maxSamples]; numSamples = maxSamples; audioClip.GetData(data, 0); audioInput = new Tensor(new TensorShape(1, numSamples), data); } void EncodeAudio() { spectrogram.Schedule(audioInput); var logmel = spectrogram.PeekOutput() as Tensor; encoder.Schedule(logmel); encodedAudio = encoder.PeekOutput() as Tensor; } async Awaitable InferenceStep() { decoder1.SetInput("input_ids", tokensTensor); decoder1.SetInput("encoder_hidden_states", encodedAudio); decoder1.Schedule(); var past_key_values_0_decoder_key = decoder1.PeekOutput("present.0.decoder.key") as Tensor; var past_key_values_0_decoder_value = decoder1.PeekOutput("present.0.decoder.value") as Tensor; var past_key_values_1_decoder_key = decoder1.PeekOutput("present.1.decoder.key") as Tensor; var past_key_values_1_decoder_value = decoder1.PeekOutput("present.1.decoder.value") as Tensor; var past_key_values_2_decoder_key = decoder1.PeekOutput("present.2.decoder.key") as Tensor; var past_key_values_2_decoder_value = decoder1.PeekOutput("present.2.decoder.value") as Tensor; var past_key_values_3_decoder_key = decoder1.PeekOutput("present.3.decoder.key") as Tensor; var past_key_values_3_decoder_value = decoder1.PeekOutput("present.3.decoder.value") as Tensor; var past_key_values_4_decoder_key = decoder1.PeekOutput("present.4.decoder.key") as Tensor; var past_key_values_4_decoder_value = decoder1.PeekOutput("present.4.decoder.value") as Tensor; var past_key_values_5_decoder_key = decoder1.PeekOutput("present.5.decoder.key") as Tensor; var past_key_values_5_decoder_value = decoder1.PeekOutput("present.5.decoder.value") as Tensor; var past_key_values_0_encoder_key = decoder1.PeekOutput("present.0.encoder.key") as Tensor; var past_key_values_0_encoder_value = decoder1.PeekOutput("present.0.encoder.value") as Tensor; var past_key_values_1_encoder_key = decoder1.PeekOutput("present.1.encoder.key") as Tensor; var past_key_values_1_encoder_value = decoder1.PeekOutput("present.1.encoder.value") as Tensor; var past_key_values_2_encoder_key = decoder1.PeekOutput("present.2.encoder.key") as Tensor; var past_key_values_2_encoder_value = decoder1.PeekOutput("present.2.encoder.value") as Tensor; var past_key_values_3_encoder_key = decoder1.PeekOutput("present.3.encoder.key") as Tensor; var past_key_values_3_encoder_value = decoder1.PeekOutput("present.3.encoder.value") as Tensor; var past_key_values_4_encoder_key = decoder1.PeekOutput("present.4.encoder.key") as Tensor; var past_key_values_4_encoder_value = decoder1.PeekOutput("present.4.encoder.value") as Tensor; var past_key_values_5_encoder_key = decoder1.PeekOutput("present.5.encoder.key") as Tensor; var past_key_values_5_encoder_value = decoder1.PeekOutput("present.5.encoder.value") as Tensor; decoder2.SetInput("input_ids", lastTokenTensor); decoder2.SetInput("past_key_values.0.decoder.key", past_key_values_0_decoder_key); decoder2.SetInput("past_key_values.0.decoder.value", past_key_values_0_decoder_value); decoder2.SetInput("past_key_values.1.decoder.key", past_key_values_1_decoder_key); decoder2.SetInput("past_key_values.1.decoder.value", past_key_values_1_decoder_value); decoder2.SetInput("past_key_values.2.decoder.key", past_key_values_2_decoder_key); decoder2.SetInput("past_key_values.2.decoder.value", past_key_values_2_decoder_value); decoder2.SetInput("past_key_values.3.decoder.key", past_key_values_3_decoder_key); decoder2.SetInput("past_key_values.3.decoder.value", past_key_values_3_decoder_value); decoder2.SetInput("past_key_values.4.decoder.key", past_key_values_4_decoder_key); decoder2.SetInput("past_key_values.4.decoder.value", past_key_values_4_decoder_value); decoder2.SetInput("past_key_values.5.decoder.key", past_key_values_5_decoder_key); decoder2.SetInput("past_key_values.5.decoder.value", past_key_values_5_decoder_value); decoder2.SetInput("past_key_values.0.encoder.key", past_key_values_0_encoder_key); decoder2.SetInput("past_key_values.0.encoder.value", past_key_values_0_encoder_value); decoder2.SetInput("past_key_values.1.encoder.key", past_key_values_1_encoder_key); decoder2.SetInput("past_key_values.1.encoder.value", past_key_values_1_encoder_value); decoder2.SetInput("past_key_values.2.encoder.key", past_key_values_2_encoder_key); decoder2.SetInput("past_key_values.2.encoder.value", past_key_values_2_encoder_value); decoder2.SetInput("past_key_values.3.encoder.key", past_key_values_3_encoder_key); decoder2.SetInput("past_key_values.3.encoder.value", past_key_values_3_encoder_value); decoder2.SetInput("past_key_values.4.encoder.key", past_key_values_4_encoder_key); decoder2.SetInput("past_key_values.4.encoder.value", past_key_values_4_encoder_value); decoder2.SetInput("past_key_values.5.encoder.key", past_key_values_5_encoder_key); decoder2.SetInput("past_key_values.5.encoder.value", past_key_values_5_encoder_value); decoder2.Schedule(); var logits = decoder2.PeekOutput("logits") as Tensor; argmax.Schedule(logits); using var t_Token = await argmax.PeekOutput().ReadbackAndCloneAsync() as Tensor; int index = t_Token[0]; outputTokens[tokenCount] = lastToken[0]; lastToken[0] = index; tokenCount++; tokensTensor.Reshape(new TensorShape(1, tokenCount)); tokensTensor.dataOnBackend.Upload(outputTokens, tokenCount); lastTokenTensor.dataOnBackend.Upload(lastToken, 1); if (index == END_OF_TEXT) { transcribe = false; } else if (index < tokens.Length) { outputString += GetUnicodeText(tokens[index]); } Debug.Log(outputString); } // Tokenizer public TextAsset jsonFile; void GetTokens() { var vocab = Newtonsoft.Json.JsonConvert.DeserializeObject>(jsonFile.text); tokens = new string[vocab.Count]; foreach (var item in vocab) { tokens[item.Value] = item.Key; } } string GetUnicodeText(string text) { var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text)); return Encoding.UTF8.GetString(bytes); } string ShiftCharacterDown(string text) { string outText = ""; foreach (char letter in text) { outText += ((int)letter <= 256) ? letter : (char)whiteSpaceCharacters[(int)(letter - 256)]; } return outText; } void SetupWhiteSpaceShifts() { for (int i = 0, n = 0; i < 256; i++) { if (IsWhiteSpace((char)i)) whiteSpaceCharacters[n++] = i; } } bool IsWhiteSpace(char c) { return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�')); } private void OnDestroy() { decoder1.Dispose(); decoder2.Dispose(); encoder.Dispose(); spectrogram.Dispose(); argmax.Dispose(); audioInput.Dispose(); lastTokenTensor.Dispose(); tokensTensor.Dispose(); } }