|
--- |
|
license: apache-2.0 |
|
language: en |
|
pipeline_tag: image-to-text |
|
--- |
|
|
|
# TotalText-STDR: End-to-End Scene Text Detection and Recognition |
|
|
|
This repository contains the official models and inference pipeline for the TotalText Scene Text Detection and Recognition (STDR) project. It provides a complete solution for identifying and transcribing text, including curved text, from images. |
|
|
|
The pipeline combines a fine-tuned Differentiable Binarization (DBNet) model for text detection and a pre-trained Attention-based model (TPS-ResNet-BiLSTM-Attn) for text recognition. |
|
|
|
## Models |
|
|
|
### Text Detection |
|
- **Architecture**: Differentiable Binarization (DBNet) with a ResNet-50 backbone. |
|
- **Pretraining**: Pre-trained on the SynthText dataset. |
|
- **Fine-tuning**: Fine-tuned on the Total-Text dataset for high precision on curved and oriented text. |
|
- **Framework**: PyTorch |
|
|
|
### Text Recognition |
|
- **Architecture**: TPS-ResNet-BiLSTM-Attention. |
|
- **Training**: Pre-trained on a large-scale dataset of real and synthetic word images. |
|
- **Framework**: PyTorch |
|
|
|
## How to Use |
|
|
|
The end-to-end inference logic is encapsulated in the `OCR_Pipeline` class in `pipeline.py`. |
|
|
|
### 1. Installation |
|
|
|
First, clone the repository and install the required dependencies: |
|
|
|
```bash |
|
git clone https://huggingface.co/sakshamhooda/TotalText-STDR |
|
cd TotalText-STDR |
|
|
|
# Install dependencies (use of a virtual environment is recommended) |
|
# Note: Ensure you have the correct PyTorch version for your CUDA setup. |
|
pip install -r requirements.txt |
|
``` |
|
|
|
### 2. Inference |
|
|
|
You can run the pipeline on an image using the following Python script. Make sure the model weights are present in the repository. |
|
|
|
```python |
|
import cv2 |
|
from pathlib import Path |
|
from pipeline import OCR_Pipeline |
|
|
|
# --- Configuration --- |
|
DETECTOR_CKPT = "runs/dbnet_detector/dbnet_best_tt_1.pth" |
|
RECOGNIZER_CKPT = "recognition-ptr-weights/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth" |
|
CHARSET_PATH = "config/charset_totaltext.txt" |
|
IMAGE_PATH = "Total-Text-Dataset/test/img/img4.jpg" # Example image |
|
|
|
# --- Initialization --- |
|
pipeline = OCR_Pipeline( |
|
det_model_path=DETECTOR_CKPT, |
|
rec_model_path=RECOGNIZER_CKPT, |
|
charset_path=CHARSET_PATH, |
|
) |
|
|
|
# --- Run Inference --- |
|
print(f"Running inference on: {IMAGE_PATH}") |
|
input_image = cv2.imread(IMAGE_PATH) |
|
|
|
results, heatmap = pipeline.run(input_image) |
|
|
|
# --- Visualize and Print Results --- |
|
print(f"Found {len(results)} text instances.") |
|
|
|
output_image = input_image.copy() |
|
for res in results: |
|
poly = np.array(res['polygon']).astype(np.int32) |
|
text = res['text'] |
|
|
|
cv2.polylines(output_image, [poly], isClosed=True, color=(0, 255, 0), thickness=2) |
|
cv2.putText(output_image, text, tuple(poly[0]), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 0, 0), 2) |
|
|
|
# Save the output |
|
output_path = Path("./pipeline_output.jpg") |
|
cv2.imwrite(str(output_path), output_image) |
|
print(f"Output image with results saved to: {output_path}") |
|
|
|
``` |
|
|
|
## Project Information |
|
|
|
This project was developed to provide a high-precision OCR solution for the Total-Text dataset. Experiment tracking was managed with W&B, and model versioning with MLflow. For more details on the training process, see the original project source. |