|
using System.Collections.Generic; |
|
using UnityEngine; |
|
using Unity.Sentis; |
|
using System.Text; |
|
using Unity.Collections; |
|
|
|
public class RunWhisper : MonoBehaviour |
|
{ |
|
public enum WhisperLanguage |
|
{ |
|
English = 0, |
|
Chinese = 1, |
|
German = 2, |
|
Spanish = 3, |
|
Russian = 4, |
|
Korean = 5, |
|
French = 6, |
|
Japanese = 7, |
|
Portuguese = 8, |
|
Turkish = 9, |
|
Polish = 10, |
|
Catalan = 11, |
|
Dutch = 12, |
|
Arabic = 13, |
|
Swedish = 14, |
|
Italian = 15, |
|
Indonesian = 16, |
|
Hindi = 17, |
|
Finnish = 18, |
|
Vietnamese = 19, |
|
Hebrew = 20, |
|
Ukrainian = 21, |
|
Greek = 22, |
|
Malay = 23, |
|
Czech = 24, |
|
Romanian = 25, |
|
Danish = 26, |
|
Hungarian = 27, |
|
Tamil = 28, |
|
Norwegian = 29, |
|
Thai = 30, |
|
Urdu = 31, |
|
Croatian = 32, |
|
Bulgarian = 33, |
|
Lithuanian = 34, |
|
Latin = 35, |
|
Maori = 36, |
|
Malayalam = 37, |
|
Welsh = 38, |
|
Slovak = 39, |
|
Telugu = 40, |
|
Persian = 41, |
|
Latvian = 42, |
|
Bengali = 43, |
|
Serbian = 44, |
|
Azerbaijani = 45, |
|
Slovenian = 46, |
|
Kannada = 47, |
|
Estonian = 48, |
|
Macedonian = 49, |
|
Breton = 50, |
|
Basque = 51, |
|
Icelandic = 52, |
|
Armenian = 53, |
|
Nepali = 54, |
|
Mongolian = 55, |
|
Bosnian = 56, |
|
Kazakh = 57, |
|
Albanian = 58, |
|
Swahili = 59, |
|
Galician = 60, |
|
Marathi = 61, |
|
Punjabi = 62, |
|
Sinhala = 63, |
|
Khmer = 64, |
|
Shona = 65, |
|
Yoruba = 66, |
|
Somali = 67, |
|
Afrikaans = 68, |
|
Occitan = 69, |
|
Georgian = 70, |
|
Belarusian = 71, |
|
Tajik = 72, |
|
Sindhi = 73, |
|
Gujarati = 74, |
|
Amharic = 75, |
|
Yiddish = 76, |
|
Lao = 77, |
|
Uzbek = 78, |
|
Faroese = 79, |
|
HaitianCreole = 80, |
|
Pashto = 81, |
|
Turkmen = 82, |
|
Nynorsk = 83, |
|
Maltese = 84, |
|
Sanskrit = 85, |
|
Luxembourgish = 86, |
|
Myanmar = 87, |
|
Tibetan = 88, |
|
Tagalog = 89, |
|
Malagasy = 90, |
|
Assamese = 91, |
|
Tatar = 92, |
|
Hawaiian = 93, |
|
Lingala = 94, |
|
Hausa = 95, |
|
Bashkir = 96, |
|
Javanese = 97, |
|
Sundanese = 98, |
|
Cantonese = 99 |
|
} |
|
static int GetLanguageCode(WhisperLanguage language) |
|
{ |
|
return 50259 + (int)language; |
|
} |
|
Worker decoder1, decoder2, encoder, spectrogram; |
|
Worker argmax; |
|
|
|
public AudioClip audioClip; |
|
|
|
|
|
const int maxTokens = 100; |
|
|
|
|
|
const int END_OF_TEXT = 50257; |
|
const int START_OF_TRANSCRIPT = 50258; |
|
const int TRANSCRIBE = 50359; |
|
const int TRANSLATE = 50358; |
|
const int NO_TIME_STAMPS = 50363; |
|
const int START_TIME = 50364; |
|
|
|
int numSamples; |
|
string[] tokens; |
|
|
|
int tokenCount = 0; |
|
NativeArray<int> outputTokens; |
|
|
|
|
|
int[] whiteSpaceCharacters = new int[256]; |
|
|
|
Tensor<float> encodedAudio; |
|
|
|
bool transcribe = false; |
|
string outputString = ""; |
|
|
|
|
|
const int maxSamples = 30 * 16000; |
|
|
|
public ModelAsset audioDecoder1, audioDecoder2; |
|
public ModelAsset audioEncoder; |
|
public ModelAsset logMelSpectro; |
|
|
|
public async void Start() |
|
{ |
|
SetupWhiteSpaceShifts(); |
|
GetTokens(); |
|
|
|
decoder1 = new Worker(ModelLoader.Load(audioDecoder1), BackendType.GPUCompute); |
|
decoder2 = new Worker(ModelLoader.Load(audioDecoder2), BackendType.GPUCompute); |
|
|
|
FunctionalGraph graph = new FunctionalGraph(); |
|
var input = graph.AddInput(DataType.Float, new DynamicTensorShape(1, 1, 51865)); |
|
var amax = Functional.ArgMax(input, -1, false); |
|
var selectTokenModel = graph.Compile(amax); |
|
argmax = new Worker(selectTokenModel, BackendType.GPUCompute); |
|
|
|
encoder = new Worker(ModelLoader.Load(audioEncoder), BackendType.GPUCompute); |
|
spectrogram = new Worker(ModelLoader.Load(logMelSpectro), BackendType.GPUCompute); |
|
|
|
outputTokens = new NativeArray<int>(maxTokens, Allocator.Persistent); |
|
|
|
outputTokens[0] = START_OF_TRANSCRIPT; |
|
outputTokens[1] = GetLanguageCode(WhisperLanguage.English); |
|
outputTokens[2] = TRANSCRIBE; |
|
|
|
tokenCount = 3; |
|
|
|
LoadAudio(); |
|
EncodeAudio(); |
|
transcribe = true; |
|
|
|
tokensTensor = new Tensor<int>(new TensorShape(1, maxTokens)); |
|
ComputeTensorData.Pin(tokensTensor); |
|
tokensTensor.Reshape(new TensorShape(1, tokenCount)); |
|
tokensTensor.dataOnBackend.Upload<int>(outputTokens, tokenCount); |
|
|
|
lastToken = new NativeArray<int>(1, Allocator.Persistent); lastToken[0] = NO_TIME_STAMPS; |
|
lastTokenTensor = new Tensor<int>(new TensorShape(1, 1), new[] { NO_TIME_STAMPS }); |
|
|
|
while (true) |
|
{ |
|
if (!transcribe || tokenCount >= (outputTokens.Length - 1)) |
|
return; |
|
m_Awaitable = InferenceStep(); |
|
await m_Awaitable; |
|
} |
|
} |
|
Awaitable m_Awaitable; |
|
|
|
NativeArray<int> lastToken; |
|
Tensor<int> lastTokenTensor; |
|
Tensor<int> tokensTensor; |
|
Tensor<float> audioInput; |
|
|
|
void LoadAudio() |
|
{ |
|
numSamples = audioClip.samples; |
|
var data = new float[maxSamples]; |
|
numSamples = maxSamples; |
|
audioClip.GetData(data, 0); |
|
audioInput = new Tensor<float>(new TensorShape(1, numSamples), data); |
|
} |
|
|
|
void EncodeAudio() |
|
{ |
|
spectrogram.Schedule(audioInput); |
|
var logmel = spectrogram.PeekOutput() as Tensor<float>; |
|
encoder.Schedule(logmel); |
|
encodedAudio = encoder.PeekOutput() as Tensor<float>; |
|
} |
|
async Awaitable InferenceStep() |
|
{ |
|
decoder1.SetInput("input_ids", tokensTensor); |
|
decoder1.SetInput("encoder_hidden_states", encodedAudio); |
|
decoder1.Schedule(); |
|
|
|
var past_key_values_0_decoder_key = decoder1.PeekOutput("present.0.decoder.key") as Tensor<float>; |
|
var past_key_values_0_decoder_value = decoder1.PeekOutput("present.0.decoder.value") as Tensor<float>; |
|
var past_key_values_1_decoder_key = decoder1.PeekOutput("present.1.decoder.key") as Tensor<float>; |
|
var past_key_values_1_decoder_value = decoder1.PeekOutput("present.1.decoder.value") as Tensor<float>; |
|
var past_key_values_2_decoder_key = decoder1.PeekOutput("present.2.decoder.key") as Tensor<float>; |
|
var past_key_values_2_decoder_value = decoder1.PeekOutput("present.2.decoder.value") as Tensor<float>; |
|
var past_key_values_3_decoder_key = decoder1.PeekOutput("present.3.decoder.key") as Tensor<float>; |
|
var past_key_values_3_decoder_value = decoder1.PeekOutput("present.3.decoder.value") as Tensor<float>; |
|
var past_key_values_4_decoder_key = decoder1.PeekOutput("present.4.decoder.key") as Tensor<float>; |
|
var past_key_values_4_decoder_value = decoder1.PeekOutput("present.4.decoder.value") as Tensor<float>; |
|
var past_key_values_5_decoder_key = decoder1.PeekOutput("present.5.decoder.key") as Tensor<float>; |
|
var past_key_values_5_decoder_value = decoder1.PeekOutput("present.5.decoder.value") as Tensor<float>; |
|
|
|
var past_key_values_0_encoder_key = decoder1.PeekOutput("present.0.encoder.key") as Tensor<float>; |
|
var past_key_values_0_encoder_value = decoder1.PeekOutput("present.0.encoder.value") as Tensor<float>; |
|
var past_key_values_1_encoder_key = decoder1.PeekOutput("present.1.encoder.key") as Tensor<float>; |
|
var past_key_values_1_encoder_value = decoder1.PeekOutput("present.1.encoder.value") as Tensor<float>; |
|
var past_key_values_2_encoder_key = decoder1.PeekOutput("present.2.encoder.key") as Tensor<float>; |
|
var past_key_values_2_encoder_value = decoder1.PeekOutput("present.2.encoder.value") as Tensor<float>; |
|
var past_key_values_3_encoder_key = decoder1.PeekOutput("present.3.encoder.key") as Tensor<float>; |
|
var past_key_values_3_encoder_value = decoder1.PeekOutput("present.3.encoder.value") as Tensor<float>; |
|
var past_key_values_4_encoder_key = decoder1.PeekOutput("present.4.encoder.key") as Tensor<float>; |
|
var past_key_values_4_encoder_value = decoder1.PeekOutput("present.4.encoder.value") as Tensor<float>; |
|
var past_key_values_5_encoder_key = decoder1.PeekOutput("present.5.encoder.key") as Tensor<float>; |
|
var past_key_values_5_encoder_value = decoder1.PeekOutput("present.5.encoder.value") as Tensor<float>; |
|
|
|
decoder2.SetInput("input_ids", lastTokenTensor); |
|
decoder2.SetInput("past_key_values.0.decoder.key", past_key_values_0_decoder_key); |
|
decoder2.SetInput("past_key_values.0.decoder.value", past_key_values_0_decoder_value); |
|
decoder2.SetInput("past_key_values.1.decoder.key", past_key_values_1_decoder_key); |
|
decoder2.SetInput("past_key_values.1.decoder.value", past_key_values_1_decoder_value); |
|
decoder2.SetInput("past_key_values.2.decoder.key", past_key_values_2_decoder_key); |
|
decoder2.SetInput("past_key_values.2.decoder.value", past_key_values_2_decoder_value); |
|
decoder2.SetInput("past_key_values.3.decoder.key", past_key_values_3_decoder_key); |
|
decoder2.SetInput("past_key_values.3.decoder.value", past_key_values_3_decoder_value); |
|
decoder2.SetInput("past_key_values.4.decoder.key", past_key_values_4_decoder_key); |
|
decoder2.SetInput("past_key_values.4.decoder.value", past_key_values_4_decoder_value); |
|
decoder2.SetInput("past_key_values.5.decoder.key", past_key_values_5_decoder_key); |
|
decoder2.SetInput("past_key_values.5.decoder.value", past_key_values_5_decoder_value); |
|
|
|
decoder2.SetInput("past_key_values.0.encoder.key", past_key_values_0_encoder_key); |
|
decoder2.SetInput("past_key_values.0.encoder.value", past_key_values_0_encoder_value); |
|
decoder2.SetInput("past_key_values.1.encoder.key", past_key_values_1_encoder_key); |
|
decoder2.SetInput("past_key_values.1.encoder.value", past_key_values_1_encoder_value); |
|
decoder2.SetInput("past_key_values.2.encoder.key", past_key_values_2_encoder_key); |
|
decoder2.SetInput("past_key_values.2.encoder.value", past_key_values_2_encoder_value); |
|
decoder2.SetInput("past_key_values.3.encoder.key", past_key_values_3_encoder_key); |
|
decoder2.SetInput("past_key_values.3.encoder.value", past_key_values_3_encoder_value); |
|
decoder2.SetInput("past_key_values.4.encoder.key", past_key_values_4_encoder_key); |
|
decoder2.SetInput("past_key_values.4.encoder.value", past_key_values_4_encoder_value); |
|
decoder2.SetInput("past_key_values.5.encoder.key", past_key_values_5_encoder_key); |
|
decoder2.SetInput("past_key_values.5.encoder.value", past_key_values_5_encoder_value); |
|
|
|
decoder2.Schedule(); |
|
|
|
var logits = decoder2.PeekOutput("logits") as Tensor<float>; |
|
argmax.Schedule(logits); |
|
using var t_Token = await argmax.PeekOutput().ReadbackAndCloneAsync() as Tensor<int>; |
|
int index = t_Token[0]; |
|
|
|
outputTokens[tokenCount] = lastToken[0]; |
|
lastToken[0] = index; |
|
tokenCount++; |
|
tokensTensor.Reshape(new TensorShape(1, tokenCount)); |
|
tokensTensor.dataOnBackend.Upload<int>(outputTokens, tokenCount); |
|
lastTokenTensor.dataOnBackend.Upload<int>(lastToken, 1); |
|
|
|
if (index == END_OF_TEXT) |
|
{ |
|
transcribe = false; |
|
} |
|
else if (index < tokens.Length) |
|
{ |
|
outputString += GetUnicodeText(tokens[index]); |
|
} |
|
|
|
Debug.Log(outputString); |
|
} |
|
|
|
|
|
public TextAsset jsonFile; |
|
void GetTokens() |
|
{ |
|
var vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonFile.text); |
|
tokens = new string[vocab.Count]; |
|
foreach (var item in vocab) |
|
{ |
|
tokens[item.Value] = item.Key; |
|
} |
|
} |
|
|
|
string GetUnicodeText(string text) |
|
{ |
|
var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text)); |
|
return Encoding.UTF8.GetString(bytes); |
|
} |
|
|
|
string ShiftCharacterDown(string text) |
|
{ |
|
string outText = ""; |
|
foreach (char letter in text) |
|
{ |
|
outText += ((int)letter <= 256) ? letter : |
|
(char)whiteSpaceCharacters[(int)(letter - 256)]; |
|
} |
|
return outText; |
|
} |
|
|
|
void SetupWhiteSpaceShifts() |
|
{ |
|
for (int i = 0, n = 0; i < 256; i++) |
|
{ |
|
if (IsWhiteSpace((char)i)) whiteSpaceCharacters[n++] = i; |
|
} |
|
} |
|
|
|
bool IsWhiteSpace(char c) |
|
{ |
|
return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�')); |
|
} |
|
|
|
private void OnDestroy() |
|
{ |
|
decoder1.Dispose(); |
|
decoder2.Dispose(); |
|
encoder.Dispose(); |
|
spectrogram.Dispose(); |
|
argmax.Dispose(); |
|
audioInput.Dispose(); |
|
lastTokenTensor.Dispose(); |
|
tokensTensor.Dispose(); |
|
} |
|
} |
|
|