Xenova HF Staff commited on
Commit
38feeb2
·
verified ·
1 Parent(s): a4e3b06

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -1
README.md CHANGED
@@ -8,4 +8,81 @@ tags:
8
  - audio
9
  ---
10
 
11
- ONNX-compatible weights for https://huggingface.co/kyutai/mimi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - audio
9
  ---
10
 
11
+ ONNX-compatible weights for https://huggingface.co/kyutai/mimi
12
+
13
+ ## Inference sample code
14
+ ```py
15
+ import onnxruntime as ort
16
+
17
+ encoder_session = ort.InferenceSession("encoder_model.onnx")
18
+ decoder_session = ort.InferenceSession("decoder_model.onnx")
19
+
20
+ encoder_inputs = {encoder_session.get_inputs()[0].name: dummy_encoder_inputs.numpy()}
21
+ encoder_outputs = encoder_session.run(None, encoder_inputs)[0]
22
+
23
+ decoder_inputs = {decoder_session.get_inputs()[0].name: encoder_outputs}
24
+ decoder_outputs = decoder_session.run(None, decoder_inputs)[0]
25
+
26
+ # Print the results
27
+ print("Encoder Output Shape:", encoder_outputs.shape)
28
+ print("Decoder Output Shape:", decoder_outputs.shape)
29
+ ```
30
+
31
+ ## Conversion sample code
32
+ ```py
33
+ import torch
34
+ import torch.nn as nn
35
+ from transformers import MimiModel
36
+
37
+ class MimiEncoder(nn.Module):
38
+ def __init__(self, model):
39
+ super(MimiEncoder, self).__init__()
40
+ self.model = model
41
+
42
+ def forward(self, input_values, padding_mask=None):
43
+ return self.model.encode(input_values, padding_mask=padding_mask).audio_codes
44
+
45
+ class MimiDecoder(nn.Module):
46
+ def __init__(self, model):
47
+ super(MimiDecoder, self).__init__()
48
+ self.model = model
49
+
50
+ def forward(self, audio_codes, padding_mask=None):
51
+ return self.model.decode(audio_codes, padding_mask=padding_mask).audio_values
52
+
53
+ model = MimiModel.from_pretrained("kyutai/mimi")
54
+ encoder = MimiEncoder(model)
55
+ decoder = MimiDecoder(model)
56
+
57
+ dummy_encoder_inputs = torch.randn((5, 1, 82500))
58
+ torch.onnx.export(
59
+ encoder,
60
+ dummy_encoder_inputs,
61
+ "encoder_model.onnx",
62
+ export_params=True,
63
+ opset_version=14,
64
+ do_constant_folding=True,
65
+ input_names=['input_values'],
66
+ output_names=['audio_codes'],
67
+ dynamic_axes={
68
+ 'input_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
69
+ 'audio_codes': {0: 'batch_size', 2: 'codes_length'},
70
+ },
71
+ )
72
+
73
+ dummy_decoder_inputs = torch.randint(100, (4, 32, 91))
74
+ torch.onnx.export(
75
+ decoder,
76
+ dummy_decoder_inputs,
77
+ "decoder_model.onnx",
78
+ export_params=True,
79
+ opset_version=14,
80
+ do_constant_folding=True,
81
+ input_names=['audio_codes'],
82
+ output_names=['audio_values'],
83
+ dynamic_axes={
84
+ 'audio_codes': {0: 'batch_size', 2: 'codes_length'},
85
+ 'audio_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
86
+ },
87
+ )
88
+ ```