Add AGILLM4 uploader local pruning
Browse files
upload_agillm4_checkpoints.py
CHANGED
|
@@ -83,6 +83,56 @@ def latest_file(glob_root: Path, pattern: str) -> Path | None:
|
|
| 83 |
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def status_json(script: Path, log: Path, save_dir: Path) -> dict[str, Any]:
|
| 87 |
result = subprocess.run(
|
| 88 |
[sys.executable, "-u", str(script), "status", "--json", "--log", str(log), "--save_dir", str(save_dir)],
|
|
@@ -167,6 +217,8 @@ def main() -> int:
|
|
| 167 |
parser.add_argument("--delta-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_DELTA_INTERVAL_SEC", str(24 * 3600))))
|
| 168 |
parser.add_argument("--keep-full", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_FULL", "2")))
|
| 169 |
parser.add_argument("--keep-delta", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_DELTA", "2")))
|
|
|
|
|
|
|
| 170 |
parser.add_argument("--tail-lines", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_TAIL_LINES", "5000")))
|
| 171 |
args = parser.parse_args()
|
| 172 |
|
|
@@ -175,14 +227,22 @@ def main() -> int:
|
|
| 175 |
args.stage.mkdir(parents=True, exist_ok=True)
|
| 176 |
state = load_json(args.state, {})
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
status = status_json(args.script, args.log, args.save_dir)
|
| 179 |
status["upload_policy"] = {
|
| 180 |
"full_interval_sec": args.full_interval_sec,
|
| 181 |
"delta_interval_sec": args.delta_interval_sec,
|
| 182 |
"keep_full_current_files": args.keep_full,
|
| 183 |
"keep_delta_current_files": args.keep_delta,
|
|
|
|
|
|
|
| 184 |
"note": "Small status/log tail uploads are frequent; multi-GB deltas/full checkpoints are rate-limited for HF public storage.",
|
| 185 |
}
|
|
|
|
| 186 |
status_path = args.stage / "status.json"
|
| 187 |
save_json(status_path, status)
|
| 188 |
upload_file(api, args.repo, status_path, f"{prefix}/status/status.json", "Update AGILLM4 training status")
|
|
@@ -197,7 +257,7 @@ def main() -> int:
|
|
| 197 |
upload_file(api, args.repo, args.stage / "latest.json", f"{prefix}/status/latest.json", "Update AGILLM4 latest checkpoint metadata")
|
| 198 |
|
| 199 |
newest_delta = latest_file(args.save_dir, "*_delta_step*.pt")
|
| 200 |
-
newest_full =
|
| 201 |
maybe_upload_large(api, args.repo, state, "delta", newest_delta, f"{prefix}/checkpoints/deltas", args.delta_interval_sec, args.keep_delta)
|
| 202 |
maybe_upload_large(api, args.repo, state, "full", newest_full, f"{prefix}/checkpoints/full", args.full_interval_sec, args.keep_full)
|
| 203 |
|
|
|
|
| 83 |
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
| 84 |
|
| 85 |
|
| 86 |
+
def latest_file_excluding(glob_root: Path, pattern: str, excluded_name_parts: tuple[str, ...]) -> Path | None:
|
| 87 |
+
files = [
|
| 88 |
+
p for p in glob_root.glob(pattern)
|
| 89 |
+
if p.is_file() and not any(part in p.name for part in excluded_name_parts)
|
| 90 |
+
]
|
| 91 |
+
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def checkpoint_artifacts(path: Path) -> list[Path]:
|
| 95 |
+
return [
|
| 96 |
+
path,
|
| 97 |
+
path.with_suffix(".sha256"),
|
| 98 |
+
path.with_suffix(path.suffix + ".upload.sha256"),
|
| 99 |
+
path.with_suffix(path.suffix + ".dtmp"),
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def prune_local_checkpoints(
|
| 104 |
+
save_dir: Path,
|
| 105 |
+
pattern: str,
|
| 106 |
+
keep: int,
|
| 107 |
+
label: str,
|
| 108 |
+
excluded_name_parts: tuple[str, ...] = (),
|
| 109 |
+
) -> list[str]:
|
| 110 |
+
if keep < 0:
|
| 111 |
+
return []
|
| 112 |
+
files = sorted(
|
| 113 |
+
[
|
| 114 |
+
p for p in save_dir.glob(pattern)
|
| 115 |
+
if p.is_file()
|
| 116 |
+
and p.stat().st_size > 0
|
| 117 |
+
and not any(part in p.name for part in excluded_name_parts)
|
| 118 |
+
],
|
| 119 |
+
key=lambda p: p.stat().st_mtime,
|
| 120 |
+
)
|
| 121 |
+
victims = files[:max(0, len(files) - keep)]
|
| 122 |
+
deleted: list[str] = []
|
| 123 |
+
for path in victims:
|
| 124 |
+
for artifact in checkpoint_artifacts(path):
|
| 125 |
+
try:
|
| 126 |
+
if artifact.exists():
|
| 127 |
+
artifact.unlink()
|
| 128 |
+
deleted.append(str(artifact))
|
| 129 |
+
except Exception as exc:
|
| 130 |
+
print(f"[upload] WARN local {label} prune failed for {artifact}: {exc}", flush=True)
|
| 131 |
+
if deleted:
|
| 132 |
+
print(f"[upload] pruned {len(deleted)} local {label} artifacts", flush=True)
|
| 133 |
+
return deleted
|
| 134 |
+
|
| 135 |
+
|
| 136 |
def status_json(script: Path, log: Path, save_dir: Path) -> dict[str, Any]:
|
| 137 |
result = subprocess.run(
|
| 138 |
[sys.executable, "-u", str(script), "status", "--json", "--log", str(log), "--save_dir", str(save_dir)],
|
|
|
|
| 217 |
parser.add_argument("--delta-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_DELTA_INTERVAL_SEC", str(24 * 3600))))
|
| 218 |
parser.add_argument("--keep-full", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_FULL", "2")))
|
| 219 |
parser.add_argument("--keep-delta", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_DELTA", "2")))
|
| 220 |
+
parser.add_argument("--local-keep-full", type=int, default=int(os.environ.get("AGILLM4_LOCAL_KEEP_FULL", "1")))
|
| 221 |
+
parser.add_argument("--local-keep-delta", type=int, default=int(os.environ.get("AGILLM4_LOCAL_KEEP_DELTA", "1")))
|
| 222 |
parser.add_argument("--tail-lines", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_TAIL_LINES", "5000")))
|
| 223 |
args = parser.parse_args()
|
| 224 |
|
|
|
|
| 227 |
args.stage.mkdir(parents=True, exist_ok=True)
|
| 228 |
state = load_json(args.state, {})
|
| 229 |
|
| 230 |
+
local_pruned = {
|
| 231 |
+
"delta": prune_local_checkpoints(args.save_dir, "*_delta_step*.pt", args.local_keep_delta, "delta"),
|
| 232 |
+
"full": prune_local_checkpoints(args.save_dir, "*_step*.pt", args.local_keep_full, "full", ("_delta_step",)),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
status = status_json(args.script, args.log, args.save_dir)
|
| 236 |
status["upload_policy"] = {
|
| 237 |
"full_interval_sec": args.full_interval_sec,
|
| 238 |
"delta_interval_sec": args.delta_interval_sec,
|
| 239 |
"keep_full_current_files": args.keep_full,
|
| 240 |
"keep_delta_current_files": args.keep_delta,
|
| 241 |
+
"local_keep_full_files": args.local_keep_full,
|
| 242 |
+
"local_keep_delta_files": args.local_keep_delta,
|
| 243 |
"note": "Small status/log tail uploads are frequent; multi-GB deltas/full checkpoints are rate-limited for HF public storage.",
|
| 244 |
}
|
| 245 |
+
status["local_pruned"] = local_pruned
|
| 246 |
status_path = args.stage / "status.json"
|
| 247 |
save_json(status_path, status)
|
| 248 |
upload_file(api, args.repo, status_path, f"{prefix}/status/status.json", "Update AGILLM4 training status")
|
|
|
|
| 257 |
upload_file(api, args.repo, args.stage / "latest.json", f"{prefix}/status/latest.json", "Update AGILLM4 latest checkpoint metadata")
|
| 258 |
|
| 259 |
newest_delta = latest_file(args.save_dir, "*_delta_step*.pt")
|
| 260 |
+
newest_full = latest_file_excluding(args.save_dir, "*_step*.pt", ("_delta_step",))
|
| 261 |
maybe_upload_large(api, args.repo, state, "delta", newest_delta, f"{prefix}/checkpoints/deltas", args.delta_interval_sec, args.keep_delta)
|
| 262 |
maybe_upload_large(api, args.repo, state, "full", newest_full, f"{prefix}/checkpoints/full", args.full_interval_sec, args.keep_full)
|
| 263 |
|