OpenTransformer commited on
Commit
2fde3ca
·
verified ·
1 Parent(s): 408a86d

Add AGILLM4 uploader local pruning

Browse files
Files changed (1) hide show
  1. upload_agillm4_checkpoints.py +61 -1
upload_agillm4_checkpoints.py CHANGED
@@ -83,6 +83,56 @@ def latest_file(glob_root: Path, pattern: str) -> Path | None:
83
  return max(files, key=lambda p: p.stat().st_mtime) if files else None
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def status_json(script: Path, log: Path, save_dir: Path) -> dict[str, Any]:
87
  result = subprocess.run(
88
  [sys.executable, "-u", str(script), "status", "--json", "--log", str(log), "--save_dir", str(save_dir)],
@@ -167,6 +217,8 @@ def main() -> int:
167
  parser.add_argument("--delta-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_DELTA_INTERVAL_SEC", str(24 * 3600))))
168
  parser.add_argument("--keep-full", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_FULL", "2")))
169
  parser.add_argument("--keep-delta", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_DELTA", "2")))
 
 
170
  parser.add_argument("--tail-lines", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_TAIL_LINES", "5000")))
171
  args = parser.parse_args()
172
 
@@ -175,14 +227,22 @@ def main() -> int:
175
  args.stage.mkdir(parents=True, exist_ok=True)
176
  state = load_json(args.state, {})
177
 
 
 
 
 
 
178
  status = status_json(args.script, args.log, args.save_dir)
179
  status["upload_policy"] = {
180
  "full_interval_sec": args.full_interval_sec,
181
  "delta_interval_sec": args.delta_interval_sec,
182
  "keep_full_current_files": args.keep_full,
183
  "keep_delta_current_files": args.keep_delta,
 
 
184
  "note": "Small status/log tail uploads are frequent; multi-GB deltas/full checkpoints are rate-limited for HF public storage.",
185
  }
 
186
  status_path = args.stage / "status.json"
187
  save_json(status_path, status)
188
  upload_file(api, args.repo, status_path, f"{prefix}/status/status.json", "Update AGILLM4 training status")
@@ -197,7 +257,7 @@ def main() -> int:
197
  upload_file(api, args.repo, args.stage / "latest.json", f"{prefix}/status/latest.json", "Update AGILLM4 latest checkpoint metadata")
198
 
199
  newest_delta = latest_file(args.save_dir, "*_delta_step*.pt")
200
- newest_full = latest_file(args.save_dir, "*_step*.pt")
201
  maybe_upload_large(api, args.repo, state, "delta", newest_delta, f"{prefix}/checkpoints/deltas", args.delta_interval_sec, args.keep_delta)
202
  maybe_upload_large(api, args.repo, state, "full", newest_full, f"{prefix}/checkpoints/full", args.full_interval_sec, args.keep_full)
203
 
 
83
  return max(files, key=lambda p: p.stat().st_mtime) if files else None
84
 
85
 
86
+ def latest_file_excluding(glob_root: Path, pattern: str, excluded_name_parts: tuple[str, ...]) -> Path | None:
87
+ files = [
88
+ p for p in glob_root.glob(pattern)
89
+ if p.is_file() and not any(part in p.name for part in excluded_name_parts)
90
+ ]
91
+ return max(files, key=lambda p: p.stat().st_mtime) if files else None
92
+
93
+
94
+ def checkpoint_artifacts(path: Path) -> list[Path]:
95
+ return [
96
+ path,
97
+ path.with_suffix(".sha256"),
98
+ path.with_suffix(path.suffix + ".upload.sha256"),
99
+ path.with_suffix(path.suffix + ".dtmp"),
100
+ ]
101
+
102
+
103
+ def prune_local_checkpoints(
104
+ save_dir: Path,
105
+ pattern: str,
106
+ keep: int,
107
+ label: str,
108
+ excluded_name_parts: tuple[str, ...] = (),
109
+ ) -> list[str]:
110
+ if keep < 0:
111
+ return []
112
+ files = sorted(
113
+ [
114
+ p for p in save_dir.glob(pattern)
115
+ if p.is_file()
116
+ and p.stat().st_size > 0
117
+ and not any(part in p.name for part in excluded_name_parts)
118
+ ],
119
+ key=lambda p: p.stat().st_mtime,
120
+ )
121
+ victims = files[:max(0, len(files) - keep)]
122
+ deleted: list[str] = []
123
+ for path in victims:
124
+ for artifact in checkpoint_artifacts(path):
125
+ try:
126
+ if artifact.exists():
127
+ artifact.unlink()
128
+ deleted.append(str(artifact))
129
+ except Exception as exc:
130
+ print(f"[upload] WARN local {label} prune failed for {artifact}: {exc}", flush=True)
131
+ if deleted:
132
+ print(f"[upload] pruned {len(deleted)} local {label} artifacts", flush=True)
133
+ return deleted
134
+
135
+
136
  def status_json(script: Path, log: Path, save_dir: Path) -> dict[str, Any]:
137
  result = subprocess.run(
138
  [sys.executable, "-u", str(script), "status", "--json", "--log", str(log), "--save_dir", str(save_dir)],
 
217
  parser.add_argument("--delta-interval-sec", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_DELTA_INTERVAL_SEC", str(24 * 3600))))
218
  parser.add_argument("--keep-full", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_FULL", "2")))
219
  parser.add_argument("--keep-delta", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_KEEP_DELTA", "2")))
220
+ parser.add_argument("--local-keep-full", type=int, default=int(os.environ.get("AGILLM4_LOCAL_KEEP_FULL", "1")))
221
+ parser.add_argument("--local-keep-delta", type=int, default=int(os.environ.get("AGILLM4_LOCAL_KEEP_DELTA", "1")))
222
  parser.add_argument("--tail-lines", type=int, default=int(os.environ.get("AGILLM4_UPLOAD_TAIL_LINES", "5000")))
223
  args = parser.parse_args()
224
 
 
227
  args.stage.mkdir(parents=True, exist_ok=True)
228
  state = load_json(args.state, {})
229
 
230
+ local_pruned = {
231
+ "delta": prune_local_checkpoints(args.save_dir, "*_delta_step*.pt", args.local_keep_delta, "delta"),
232
+ "full": prune_local_checkpoints(args.save_dir, "*_step*.pt", args.local_keep_full, "full", ("_delta_step",)),
233
+ }
234
+
235
  status = status_json(args.script, args.log, args.save_dir)
236
  status["upload_policy"] = {
237
  "full_interval_sec": args.full_interval_sec,
238
  "delta_interval_sec": args.delta_interval_sec,
239
  "keep_full_current_files": args.keep_full,
240
  "keep_delta_current_files": args.keep_delta,
241
+ "local_keep_full_files": args.local_keep_full,
242
+ "local_keep_delta_files": args.local_keep_delta,
243
  "note": "Small status/log tail uploads are frequent; multi-GB deltas/full checkpoints are rate-limited for HF public storage.",
244
  }
245
+ status["local_pruned"] = local_pruned
246
  status_path = args.stage / "status.json"
247
  save_json(status_path, status)
248
  upload_file(api, args.repo, status_path, f"{prefix}/status/status.json", "Update AGILLM4 training status")
 
257
  upload_file(api, args.repo, args.stage / "latest.json", f"{prefix}/status/latest.json", "Update AGILLM4 latest checkpoint metadata")
258
 
259
  newest_delta = latest_file(args.save_dir, "*_delta_step*.pt")
260
+ newest_full = latest_file_excluding(args.save_dir, "*_step*.pt", ("_delta_step",))
261
  maybe_upload_large(api, args.repo, state, "delta", newest_delta, f"{prefix}/checkpoints/deltas", args.delta_interval_sec, args.keep_delta)
262
  maybe_upload_large(api, args.repo, state, "full", newest_full, f"{prefix}/checkpoints/full", args.full_interval_sec, args.keep_full)
263
 
Free AI Image Generator No sign-up. Instant results. Open Now