| |
| """ |
| Script to analyze and compare training results from multiple model runs. |
| """ |
|
|
| import json |
| import os |
| from pathlib import Path |
|
|
| def load_metadata(run_dir): |
| """Load metadata from a training run directory""" |
| metadata_path = os.path.join(run_dir, "metadata.json") |
| if os.path.exists(metadata_path): |
| with open(metadata_path, 'r', encoding='utf-8') as f: |
| return json.load(f) |
| return None |
|
|
| def analyze_all_runs(): |
| """Analyze all training runs and create comparison""" |
| runs_dir = Path("runs") |
| results = [] |
|
|
| |
| for run_dir in runs_dir.glob("*/"): |
| if run_dir.is_dir(): |
| metadata = load_metadata(run_dir) |
| if metadata: |
| results.append({ |
| 'run_id': run_dir.name, |
| 'model': metadata.get('classifier', 'Unknown'), |
| 'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank', |
| 'max_features': metadata.get('max_features', 0), |
| 'ngram_range': metadata.get('ngram_range', [1,1]), |
| 'train_accuracy': metadata.get('train_accuracy', 0), |
| 'test_accuracy': metadata.get('test_accuracy', 0), |
| 'train_time': metadata.get('train_time', 0), |
| 'prediction_time': metadata.get('prediction_time', 0), |
| 'train_samples': metadata.get('train_samples', 0), |
| 'test_samples': metadata.get('test_samples', 0) |
| }) |
|
|
| return results |
|
|
| def print_comparison_table(results): |
| """Print formatted comparison table""" |
| print("\n" + "="*120) |
| print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS") |
| print("="*120) |
|
|
| |
| vntc_results = [r for r in results if r['dataset'] == 'VNTC'] |
|
|
| if vntc_results: |
| print("\nVNTC Dataset (Vietnamese News Classification):") |
| print("-"*120) |
| print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") |
| print("-"*120) |
|
|
| |
| vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True) |
|
|
| for result in vntc_results: |
| model = result['model'][:18] |
| features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" |
| ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" |
| train_acc = f"{result['train_accuracy']:.4f}" |
| test_acc = f"{result['test_accuracy']:.4f}" |
| train_time = f"{result['train_time']:.1f}s" |
| pred_time = f"{result['prediction_time']:.1f}s" |
|
|
| print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") |
|
|
| |
| bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] |
|
|
| if bank_results: |
| print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):") |
| print("-"*120) |
| print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}") |
| print("-"*120) |
|
|
| |
| bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True) |
|
|
| for result in bank_results: |
| model = result['model'][:18] |
| features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A" |
| ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}" |
| train_acc = f"{result['train_accuracy']:.4f}" |
| test_acc = f"{result['test_accuracy']:.4f}" |
| train_time = f"{result['train_time']:.1f}s" |
| pred_time = f"{result['prediction_time']:.1f}s" |
|
|
| print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}") |
|
|
| print("="*120) |
|
|
| if vntc_results: |
| best_vntc = max(vntc_results, key=lambda x: x['test_accuracy']) |
| print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy") |
|
|
| if bank_results: |
| best_bank = max(bank_results, key=lambda x: x['test_accuracy']) |
| print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy") |
|
|
| def main(): |
| """Main analysis function""" |
| print("Analyzing Vietnamese Text Classification Training Results...") |
|
|
| results = analyze_all_runs() |
|
|
| if not results: |
| print("No training results found in runs/ directory.") |
| return |
|
|
| print(f"Found {len(results)} training runs.") |
| print_comparison_table(results) |
|
|
| |
| vntc_results = [r for r in results if r['dataset'] == 'VNTC'] |
| bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank'] |
|
|
| print("\nSummary:") |
| print(f"- VNTC runs: {len(vntc_results)}") |
| print(f"- UTS2017_Bank runs: {len(bank_results)}") |
|
|
| if vntc_results: |
| avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results) |
| print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}") |
|
|
| if bank_results: |
| avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results) |
| print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |