ViTCM-LLM / app.py
Mark-CHAE's picture
Upload folder using huggingface_hub
8374b0f verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
from peft import PeftModel
from PIL import Image
import io
# Page configuration
st.set_page_config(
page_title="ViTCM_LLM Tongue Diagnosis",
page_icon="🖼️",
layout="wide"
)
# Title
st.title("🖼️ ViTCM_LLM Tongue Diagnosis")
st.markdown("**ViTCM_LLM - Traditional Chinese Medicine Tongue Diagnosis Model**")
# Model loading
@st.cache_resource
def load_model():
"""Load the ViTCM_LLM model for TCM tongue diagnosis."""
try:
# Tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-32B-Instruct")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-32B-Instruct")
# Base model
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-VL-32B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
# LoRA adapter
model = PeftModel.from_pretrained(base_model, "Mark-CHAE/shezhen")
return model, tokenizer, processor
except Exception as e:
st.error(f"Model loading failed: {e}")
return None, None, None
# Sidebar
with st.sidebar:
st.header("⚙️ Settings")
# Inference parameters
max_length = st.slider("Max tokens", 100, 1024, 512)
temperature = st.slider("Temperature", 0.1, 2.0, 0.7, 0.1)
top_p = st.slider("Top-p", 0.1, 1.0, 0.9, 0.05)
# Model load button
if st.button("🚀 Load Model", type="primary"):
with st.spinner("Loading ViTCM_LLM model..."):
model, tokenizer, processor = load_model()
if model is not None:
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.processor = processor
st.session_state.model_loaded = True
st.success("✅ ViTCM_LLM model loaded successfully!")
# Main content
if not st.session_state.get('model_loaded', False):
st.info("👈 Click 'Load Model' button in the sidebar to start tongue diagnosis.")
st.stop()
# Image upload
st.header("📸 Tongue Image Upload")
uploaded_file = st.file_uploader(
"Upload a tongue image for TCM diagnosis",
type=['png', 'jpg', 'jpeg']
)
if uploaded_file is not None:
# Display image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded tongue image", use_column_width=True)
# Question input
st.header("❓ Tongue Diagnosis Question")
question = st.text_area(
"Ask a question about the tongue image for TCM diagnosis",
placeholder="e.g., 根据图片判断舌诊内容",
height=100
)
# Analyze button
if st.button("🔍 Analyze Tongue", type="primary") and question.strip():
with st.spinner("Analyzing tongue for TCM diagnosis..."):
try:
# Construct prompt
prompt = f"<|im_start|>user\n<image>\n{question}<|im_end|>\n<|im_start|>assistant\n"
# Process inputs
inputs = st.session_state.processor(
text=prompt,
images=image,
return_tensors="pt"
)
# Inference
with torch.no_grad():
outputs = st.session_state.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=st.session_state.tokenizer.eos_token_id
)
# Process results
response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = response.split("<|im_start|>assistant")[-1].strip()
# Display results
st.header("💡 TCM Tongue Diagnosis")
st.markdown(f"**Question:** {question}")
st.markdown(f"**Diagnosis:** {answer}")
except Exception as e:
st.error(f"Error occurred during tongue analysis: {e}")
# Usage examples
with st.expander("📚 Tongue Diagnosis Examples"):
st.markdown("""
### Tongue Diagnosis Questions:
- 根据图片判断舌诊内容
- 分析舌头的颜色和形状
- 判断舌苔的厚薄和颜色
- 分析舌头的裂纹和斑点
- 评估舌头的整体健康状况
""")
# Model information
with st.expander("ℹ️ Model Information"):
st.markdown("""
### ViTCM_LLM - Traditional Chinese Medicine Tongue Diagnosis Model
- **Base Model**: Qwen/Qwen2.5-VL-32B-Instruct
- **Adapter**: Mark-CHAE/shezhen (ViTCM_LLM)
- **Language**: Chinese
- **License**: Apache-2.0
- **Specialization**: Traditional Chinese Medicine Tongue Diagnosis
""")
st.markdown("---")
st.markdown("**ViTCM_LLM Tongue Diagnosis** | Powered by Qwen2.5-VL-32B-Instruct")