File size: 5,435 Bytes
a8639ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re
def extract_run_name(filename):
"""Extract the run name from the filename."""
basename = os.path.basename(filename)
# Extract the part between '_' and '_tensorboard.csv'
match = re.search(r'_([^_]+)(?:-loss)?_tensorboard\.csv$', basename)
if match:
return match.group(1)
return basename.split('_')[1].split('-')[0] # Fallback extraction
def setup_plot_style():
"""Apply publication-quality styling to plots."""
plt.rcParams.update({
'font.family': 'serif',
'font.size': 12,
'axes.labelsize': 14,
'axes.titlesize': 16,
'legend.fontsize': 10,
'figure.dpi': 300,
'figure.figsize': (10, 6),
'lines.linewidth': 2.5,
'axes.grid': True,
'grid.linestyle': '--',
'grid.alpha': 0.6,
'axes.spines.top': False,
'axes.spines.right': False,
})
def get_metric_label(metric_name):
"""Return a human-readable label for the metric."""
labels = {
'loss_epoch': 'Loss',
'perplexityval_epoch': 'Validation Perplexity',
'topkacc_epoch': 'Top-K Accuracy',
'acc_trainstep': 'Training Accuracy'
}
return labels.get(metric_name, metric_name.replace('_', ' ').title())
def get_color_mapping(run_names):
"""Create a consistent color mapping for all runs."""
# Define a color palette with distinct colors
# colors = [
# '#1f77b4', # Blue
# '#ff7f0e', # Orange
# '#2ca02c', # Green
# '#d62728', # Red
# '#9467bd', # Purple
# '#8c564b', # Brown
# '#e377c2', # Pink
# '#7f7f7f', # Gray
# '#bcbd22', # Yellow-green
# '#17becf', # Cyan
# ]
# colors = """#091717
# #13B3B9
# #265E5A
# #20808D
# #25E5A5
# #20808D
# #FBFAF4
# #E4E3D4
# #FFD2A6
# #A84B2F
# #944454""".lower().split("\n\n")
colors = [
"#e6194b", # Red
"#f58231", # Orange
"#ffe119", # Yellow
"#bfef45", # Lime
"#3cb44b", # Green
"#42d4f4", # Cyan
"#4363d8", # Blue
"#911eb4", # Purple
"#f032e6", # Magenta
"#a9a9a9" # Grey
]
# Create a mapping of run names to colors
return {name: colors[i % len(colors)] for i, name in enumerate(sorted(run_names))}
def plot_metric(metric_dir, color_mapping, output_dir):
"""Plot all runs for a specific metric."""
metric_name = os.path.basename(metric_dir)
csv_files = glob.glob(os.path.join(metric_dir, '*.csv'))
if not csv_files:
print(f"No CSV files found in {metric_dir}")
return
plt.figure(figsize=(12, 7))
for csv_file in sorted(csv_files):
try:
# Read the CSV file
df = pd.read_csv(csv_file)
# Extract run name from filename
run_name = extract_run_name(csv_file)
# Plot the data using step as x-axis
color = color_mapping.get(run_name, 'gray')
plt.plot(df['Step'], df['Value'], label=run_name, color=color, alpha=0.9)
#plt.plot(df['Step'], df['Value'], label=run_name, color=color, marker='o', markersize=6, alpha=0.8)
except Exception as e:
print(f"Error processing {csv_file}: {e}")
# Set labels and title
plt.xlabel('Step')
plt.ylabel(get_metric_label(metric_name))
comparison = "Epoch" if "epoch" in metric_name else "Step"
plt.title(f'{get_metric_label(metric_name)} vs. {comparison}', fontweight='bold')
# Add legend with good positioning
plt.legend(loc='best', frameon=True, fancybox=True, framealpha=0.9,
shadow=True, borderpad=1, ncol=2 if len(csv_files) > 5 else 1)
# Add grid for better readability
plt.grid(True, linestyle='--', alpha=0.7)
# Tight layout for clean margins
plt.tight_layout()
# Save the plot
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f'{metric_name}_plot.png')
plt.savefig(output_path, bbox_inches='tight')
print(f"Saved plot to {output_path}")
# Close the figure to free memory
plt.close()
def main():
# Base directory containing the metric directories
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs_jsons')
# Output directory for plots
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots')
os.makedirs(output_dir, exist_ok=True)
# Setup plot style
setup_plot_style()
# Get all metric directories
metric_dirs = [d for d in glob.glob(os.path.join(base_dir, '*')) if os.path.isdir(d)]
# Collect all run names across all metrics for consistent coloring
all_run_names = set()
for metric_dir in metric_dirs:
csv_files = glob.glob(os.path.join(metric_dir, '*.csv'))
for csv_file in csv_files:
run_name = extract_run_name(csv_file)
all_run_names.add(run_name)
# Create color mapping
color_mapping = get_color_mapping(all_run_names)
# Plot each metric
for metric_dir in metric_dirs:
plot_metric(metric_dir, color_mapping, output_dir)
print(f"All plots have been generated in {output_dir}")
if __name__ == '__main__':
main()
|