togawa83's picture
first commit
f7e5b33
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.Text;
using Unity.Collections;
public class RunWhisper : MonoBehaviour
{
public enum WhisperLanguage
{
English = 0,
Chinese = 1,
German = 2,
Spanish = 3,
Russian = 4,
Korean = 5,
French = 6,
Japanese = 7,
Portuguese = 8,
Turkish = 9,
Polish = 10,
Catalan = 11,
Dutch = 12,
Arabic = 13,
Swedish = 14,
Italian = 15,
Indonesian = 16,
Hindi = 17,
Finnish = 18,
Vietnamese = 19,
Hebrew = 20,
Ukrainian = 21,
Greek = 22,
Malay = 23,
Czech = 24,
Romanian = 25,
Danish = 26,
Hungarian = 27,
Tamil = 28,
Norwegian = 29,
Thai = 30,
Urdu = 31,
Croatian = 32,
Bulgarian = 33,
Lithuanian = 34,
Latin = 35,
Maori = 36,
Malayalam = 37,
Welsh = 38,
Slovak = 39,
Telugu = 40,
Persian = 41,
Latvian = 42,
Bengali = 43,
Serbian = 44,
Azerbaijani = 45,
Slovenian = 46,
Kannada = 47,
Estonian = 48,
Macedonian = 49,
Breton = 50,
Basque = 51,
Icelandic = 52,
Armenian = 53,
Nepali = 54,
Mongolian = 55,
Bosnian = 56,
Kazakh = 57,
Albanian = 58,
Swahili = 59,
Galician = 60,
Marathi = 61,
Punjabi = 62,
Sinhala = 63,
Khmer = 64,
Shona = 65,
Yoruba = 66,
Somali = 67,
Afrikaans = 68,
Occitan = 69,
Georgian = 70,
Belarusian = 71,
Tajik = 72,
Sindhi = 73,
Gujarati = 74,
Amharic = 75,
Yiddish = 76,
Lao = 77,
Uzbek = 78,
Faroese = 79,
HaitianCreole = 80,
Pashto = 81,
Turkmen = 82,
Nynorsk = 83,
Maltese = 84,
Sanskrit = 85,
Luxembourgish = 86,
Myanmar = 87,
Tibetan = 88,
Tagalog = 89,
Malagasy = 90,
Assamese = 91,
Tatar = 92,
Hawaiian = 93,
Lingala = 94,
Hausa = 95,
Bashkir = 96,
Javanese = 97,
Sundanese = 98,
Cantonese = 99
}
static int GetLanguageCode(WhisperLanguage language)
{
return 50259 + (int)language;
}
Worker decoder1, decoder2, encoder, spectrogram;
Worker argmax;
public AudioClip audioClip;
// This is how many tokens you want. It can be adjusted.
const int maxTokens = 100;
// Special tokens see added tokens file for details
const int END_OF_TEXT = 50257;
const int START_OF_TRANSCRIPT = 50258;
const int TRANSCRIBE = 50359; //for speech-to-text in specified language
const int TRANSLATE = 50358; //for speech-to-text then translate to English
const int NO_TIME_STAMPS = 50363;
const int START_TIME = 50364;
int numSamples;
string[] tokens;
int tokenCount = 0;
NativeArray<int> outputTokens;
// Used for special character decoding
int[] whiteSpaceCharacters = new int[256];
Tensor<float> encodedAudio;
bool transcribe = false;
string outputString = "";
// Maximum size of audioClip (30s at 16kHz)
const int maxSamples = 30 * 16000;
public ModelAsset audioDecoder1, audioDecoder2;
public ModelAsset audioEncoder;
public ModelAsset logMelSpectro;
public async void Start()
{
SetupWhiteSpaceShifts();
GetTokens();
decoder1 = new Worker(ModelLoader.Load(audioDecoder1), BackendType.GPUCompute);
decoder2 = new Worker(ModelLoader.Load(audioDecoder2), BackendType.GPUCompute);
FunctionalGraph graph = new FunctionalGraph();
var input = graph.AddInput(DataType.Float, new DynamicTensorShape(1, 1, 51865));
var amax = Functional.ArgMax(input, -1, false);
var selectTokenModel = graph.Compile(amax);
argmax = new Worker(selectTokenModel, BackendType.GPUCompute);
encoder = new Worker(ModelLoader.Load(audioEncoder), BackendType.GPUCompute);
spectrogram = new Worker(ModelLoader.Load(logMelSpectro), BackendType.GPUCompute);
outputTokens = new NativeArray<int>(maxTokens, Allocator.Persistent);
outputTokens[0] = START_OF_TRANSCRIPT;
outputTokens[1] = GetLanguageCode(WhisperLanguage.English);// GERMAN;//FRENCH;//
outputTokens[2] = TRANSCRIBE; //TRANSLATE;//
//outputTokens[3] = NO_TIME_STAMPS;// START_TIME;//
tokenCount = 3;
LoadAudio();
EncodeAudio();
transcribe = true;
tokensTensor = new Tensor<int>(new TensorShape(1, maxTokens));
ComputeTensorData.Pin(tokensTensor);
tokensTensor.Reshape(new TensorShape(1, tokenCount));
tokensTensor.dataOnBackend.Upload<int>(outputTokens, tokenCount);
lastToken = new NativeArray<int>(1, Allocator.Persistent); lastToken[0] = NO_TIME_STAMPS;
lastTokenTensor = new Tensor<int>(new TensorShape(1, 1), new[] { NO_TIME_STAMPS });
while (true)
{
if (!transcribe || tokenCount >= (outputTokens.Length - 1))
return;
m_Awaitable = InferenceStep();
await m_Awaitable;
}
}
Awaitable m_Awaitable;
NativeArray<int> lastToken;
Tensor<int> lastTokenTensor;
Tensor<int> tokensTensor;
Tensor<float> audioInput;
void LoadAudio()
{
numSamples = audioClip.samples;
var data = new float[maxSamples];
numSamples = maxSamples;
audioClip.GetData(data, 0);
audioInput = new Tensor<float>(new TensorShape(1, numSamples), data);
}
void EncodeAudio()
{
spectrogram.Schedule(audioInput);
var logmel = spectrogram.PeekOutput() as Tensor<float>;
encoder.Schedule(logmel);
encodedAudio = encoder.PeekOutput() as Tensor<float>;
}
async Awaitable InferenceStep()
{
decoder1.SetInput("input_ids", tokensTensor);
decoder1.SetInput("encoder_hidden_states", encodedAudio);
decoder1.Schedule();
var past_key_values_0_decoder_key = decoder1.PeekOutput("present.0.decoder.key") as Tensor<float>;
var past_key_values_0_decoder_value = decoder1.PeekOutput("present.0.decoder.value") as Tensor<float>;
var past_key_values_1_decoder_key = decoder1.PeekOutput("present.1.decoder.key") as Tensor<float>;
var past_key_values_1_decoder_value = decoder1.PeekOutput("present.1.decoder.value") as Tensor<float>;
var past_key_values_2_decoder_key = decoder1.PeekOutput("present.2.decoder.key") as Tensor<float>;
var past_key_values_2_decoder_value = decoder1.PeekOutput("present.2.decoder.value") as Tensor<float>;
var past_key_values_3_decoder_key = decoder1.PeekOutput("present.3.decoder.key") as Tensor<float>;
var past_key_values_3_decoder_value = decoder1.PeekOutput("present.3.decoder.value") as Tensor<float>;
var past_key_values_4_decoder_key = decoder1.PeekOutput("present.4.decoder.key") as Tensor<float>;
var past_key_values_4_decoder_value = decoder1.PeekOutput("present.4.decoder.value") as Tensor<float>;
var past_key_values_5_decoder_key = decoder1.PeekOutput("present.5.decoder.key") as Tensor<float>;
var past_key_values_5_decoder_value = decoder1.PeekOutput("present.5.decoder.value") as Tensor<float>;
var past_key_values_0_encoder_key = decoder1.PeekOutput("present.0.encoder.key") as Tensor<float>;
var past_key_values_0_encoder_value = decoder1.PeekOutput("present.0.encoder.value") as Tensor<float>;
var past_key_values_1_encoder_key = decoder1.PeekOutput("present.1.encoder.key") as Tensor<float>;
var past_key_values_1_encoder_value = decoder1.PeekOutput("present.1.encoder.value") as Tensor<float>;
var past_key_values_2_encoder_key = decoder1.PeekOutput("present.2.encoder.key") as Tensor<float>;
var past_key_values_2_encoder_value = decoder1.PeekOutput("present.2.encoder.value") as Tensor<float>;
var past_key_values_3_encoder_key = decoder1.PeekOutput("present.3.encoder.key") as Tensor<float>;
var past_key_values_3_encoder_value = decoder1.PeekOutput("present.3.encoder.value") as Tensor<float>;
var past_key_values_4_encoder_key = decoder1.PeekOutput("present.4.encoder.key") as Tensor<float>;
var past_key_values_4_encoder_value = decoder1.PeekOutput("present.4.encoder.value") as Tensor<float>;
var past_key_values_5_encoder_key = decoder1.PeekOutput("present.5.encoder.key") as Tensor<float>;
var past_key_values_5_encoder_value = decoder1.PeekOutput("present.5.encoder.value") as Tensor<float>;
decoder2.SetInput("input_ids", lastTokenTensor);
decoder2.SetInput("past_key_values.0.decoder.key", past_key_values_0_decoder_key);
decoder2.SetInput("past_key_values.0.decoder.value", past_key_values_0_decoder_value);
decoder2.SetInput("past_key_values.1.decoder.key", past_key_values_1_decoder_key);
decoder2.SetInput("past_key_values.1.decoder.value", past_key_values_1_decoder_value);
decoder2.SetInput("past_key_values.2.decoder.key", past_key_values_2_decoder_key);
decoder2.SetInput("past_key_values.2.decoder.value", past_key_values_2_decoder_value);
decoder2.SetInput("past_key_values.3.decoder.key", past_key_values_3_decoder_key);
decoder2.SetInput("past_key_values.3.decoder.value", past_key_values_3_decoder_value);
decoder2.SetInput("past_key_values.4.decoder.key", past_key_values_4_decoder_key);
decoder2.SetInput("past_key_values.4.decoder.value", past_key_values_4_decoder_value);
decoder2.SetInput("past_key_values.5.decoder.key", past_key_values_5_decoder_key);
decoder2.SetInput("past_key_values.5.decoder.value", past_key_values_5_decoder_value);
decoder2.SetInput("past_key_values.0.encoder.key", past_key_values_0_encoder_key);
decoder2.SetInput("past_key_values.0.encoder.value", past_key_values_0_encoder_value);
decoder2.SetInput("past_key_values.1.encoder.key", past_key_values_1_encoder_key);
decoder2.SetInput("past_key_values.1.encoder.value", past_key_values_1_encoder_value);
decoder2.SetInput("past_key_values.2.encoder.key", past_key_values_2_encoder_key);
decoder2.SetInput("past_key_values.2.encoder.value", past_key_values_2_encoder_value);
decoder2.SetInput("past_key_values.3.encoder.key", past_key_values_3_encoder_key);
decoder2.SetInput("past_key_values.3.encoder.value", past_key_values_3_encoder_value);
decoder2.SetInput("past_key_values.4.encoder.key", past_key_values_4_encoder_key);
decoder2.SetInput("past_key_values.4.encoder.value", past_key_values_4_encoder_value);
decoder2.SetInput("past_key_values.5.encoder.key", past_key_values_5_encoder_key);
decoder2.SetInput("past_key_values.5.encoder.value", past_key_values_5_encoder_value);
decoder2.Schedule();
var logits = decoder2.PeekOutput("logits") as Tensor<float>;
argmax.Schedule(logits);
using var t_Token = await argmax.PeekOutput().ReadbackAndCloneAsync() as Tensor<int>;
int index = t_Token[0];
outputTokens[tokenCount] = lastToken[0];
lastToken[0] = index;
tokenCount++;
tokensTensor.Reshape(new TensorShape(1, tokenCount));
tokensTensor.dataOnBackend.Upload<int>(outputTokens, tokenCount);
lastTokenTensor.dataOnBackend.Upload<int>(lastToken, 1);
if (index == END_OF_TEXT)
{
transcribe = false;
}
else if (index < tokens.Length)
{
outputString += GetUnicodeText(tokens[index]);
}
Debug.Log(outputString);
}
// Tokenizer
public TextAsset jsonFile;
void GetTokens()
{
var vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonFile.text);
tokens = new string[vocab.Count];
foreach (var item in vocab)
{
tokens[item.Value] = item.Key;
}
}
string GetUnicodeText(string text)
{
var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text));
return Encoding.UTF8.GetString(bytes);
}
string ShiftCharacterDown(string text)
{
string outText = "";
foreach (char letter in text)
{
outText += ((int)letter <= 256) ? letter :
(char)whiteSpaceCharacters[(int)(letter - 256)];
}
return outText;
}
void SetupWhiteSpaceShifts()
{
for (int i = 0, n = 0; i < 256; i++)
{
if (IsWhiteSpace((char)i)) whiteSpaceCharacters[n++] = i;
}
}
bool IsWhiteSpace(char c)
{
return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�'));
}
private void OnDestroy()
{
decoder1.Dispose();
decoder2.Dispose();
encoder.Dispose();
spectrogram.Dispose();
argmax.Dispose();
audioInput.Dispose();
lastTokenTensor.Dispose();
tokensTensor.Dispose();
}
}