|
|
--- |
|
|
base_model: google/gemma-3n-E4B-it |
|
|
library_name: peft |
|
|
model_name: gemma-3n-E4B-transcribe-zh-tw-1 |
|
|
tags: |
|
|
- generated_from_trainer |
|
|
- trl |
|
|
- sft |
|
|
licence: license |
|
|
--- |
|
|
|
|
|
# Model Card for gemma-3n-E4B-transcribe-zh-tw-1 |
|
|
|
|
|
This model is a fine-tuned version of [google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it). |
|
|
It has been trained using [TRL](https://github.com/huggingface/trl). |
|
|
|
|
|
## Quick start |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from peft import PeftModel |
|
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("google/gemma-3n-E4B-it", device_map="auto") |
|
|
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3n-E4B-it") |
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, "JacobLinCool/gemma-3n-E4B-transcribe-zh-tw-1" |
|
|
).to(device) |
|
|
|
|
|
|
|
|
def trascribe(model, processor, audio): |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": "You are an assistant that transcribes speech accurately.", |
|
|
} |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "audio", "audio": audio}, |
|
|
{"type": "text", "text": "Transcribe this audio."}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
input_ids = processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
input_ids = input_ids.to(device, dtype=model.dtype) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**input_ids, max_new_tokens=128) |
|
|
|
|
|
prediction = processor.batch_decode( |
|
|
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
prediction = prediction.split("\nmodel\n")[-1].strip() |
|
|
return prediction |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
prediction = trascribe(model, processor, "/workspace/audio.mp3") |
|
|
print(prediction) |
|
|
``` |
|
|
|
|
|
## Training procedure |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This model was trained with SFT. |
|
|
|
|
|
### Framework versions |
|
|
|
|
|
- PEFT 0.15.2 |
|
|
- TRL: 0.19.0 |
|
|
- Transformers: 4.53.0 |
|
|
- Pytorch: 2.8.0.dev20250319+cu128 |
|
|
- Datasets: 3.6.0 |
|
|
- Tokenizers: 0.21.2 |
|
|
|
|
|
## Citations |
|
|
|
|
|
|
|
|
|
|
|
Cite TRL as: |
|
|
|
|
|
```bibtex |
|
|
@misc{vonwerra2022trl, |
|
|
title = {{TRL: Transformer Reinforcement Learning}}, |
|
|
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, |
|
|
year = 2020, |
|
|
journal = {GitHub repository}, |
|
|
publisher = {GitHub}, |
|
|
howpublished = {\url{https://github.com/huggingface/trl}} |
|
|
} |
|
|
``` |