|
--- |
|
language: en |
|
license: mit |
|
tags: |
|
- summarization |
|
- nlp |
|
- transformer |
|
- text-generation |
|
- huggingface |
|
datasets: |
|
- cnn_dailymail |
|
metrics: |
|
- rouge |
|
widget: |
|
- text: "The quick brown fox jumps over the lazy dog. This is a sample article for testing summarization." |
|
--- |
|
|
|
# Text Summarization Model |
|
|
|
## Model Overview |
|
This is a **text summarization model** built using a Seq2Seq architecture. |
|
It was trained on the **CNN/DailyMail dataset (3.0.0)** and is capable of generating concise summaries of news articles or other long-form texts. |
|
|
|
**Intended Use:** |
|
- Summarizing articles, documents, or reports. |
|
- Extracting key points from text for quick understanding. |
|
|
|
**Limitations & Biases:** |
|
- May struggle with extremely long articles or highly technical content. |
|
- Generated summaries may occasionally miss nuanced details. |
|
|
|
--- |
|
|
|
|
|
## Training Details |
|
- **Dataset**: CNN/DailyMail (3.0.0 version) |
|
- **Preprocessing**: Truncation at 512 tokens for input, summaries capped at 150 tokens. |
|
- **Hyperparameters**: |
|
- Optimizer: AdamW (PyTorch) |
|
- Learning rate: 2e-5 |
|
- Batch size: 4 (per device) |
|
- Epochs: 10 |
|
- **Evaluation Metrics**: ROUGE-1, ROUGE-2, ROUGE-L |
|
|
|
--- |
|
|
|
## Evaluation Results |
|
| Metric | Score (%) | |
|
|-----------|-----------| |
|
| ROUGE-1 | 83.3 | |
|
| ROUGE-2 | 60.0 | |
|
| ROUGE-L | 83.3 | |
|
| ROUGE-Lsum| 83.3 | |
|
|
|
|
|
--- |
|
|
|
## Example Usage |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("your-username/your-model-name") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("your-username/your-model-name") |
|
|
|
text = "The stock market saw a significant drop today due to rising inflation concerns. Investors are cautious ahead of the Federal Reserve's upcoming decision." |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
|
summary_ids = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True) |
|
|
|
print(tokenizer.decode(summary_ids[0], skip_special_tokens=True)) |
|
|