Update README.md
Browse files
README.md
CHANGED
|
@@ -21,7 +21,6 @@ CED are simple ViT-Transformer-based models for audio tagging. Notable differenc
|
|
| 21 |
- **Demo:** https://huggingface.co/spaces/mispeech/ced-base
|
| 22 |
|
| 23 |
## Install
|
| 24 |
-
|
| 25 |
```bash
|
| 26 |
git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
|
| 27 |
pip install -r requirements.txt
|
|
@@ -32,20 +31,21 @@ pip install -r requirements.txt
|
|
| 32 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
| 33 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
| 34 |
|
| 35 |
-
>>>
|
| 36 |
-
>>> feature_extractor = CedFeatureExtractor.from_pretrained(
|
| 37 |
-
>>> model = CedForAudioClassification.from_pretrained(
|
| 38 |
|
| 39 |
>>> import torchaudio
|
| 40 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
| 41 |
-
|
| 42 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
|
|
|
|
|
|
| 43 |
>>> with torch.no_grad():
|
| 44 |
... logits = model(**inputs).logits
|
| 45 |
|
| 46 |
-
>>>
|
| 47 |
-
>>>
|
| 48 |
-
>>> model.config.id2label[predicted_class_ids]
|
| 49 |
'Finger snapping'
|
| 50 |
```
|
| 51 |
|
|
|
|
| 21 |
- **Demo:** https://huggingface.co/spaces/mispeech/ced-base
|
| 22 |
|
| 23 |
## Install
|
|
|
|
| 24 |
```bash
|
| 25 |
git clone https://github.com/jimbozhang/hf_transformers_custom_model_ced.git
|
| 26 |
pip install -r requirements.txt
|
|
|
|
| 31 |
>>> from ced_model.feature_extraction_ced import CedFeatureExtractor
|
| 32 |
>>> from ced_model.modeling_ced import CedForAudioClassification
|
| 33 |
|
| 34 |
+
>>> model_name = "mispeech/ced-tiny"
|
| 35 |
+
>>> feature_extractor = CedFeatureExtractor.from_pretrained(model_name)
|
| 36 |
+
>>> model = CedForAudioClassification.from_pretrained(model_name)
|
| 37 |
|
| 38 |
>>> import torchaudio
|
| 39 |
>>> audio, sampling_rate = torchaudio.load("resources/JeD5V5aaaoI_931_932.wav")
|
| 40 |
+
>>> assert sampling_rate == 16000
|
| 41 |
>>> inputs = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt")
|
| 42 |
+
|
| 43 |
+
>>> import torch
|
| 44 |
>>> with torch.no_grad():
|
| 45 |
... logits = model(**inputs).logits
|
| 46 |
|
| 47 |
+
>>> predicted_class_id = torch.argmax(logits, dim=-1).item()
|
| 48 |
+
>>> model.config.id2label[predicted_class_id]
|
|
|
|
| 49 |
'Finger snapping'
|
| 50 |
```
|
| 51 |
|