inference-engine-blaze-face / RunBlazeFace.cs
UnityPaul's picture
Upload 4 files
505e8d0 verified
raw
history blame
11.1 kB
using UnityEngine;
using Unity.Sentis;
using UnityEngine.Video;
using UnityEngine.UI;
using Lays = Unity.Sentis.Layers;
using FF = Unity.Sentis.Functional;
/*
* Blaze Face Inference
* ====================
*
* Basic inference script for blaze face
*
* Put this script on the Main Camera
* Drawg the blazeface.sentis into the modelAsset field
* Create a RawImage in the scene
* Put a link to that image in previewUI
* Put a video in Assets/StreamingAssets folder and put the name of it int videoName
* Or put a test image in inputImage
* Set inputType to appropriate input
*/
public class RunBlazeFace : MonoBehaviour
{
//Draw the sentis model asset here:
public ModelAsset modelAsset;
//Drag a link to a raw image here:
public RawImage previewUI = null;
// Put your bounding box sprite image here
public Sprite boundingboxSprite;
public Texture2D borderTexture;
// 6 optional sprite images (left eye, right eye, nose, mouth, left ear, right ear)
public Sprite[] markerTextures;
public string videoName = "chatting.mp4";
//
public Texture2D inputImage;
public InputType inputType = InputType.Video;
Vector2Int resolution = new Vector2Int(640, 640);
WebCamTexture webcam;
VideoPlayer video;
const BackendType backend = BackendType.GPUCompute;
RenderTexture targetTexture;
public enum InputType { Image, Video, Webcam };
//Some adjustable parameters for the model
[SerializeField, Range(0, 1)] float iouThreshold = 0.5f;
[SerializeField, Range(0, 1)] float scoreThreshold = 0.5f;
int maxOutputBoxes = 64;
IWorker worker;
//Holds input image size
int size = 128;
Model model;
//webcam device name:
const string deviceName = "";
bool closing = false;
TensorFloat anchors, centersToCorners;
public struct BoundingBox
{
public float centerX;
public float centerY;
public float width;
public float height;
}
void Start()
{
//(Note: if using a webcam on mobile get permissions here first)
targetTexture = new RenderTexture(resolution.x, resolution.y, 0);
SetupInput();
SetupModel();
SetupBoundingBoxSprite();
}
void SetupInput()
{
switch (inputType)
{
case InputType.Webcam:
{
webcam = new WebCamTexture(deviceName, resolution.x, resolution.y);
webcam.requestedFPS = 30;
webcam.Play();
break;
}
case InputType.Video:
{
video = gameObject.AddComponent<VideoPlayer>();//new VideoPlayer();
video.renderMode = VideoRenderMode.APIOnly;
video.source = VideoSource.Url;
video.url = Application.streamingAssetsPath + "/"+videoName;
video.isLooping = true;
video.Play();
break;
}
default:
{
Graphics.Blit(inputImage, targetTexture);
}
break;
}
}
void SetupBoundingBoxSprite()
{
if (boundingboxSprite == null)
{
boundingboxSprite = Sprite.Create(borderTexture,
new Rect(0, 0, borderTexture.width, borderTexture.height),
new Vector2(borderTexture.width / 2, borderTexture.height / 2));
}
}
void Update()
{
if (inputType == InputType.Webcam)
{
// Format video input
if (!webcam.didUpdateThisFrame) return;
var aspect1 = (float)webcam.width / webcam.height;
var aspect2 = (float)resolution.x / resolution.y;
var gap = aspect2 / aspect1;
var vflip = webcam.videoVerticallyMirrored;
var scale = new Vector2(gap, vflip ? -1 : 1);
var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0);
Graphics.Blit(webcam, targetTexture, scale, offset);
}
if (inputType == InputType.Video)
{
var aspect1 = (float)video.width / video.height;
var aspect2 = (float)resolution.x / resolution.y;
var gap = aspect2 / aspect1;
var vflip = false;
var scale = new Vector2(gap, vflip ? -1 : 1);
var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0);
Graphics.Blit(video.texture, targetTexture, scale, offset);
}
if (inputType == InputType.Image)
{
Graphics.Blit(inputImage, targetTexture);
}
if (Input.GetKeyDown(KeyCode.Escape))
{
closing = true;
Application.Quit();
}
if (Input.GetKeyDown(KeyCode.P))
{
previewUI.enabled = !previewUI.enabled;
}
}
void LateUpdate()
{
if (!closing)
{
RunInference(targetTexture);
}
}
//Calculate the centers of the grid squares for two 16x16 grids and six 8x8 grids
//The positions of the faces are given relative to these "anchor points"
float[] GetGridBoxCoords()
{
var offsets = new float[896 * 4];
int n = 0;
AddGrid(offsets, 16, 2, 8, ref n);
AddGrid(offsets, 8, 6, 16, ref n);
return offsets;
}
void AddGrid(float[] offsets, int rows, int repeats, int cellWidth, ref int n)
{
for (int j = 0; j < repeats * rows * rows; j++)
{
offsets[n++] = cellWidth * ((j / repeats) % rows - (rows - 1) * 0.5f);
offsets[n++] = cellWidth * ((j / repeats / rows) - (rows - 1) * 0.5f);
n += 2;
}
}
void SetupModel()
{
float[] offsets = GetGridBoxCoords();
var model = ModelLoader.Load(modelAsset);
//We need to add extra layers to the model in order to aggregate the box predicions:
size = model.inputs[0].shape.ToTensorShape()[1]; // Input tensor width
anchors = new TensorFloat(new TensorShape(offsets.Length / 4, 4), offsets);
centersToCorners = new TensorFloat(new TensorShape(4, 4),
new float[]
{
1, 0, 1, 0,
0, 1, 0, 1,
-0.5f, 0, 0.5f, 0,
0, -0.5f, 0, 0.5f
});
var model2 = Functional.Compile(
input =>
{
var outputs = model.Forward(input);
var regressors = outputs[0][0]; //shape=(896,16)
var scores = outputs[1][0].Transpose(0, 1) - scoreThreshold; //shape=(1,896)
var boxCoords = regressors[.., 0..4] + FunctionalTensor.FromTensor(anchors);//(896,4)
var boxCorners = FF.MatMul(boxCoords, FunctionalTensor.FromTensor(centersToCorners));
var indices = FF.NMS(boxCorners, scores, iouThreshold); //shape=(N)
var indices2 = indices.Unsqueeze(-1).BroadcastTo(new int[] { 4 }); //shape=(N,4)
var output = FF.Gather(boxCoords, 0, indices2); //shape=(N,4)
var indices3 = indices.Unsqueeze(-1).BroadcastTo(new int[] { 16 }); //shape=(N,16)
var markersOutput = FF.Gather(regressors, 0, indices3); //shape=(N,16)
return (output, markersOutput);
},
InputDef.FromModel(model)[0]
);
worker = WorkerFactory.CreateWorker(backend, model2);
}
void DrawFaces(TensorFloat index3, TensorFloat regressors, int NMAX, Vector2 scale)
{
for (int n = 0; n < NMAX; n++)
{
//Draw bounding box of face
var box = new BoundingBox
{
centerX = index3[ n, 0] * scale.x,
centerY = index3[ n, 1] * scale.y,
width = index3[ n, 2] * scale.x,
height = index3[ n, 3] * scale.y
};
DrawBox(box, boundingboxSprite);
if (regressors == null) continue;
//Draw markers for eyes, ears, nose, mouth:
for (int j = 0; j < 6; j++)
{
var marker = new BoundingBox
{
centerX = box.centerX + (regressors[ n, 4 + j * 2] - regressors[ n, 0]) * scale.x,
centerY = box.centerY + (regressors[ n, 4 + j * 2 + 1] - regressors[ n, 1]) * scale.y,
width = 8.0f * scale.x,
height = 8.0f * scale.y,
};
DrawBox(marker, j < markerTextures.Length ? markerTextures[j] : boundingboxSprite);
}
}
}
void ExecuteML(Texture source)
{
var transform = new TextureTransform();
transform.SetDimensions(size, size, 3);
transform.SetTensorLayout(0, 3, 1, 2);
using var image = TextureConverter.ToTensor(source, transform);
worker.Execute(image);
using var output = worker.PeekOutput("output_0") as TensorFloat; //face coords
using var markersOutput = worker.PeekOutput("output_1") as TensorFloat; //contains markers
output.CompleteOperationsAndDownload();
markersOutput.CompleteOperationsAndDownload();
// Debug.Log(output.shape + " " + markersOutput.shape);
//return;
ClearAnnotations();
Vector2 markerScale = previewUI.rectTransform.rect.size / size;
//Debug.Log(output.shape + " " + markersOutput.shape);
DrawFaces(output, markersOutput, output.shape[0], markerScale);
}
void RunInference(Texture input)
{
// Face detection
ExecuteML(input);
previewUI.texture = input;
}
public void DrawBox(BoundingBox box, Sprite sprite)
{
var panel = new GameObject("ObjectBox");
panel.AddComponent<CanvasRenderer>();
panel.AddComponent<Image>();
panel.transform.SetParent(previewUI.transform, false);
var img = panel.GetComponent<Image>();
img.color = Color.white;
img.sprite = sprite;
img.type = Image.Type.Sliced;
panel.transform.localPosition = new Vector3(box.centerX, -box.centerY);
RectTransform rt = panel.GetComponent<RectTransform>();
rt.sizeDelta = new Vector2(box.width, box.height);
}
public void ClearAnnotations()
{
foreach (Transform child in previewUI.transform)
{
Destroy(child.gameObject);
}
}
void CleanUp()
{
closing = true;
anchors?.Dispose();
centersToCorners?.Dispose();
if (webcam) Destroy(webcam);
if (video) Destroy(video);
RenderTexture.active = null;
targetTexture.Release();
worker?.Dispose();
worker = null;
}
void OnDestroy()
{
CleanUp();
}
}