diff --git a/consts.ts b/consts.ts new file mode 100644 index 0000000000000000000000000000000000000000..5d34e9caecf3f843e34c24f341de91f6c715f17e --- /dev/null +++ b/consts.ts @@ -0,0 +1 @@ +export const HUB_URL = "https://huggingface.co"; diff --git a/error.ts b/error.ts new file mode 100644 index 0000000000000000000000000000000000000000..0da5b2dd9686318e8bafb3c8c9f155decf038d2d --- /dev/null +++ b/error.ts @@ -0,0 +1,49 @@ +import type { JsonObject } from "./vendor/type-fest/basic"; + +export async function createApiError( + response: Response, + opts?: { requestId?: string; message?: string } +): Promise { + const error = new HubApiError(response.url, response.status, response.headers.get("X-Request-Id") ?? opts?.requestId); + + error.message = `Api error with status ${error.statusCode}${opts?.message ? `. ${opts.message}` : ""}`; + + const trailer = [`URL: ${error.url}`, error.requestId ? `Request ID: ${error.requestId}` : undefined] + .filter(Boolean) + .join(". "); + + if (response.headers.get("Content-Type")?.startsWith("application/json")) { + const json = await response.json(); + error.message = json.error || json.message || error.message; + if (json.error_description) { + error.message = error.message ? error.message + `: ${json.error_description}` : json.error_description; + } + error.data = json; + } else { + error.data = { message: await response.text() }; + } + + error.message += `. ${trailer}`; + + throw error; +} + +/** + * Error thrown when an API call to the Hugging Face Hub fails. + */ +export class HubApiError extends Error { + statusCode: number; + url: string; + requestId?: string; + data?: JsonObject; + + constructor(url: string, statusCode: number, requestId?: string, message?: string) { + super(message); + + this.statusCode = statusCode; + this.requestId = requestId; + this.url = url; + } +} + +export class InvalidApiResponseFormatError extends Error {} diff --git a/index.ts b/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..a73655c797efac022292a0bc127a9a4b745138ce --- /dev/null +++ b/index.ts @@ -0,0 +1,25 @@ +export * from "./lib"; +// Typescript 5 will add 'export type *' +export type { + AccessToken, + AccessTokenRole, + AuthType, + Credentials, + PipelineType, + RepoDesignation, + RepoFullName, + RepoId, + RepoType, + SpaceHardwareFlavor, + SpaceResourceConfig, + SpaceResourceRequirement, + SpaceRuntime, + SpaceSdk, + SpaceStage, +} from "./types/public"; +export { HubApiError, InvalidApiResponseFormatError } from "./error"; +/** + * Only exported for E2Es convenience + */ +export { sha256 as __internal_sha256 } from "./utils/sha256"; +export { XetBlob as __internal_XetBlob } from "./utils/XetBlob"; diff --git a/lib/cache-management.spec.ts b/lib/cache-management.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..3bfed63d92ae1ea23689f7859eb95fc854fc2424 --- /dev/null +++ b/lib/cache-management.spec.ts @@ -0,0 +1,137 @@ +import { describe, test, expect, vi, beforeEach } from "vitest"; +import { + scanCacheDir, + scanCachedRepo, + scanSnapshotDir, + parseRepoType, + getBlobStat, + type CachedFileInfo, +} from "./cache-management"; +import { stat, readdir, realpath, lstat } from "node:fs/promises"; +import type { Dirent, Stats } from "node:fs"; +import { join } from "node:path"; + +// Mocks +vi.mock("node:fs/promises"); + +beforeEach(() => { + vi.resetAllMocks(); + vi.restoreAllMocks(); +}); + +describe("scanCacheDir", () => { + test("should throw an error if cacheDir is not a directory", async () => { + vi.mocked(stat).mockResolvedValueOnce({ + isDirectory: () => false, + } as Stats); + + await expect(scanCacheDir("/fake/dir")).rejects.toThrow("Scan cache expects a directory"); + }); + + test("empty directory should return an empty set of repository and no warnings", async () => { + vi.mocked(stat).mockResolvedValueOnce({ + isDirectory: () => true, + } as Stats); + + // mock empty cache folder + vi.mocked(readdir).mockResolvedValue([]); + + const result = await scanCacheDir("/fake/dir"); + + // cacheDir must have been read + expect(readdir).toHaveBeenCalledWith("/fake/dir"); + + expect(result.warnings.length).toBe(0); + expect(result.repos).toHaveLength(0); + expect(result.size).toBe(0); + }); +}); + +describe("scanCachedRepo", () => { + test("should throw an error for invalid repo path", async () => { + await expect(() => { + return scanCachedRepo("/fake/repo_path"); + }).rejects.toThrow("Repo path is not a valid HuggingFace cache directory"); + }); + + test("should throw an error if the snapshot folder does not exist", async () => { + vi.mocked(readdir).mockResolvedValue([]); + vi.mocked(stat).mockResolvedValue({ + isDirectory: () => false, + } as Stats); + + await expect(() => { + return scanCachedRepo("/fake/cacheDir/models--hello-world--name"); + }).rejects.toThrow("Snapshots dir doesn't exist in cached repo"); + }); + + test("should properly parse the repository name", async () => { + const repoPath = "/fake/cacheDir/models--hello-world--name"; + vi.mocked(readdir).mockResolvedValue([]); + vi.mocked(stat).mockResolvedValue({ + isDirectory: () => true, + } as Stats); + + const result = await scanCachedRepo(repoPath); + expect(readdir).toHaveBeenCalledWith(join(repoPath, "refs"), { + withFileTypes: true, + }); + + expect(result.id.name).toBe("hello-world/name"); + expect(result.id.type).toBe("model"); + }); +}); + +describe("scanSnapshotDir", () => { + test("should scan a valid snapshot directory", async () => { + const cachedFiles: CachedFileInfo[] = []; + const blobStats = new Map(); + vi.mocked(readdir).mockResolvedValueOnce([{ name: "file1", isDirectory: () => false } as Dirent]); + + vi.mocked(realpath).mockResolvedValueOnce("/fake/realpath"); + vi.mocked(lstat).mockResolvedValueOnce({ size: 1024, atimeMs: Date.now(), mtimeMs: Date.now() } as Stats); + + await scanSnapshotDir("/fake/revision", cachedFiles, blobStats); + + expect(cachedFiles).toHaveLength(1); + expect(blobStats.size).toBe(1); + }); +}); + +describe("getBlobStat", () => { + test("should retrieve blob stat if already cached", async () => { + const blobStats = new Map([["/fake/blob", { size: 1024 } as Stats]]); + const result = await getBlobStat("/fake/blob", blobStats); + + expect(lstat).not.toHaveBeenCalled(); + expect(result.size).toBe(1024); + }); + + test("should fetch and cache blob stat if not cached", async () => { + const blobStats = new Map(); + vi.mocked(lstat).mockResolvedValueOnce({ size: 2048 } as Stats); + + const result = await getBlobStat("/fake/blob", blobStats); + + expect(result.size).toBe(2048); + expect(blobStats.size).toBe(1); + }); +}); + +describe("parseRepoType", () => { + test("should parse models repo type", () => { + expect(parseRepoType("models")).toBe("model"); + }); + + test("should parse dataset repo type", () => { + expect(parseRepoType("datasets")).toBe("dataset"); + }); + + test("should parse space repo type", () => { + expect(parseRepoType("spaces")).toBe("space"); + }); + + test("should throw an error for invalid repo type", () => { + expect(() => parseRepoType("invalid")).toThrowError("Invalid repo type: invalid"); + }); +}); diff --git a/lib/cache-management.ts b/lib/cache-management.ts new file mode 100644 index 0000000000000000000000000000000000000000..84b66407797ed89697272edd2609f3782ba05fc3 --- /dev/null +++ b/lib/cache-management.ts @@ -0,0 +1,265 @@ +import { homedir } from "node:os"; +import { join, basename } from "node:path"; +import { stat, readdir, readFile, realpath, lstat } from "node:fs/promises"; +import type { Stats } from "node:fs"; +import type { RepoType, RepoId } from "../types/public"; + +function getDefaultHome(): string { + return join(homedir(), ".cache"); +} + +function getDefaultCachePath(): string { + return join(process.env["HF_HOME"] ?? join(process.env["XDG_CACHE_HOME"] ?? getDefaultHome(), "huggingface"), "hub"); +} + +function getHuggingFaceHubCache(): string { + return process.env["HUGGINGFACE_HUB_CACHE"] ?? getDefaultCachePath(); +} + +export function getHFHubCachePath(): string { + return process.env["HF_HUB_CACHE"] ?? getHuggingFaceHubCache(); +} + +const FILES_TO_IGNORE: string[] = [".DS_Store"]; + +export const REPO_ID_SEPARATOR: string = "--"; + +export function getRepoFolderName({ name, type }: RepoId): string { + const parts = [`${type}s`, ...name.split("/")]; + return parts.join(REPO_ID_SEPARATOR); +} + +export interface CachedFileInfo { + path: string; + /** + * Underlying file - which `path` is symlinked to + */ + blob: { + size: number; + path: string; + lastModifiedAt: Date; + lastAccessedAt: Date; + }; +} + +export interface CachedRevisionInfo { + commitOid: string; + path: string; + size: number; + files: CachedFileInfo[]; + refs: string[]; + + lastModifiedAt: Date; +} + +export interface CachedRepoInfo { + id: RepoId; + path: string; + size: number; + filesCount: number; + revisions: CachedRevisionInfo[]; + + lastAccessedAt: Date; + lastModifiedAt: Date; +} + +export interface HFCacheInfo { + size: number; + repos: CachedRepoInfo[]; + warnings: Error[]; +} + +export async function scanCacheDir(cacheDir: string | undefined = undefined): Promise { + if (!cacheDir) cacheDir = getHFHubCachePath(); + + const s = await stat(cacheDir); + if (!s.isDirectory()) { + throw new Error( + `Scan cache expects a directory but found a file: ${cacheDir}. Please use \`cacheDir\` argument or set \`HF_HUB_CACHE\` environment variable.` + ); + } + + const repos: CachedRepoInfo[] = []; + const warnings: Error[] = []; + + const directories = await readdir(cacheDir); + for (const repo of directories) { + // skip .locks folder + if (repo === ".locks") continue; + + // get the absolute path of the repo + const absolute = join(cacheDir, repo); + + // ignore non-directory element + const s = await stat(absolute); + if (!s.isDirectory()) { + continue; + } + + try { + const cached = await scanCachedRepo(absolute); + repos.push(cached); + } catch (err: unknown) { + warnings.push(err as Error); + } + } + + return { + repos: repos, + size: [...repos.values()].reduce((sum, repo) => sum + repo.size, 0), + warnings: warnings, + }; +} + +export async function scanCachedRepo(repoPath: string): Promise { + // get the directory name + const name = basename(repoPath); + if (!name.includes(REPO_ID_SEPARATOR)) { + throw new Error(`Repo path is not a valid HuggingFace cache directory: ${name}`); + } + + // parse the repoId from directory name + const [type, ...remaining] = name.split(REPO_ID_SEPARATOR); + const repoType = parseRepoType(type); + const repoId = remaining.join("/"); + + const snapshotsPath = join(repoPath, "snapshots"); + const refsPath = join(repoPath, "refs"); + + const snapshotStat = await stat(snapshotsPath); + if (!snapshotStat.isDirectory()) { + throw new Error(`Snapshots dir doesn't exist in cached repo ${snapshotsPath}`); + } + + // Check if the refs directory exists and scan it + const refsByHash: Map = new Map(); + const refsStat = await stat(refsPath); + if (refsStat.isDirectory()) { + await scanRefsDir(refsPath, refsByHash); + } + + // Scan snapshots directory and collect cached revision information + const cachedRevisions: CachedRevisionInfo[] = []; + const blobStats: Map = new Map(); // Store blob stats + + const snapshotDirs = await readdir(snapshotsPath); + for (const dir of snapshotDirs) { + if (FILES_TO_IGNORE.includes(dir)) continue; // Ignore unwanted files + + const revisionPath = join(snapshotsPath, dir); + const revisionStat = await stat(revisionPath); + if (!revisionStat.isDirectory()) { + throw new Error(`Snapshots folder corrupted. Found a file: ${revisionPath}`); + } + + const cachedFiles: CachedFileInfo[] = []; + await scanSnapshotDir(revisionPath, cachedFiles, blobStats); + + const revisionLastModified = + cachedFiles.length > 0 + ? Math.max(...[...cachedFiles].map((file) => file.blob.lastModifiedAt.getTime())) + : revisionStat.mtimeMs; + + cachedRevisions.push({ + commitOid: dir, + files: cachedFiles, + refs: refsByHash.get(dir) || [], + size: [...cachedFiles].reduce((sum, file) => sum + file.blob.size, 0), + path: revisionPath, + lastModifiedAt: new Date(revisionLastModified), + }); + + refsByHash.delete(dir); + } + + // Verify that all refs refer to a valid revision + if (refsByHash.size > 0) { + throw new Error( + `Reference(s) refer to missing commit hashes: ${JSON.stringify(Object.fromEntries(refsByHash))} (${repoPath})` + ); + } + + const repoStats = await stat(repoPath); + const repoLastAccessed = + blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.atimeMs)) : repoStats.atimeMs; + + const repoLastModified = + blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.mtimeMs)) : repoStats.mtimeMs; + + // Return the constructed CachedRepoInfo object + return { + id: { + name: repoId, + type: repoType, + }, + path: repoPath, + filesCount: blobStats.size, + revisions: cachedRevisions, + size: [...blobStats.values()].reduce((sum, stat) => sum + stat.size, 0), + lastAccessedAt: new Date(repoLastAccessed), + lastModifiedAt: new Date(repoLastModified), + }; +} + +export async function scanRefsDir(refsPath: string, refsByHash: Map): Promise { + const refFiles = await readdir(refsPath, { withFileTypes: true }); + for (const refFile of refFiles) { + const refFilePath = join(refsPath, refFile.name); + if (refFile.isDirectory()) continue; // Skip directories + + const commitHash = await readFile(refFilePath, "utf-8"); + const refName = refFile.name; + if (!refsByHash.has(commitHash)) { + refsByHash.set(commitHash, []); + } + refsByHash.get(commitHash)?.push(refName); + } +} + +export async function scanSnapshotDir( + revisionPath: string, + cachedFiles: CachedFileInfo[], + blobStats: Map +): Promise { + const files = await readdir(revisionPath, { withFileTypes: true }); + for (const file of files) { + if (file.isDirectory()) continue; // Skip directories + + const filePath = join(revisionPath, file.name); + const blobPath = await realpath(filePath); + const blobStat = await getBlobStat(blobPath, blobStats); + + cachedFiles.push({ + path: filePath, + blob: { + path: blobPath, + size: blobStat.size, + lastAccessedAt: new Date(blobStat.atimeMs), + lastModifiedAt: new Date(blobStat.mtimeMs), + }, + }); + } +} + +export async function getBlobStat(blobPath: string, blobStats: Map): Promise { + const blob = blobStats.get(blobPath); + if (!blob) { + const statResult = await lstat(blobPath); + blobStats.set(blobPath, statResult); + return statResult; + } + return blob; +} + +export function parseRepoType(type: string): RepoType { + switch (type) { + case "models": + return "model"; + case "datasets": + return "dataset"; + case "spaces": + return "space"; + default: + throw new TypeError(`Invalid repo type: ${type}`); + } +} diff --git a/lib/check-repo-access.spec.ts b/lib/check-repo-access.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..12ad5cd92a2cf3549b1e62f4125b875336b55d6a --- /dev/null +++ b/lib/check-repo-access.spec.ts @@ -0,0 +1,34 @@ +import { assert, describe, expect, it } from "vitest"; +import { checkRepoAccess } from "./check-repo-access"; +import { HubApiError } from "../error"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts"; + +describe("checkRepoAccess", () => { + it("should throw 401 when accessing unexisting repo unauthenticated", async () => { + try { + await checkRepoAccess({ repo: { name: "i--d/dont", type: "model" } }); + assert(false, "should have thrown"); + } catch (err) { + expect(err).toBeInstanceOf(HubApiError); + expect((err as HubApiError).statusCode).toBe(401); + } + }); + + it("should throw 404 when accessing unexisting repo authenticated", async () => { + try { + await checkRepoAccess({ + repo: { name: "i--d/dont", type: "model" }, + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, + }); + assert(false, "should have thrown"); + } catch (err) { + expect(err).toBeInstanceOf(HubApiError); + expect((err as HubApiError).statusCode).toBe(404); + } + }); + + it("should not throw when accessing public repo", async () => { + await checkRepoAccess({ repo: { name: "openai-community/gpt2", type: "model" } }); + }); +}); diff --git a/lib/check-repo-access.ts b/lib/check-repo-access.ts new file mode 100644 index 0000000000000000000000000000000000000000..3107c9bd7e9006fe68c38c40d2c220b474390cf1 --- /dev/null +++ b/lib/check-repo-access.ts @@ -0,0 +1,32 @@ +import { HUB_URL } from "../consts"; +// eslint-disable-next-line @typescript-eslint/no-unused-vars +import { createApiError, type HubApiError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +/** + * Check if we have read access to a repository. + * + * Throw a {@link HubApiError} error if we don't have access. HubApiError.statusCode will be 401, 403 or 404. + */ +export async function checkRepoAccess( + params: { + repo: RepoDesignation; + hubUrl?: string; + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = params && checkCredentials(params); + const repoId = toRepoId(params.repo); + + const response = await (params.fetch || fetch)(`${params?.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}`, { + headers: { + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + }, + }); + + if (!response.ok) { + throw await createApiError(response); + } +} diff --git a/lib/commit.spec.ts b/lib/commit.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..024155bbc7e69744c227e65c17f40d1d7ad2e56b --- /dev/null +++ b/lib/commit.spec.ts @@ -0,0 +1,271 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import type { CommitFile } from "./commit"; +import { commit } from "./commit"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { fileDownloadInfo } from "./file-download-info"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { isFrontend } from "../utils/isFrontend"; + +const lfsContent = "O123456789".repeat(100_000); + +describe("commit", () => { + it("should commit to a repo with blobs", async function () { + const tokenizerJsonUrl = new URL( + "https://huggingface.co/spaces/aschen/push-model-from-web/raw/main/mobilenet/model.json" + ); + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo: RepoId = { + name: repoName, + type: "model", + }; + + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + license: "mit", + }); + + try { + const readme1 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL }); + assert(readme1, "Readme doesn't exist"); + + const nodeOperation: CommitFile[] = isFrontend + ? [] + : [ + { + operation: "addOrUpdate", + path: "tsconfig.json", + content: (await import("node:url")).pathToFileURL("./tsconfig.json") as URL, + }, + ]; + + await commit({ + repo, + title: "Some commit", + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + operations: [ + { + operation: "addOrUpdate", + content: new Blob(["This is me"]), + path: "test.txt", + }, + { + operation: "addOrUpdate", + content: new Blob([lfsContent]), + path: "test.lfs.txt", + }, + ...nodeOperation, + { + operation: "addOrUpdate", + content: tokenizerJsonUrl, + path: "lamaral.json", + }, + { + operation: "delete", + path: "README.md", + }, + ], + // To test web workers in the front-end + useWebWorkers: { minSize: 5_000 }, + }); + + const fileContent = await downloadFile({ repo, path: "test.txt", hubUrl: TEST_HUB_URL }); + assert.strictEqual(await fileContent?.text(), "This is me"); + + const lfsFileContent = await downloadFile({ repo, path: "test.lfs.txt", hubUrl: TEST_HUB_URL }); + assert.strictEqual(await lfsFileContent?.text(), lfsContent); + + const lfsFileUrl = `${TEST_HUB_URL}/${repoName}/raw/main/test.lfs.txt`; + const lfsFilePointer = await fetch(lfsFileUrl); + assert.strictEqual(lfsFilePointer.status, 200); + assert.strictEqual( + (await lfsFilePointer.text()).trim(), + ` +version https://git-lfs.github.com/spec/v1 +oid sha256:a3bbce7ee1df7233d85b5f4d60faa3755f93f537804f8b540c72b0739239ddf8 +size ${lfsContent.length} + `.trim() + ); + + if (!isFrontend) { + const fileUrlContent = await downloadFile({ repo, path: "tsconfig.json", hubUrl: TEST_HUB_URL }); + assert.strictEqual( + await fileUrlContent?.text(), + (await import("node:fs")).readFileSync("./tsconfig.json", "utf-8") + ); + } + + const webResourceContent = await downloadFile({ repo, path: "lamaral.json", hubUrl: TEST_HUB_URL }); + assert.strictEqual(await webResourceContent?.text(), await (await fetch(tokenizerJsonUrl)).text()); + + const readme2 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL }); + assert.strictEqual(readme2, null); + } finally { + await deleteRepo({ + repo: { + name: repoName, + type: "model", + }, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + } + }, 60_000); + + it("should commit a full repo from HF with web urls", async function () { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo: RepoId = { + name: repoName, + type: "model", + }; + + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + try { + const FILES_TO_UPLOAD = [ + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/model.json`, + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2`, + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard2of2`, + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/coffee.jpg`, + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/README.md`, + ]; + + const operations: CommitFile[] = await Promise.all( + FILES_TO_UPLOAD.map(async (file) => { + return { + operation: "addOrUpdate", + path: file.slice(file.indexOf("main/") + "main/".length), + // upload remote file + content: new URL(file), + }; + }) + ); + await commit({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + title: "upload model", + operations, + }); + + const LFSSize = (await fileDownloadInfo({ repo, path: "mobilenet/group1-shard1of2", hubUrl: TEST_HUB_URL })) + ?.size; + + assert.strictEqual(LFSSize, 4_194_304); + + const pointerFile = await downloadFile({ + repo, + path: "mobilenet/group1-shard1of2", + raw: true, + hubUrl: TEST_HUB_URL, + }); + + // Make sure SHA is computed properly as well + assert.strictEqual( + (await pointerFile?.text())?.trim(), + ` +version https://git-lfs.github.com/spec/v1 +oid sha256:3fb621eb9b37478239504ee083042d5b18699e8b8618e569478b03b119a85a69 +size 4194304 + `.trim() + ); + } finally { + await deleteRepo({ + repo: { + name: repoName, + type: "model", + }, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + } + // https://huggingfacejs-push-model-from-web.hf.space/ + }, 60_000); + + it("should be able to create a PR and then commit to it", async function () { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo: RepoId = { + name: repoName, + type: "model", + }; + + await createRepo({ + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + repo, + hubUrl: TEST_HUB_URL, + }); + + try { + const pr = await commit({ + repo, + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + hubUrl: TEST_HUB_URL, + title: "Create PR", + isPullRequest: true, + operations: [ + { + operation: "addOrUpdate", + content: new Blob(["This is me"]), + path: "test.txt", + }, + ], + }); + + if (!pr) { + throw new Error("PR creation failed"); + } + + if (!pr.pullRequestUrl) { + throw new Error("No pull request url"); + } + + const prNumber = pr.pullRequestUrl.split("/").pop(); + const prRef = `refs/pr/${prNumber}`; + + await commit({ + repo, + credentials: { + accessToken: TEST_ACCESS_TOKEN, + }, + hubUrl: TEST_HUB_URL, + branch: prRef, + title: "Some commit", + operations: [ + { + operation: "addOrUpdate", + content: new URL( + `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2` + ), + path: "mobilenet/group1-shard1of2", + }, + ], + }); + + assert(commit, "PR commit failed"); + } finally { + await deleteRepo({ + repo: { + name: repoName, + type: "model", + }, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + } + }, 60_000); +}); diff --git a/lib/commit.ts b/lib/commit.ts new file mode 100644 index 0000000000000000000000000000000000000000..a7acb3bcbe3eee22822a18a533332ebca19b422a --- /dev/null +++ b/lib/commit.ts @@ -0,0 +1,609 @@ +import { HUB_URL } from "../consts"; +import { HubApiError, createApiError, InvalidApiResponseFormatError } from "../error"; +import type { + ApiCommitHeader, + ApiCommitLfsFile, + ApiCommitOperation, + ApiLfsBatchRequest, + ApiLfsBatchResponse, + ApiLfsCompleteMultipartRequest, + ApiPreuploadRequest, + ApiPreuploadResponse, +} from "../types/api/api-commit"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { chunk } from "../utils/chunk"; +import { promisesQueue } from "../utils/promisesQueue"; +import { promisesQueueStreaming } from "../utils/promisesQueueStreaming"; +import { sha256 } from "../utils/sha256"; +import { toRepoId } from "../utils/toRepoId"; +import { WebBlob } from "../utils/WebBlob"; +import { eventToGenerator } from "../utils/eventToGenerator"; +import { base64FromBytes } from "../utils/base64FromBytes"; +import { isFrontend } from "../utils/isFrontend"; +import { createBlobs } from "../utils/createBlobs"; + +const CONCURRENT_SHAS = 5; +const CONCURRENT_LFS_UPLOADS = 5; +const MULTIPART_PARALLEL_UPLOAD = 5; + +export interface CommitDeletedEntry { + operation: "delete"; + path: string; +} + +export type ContentSource = Blob | URL; + +export interface CommitFile { + operation: "addOrUpdate"; + path: string; + content: ContentSource; + // forceLfs?: boolean +} + +type CommitBlob = Omit & { content: Blob }; + +// TODO: find a nice way to handle LFS & non-LFS files in an uniform manner, see https://github.com/huggingface/moon-landing/issues/4370 +// export type CommitRenameFile = { +// operation: "rename"; +// path: string; +// oldPath: string; +// content?: ContentSource; +// }; + +export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */; +type CommitBlobOperation = Exclude | CommitBlob; + +export type CommitParams = { + title: string; + description?: string; + repo: RepoDesignation; + operations: CommitOperation[]; + /** @default "main" */ + branch?: string; + /** + * Parent commit. Optional + * + * - When opening a PR: will use parentCommit as the parent commit + * - When committing on a branch: Will make sure that there were no intermediate commits + */ + parentCommit?: string; + isPullRequest?: boolean; + hubUrl?: string; + /** + * Whether to use web workers to compute SHA256 hashes. + * + * @default false + */ + useWebWorkers?: boolean | { minSize?: number; poolSize?: number }; + /** + * Maximum depth of folders to upload. Files deeper than this will be ignored + * + * @default 5 + */ + maxFolderDepth?: number; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + abortSignal?: AbortSignal; + // Credentials are optional due to custom fetch functions or cookie auth +} & Partial; + +export interface CommitOutput { + pullRequestUrl?: string; + commit: { + oid: string; + url: string; + }; + hookOutput: string; +} + +function isFileOperation(op: CommitOperation): op is CommitBlob { + const ret = op.operation === "addOrUpdate"; + + if (ret && !(op.content instanceof Blob)) { + throw new TypeError("Precondition failed: op.content should be a Blob"); + } + + return ret; +} + +export type CommitProgressEvent = + | { + event: "phase"; + phase: "preuploading" | "uploadingLargeFiles" | "committing"; + } + | { + event: "fileProgress"; + path: string; + progress: number; + state: "hashing" | "uploading"; + }; + +/** + * Internal function for now, used by commit. + * + * Can be exposed later to offer fine-tuned progress info + */ +export async function* commitIter(params: CommitParams): AsyncGenerator { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + yield { event: "phase", phase: "preuploading" }; + + const lfsShas = new Map(); + + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + // Polyfill see https://discuss.huggingface.co/t/why-cant-i-upload-a-parquet-file-to-my-dataset-error-o-throwifaborted-is-not-a-function/62245 + if (!abortSignal.throwIfAborted) { + abortSignal.throwIfAborted = () => { + if (abortSignal.aborted) { + throw new DOMException("Aborted", "AbortError"); + } + }; + } + + if (params.abortSignal) { + params.abortSignal.addEventListener("abort", () => abortController.abort()); + } + + try { + const allOperations = ( + await Promise.all( + params.operations.map(async (operation) => { + if (operation.operation !== "addOrUpdate") { + return operation; + } + + if (!(operation.content instanceof URL)) { + /** TS trick to enforce `content` to be a `Blob` */ + return { ...operation, content: operation.content }; + } + + const lazyBlobs = await createBlobs(operation.content, operation.path, { + fetch: params.fetch, + maxFolderDepth: params.maxFolderDepth, + }); + + abortSignal?.throwIfAborted(); + + return lazyBlobs.map((blob) => ({ + ...operation, + content: blob.blob, + path: blob.path, + })); + }) + ) + ).flat(1); + + const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content; + + for (const operations of chunk(allOperations.filter(isFileOperation), 100)) { + const payload: ApiPreuploadRequest = { + gitAttributes: gitAttributes && (await gitAttributes.text()), + files: await Promise.all( + operations.map(async (operation) => ({ + path: operation.path, + size: operation.content.size, + sample: base64FromBytes(new Uint8Array(await operation.content.slice(0, 512).arrayBuffer())), + })) + ), + }; + + abortSignal?.throwIfAborted(); + + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent( + params.branch ?? "main" + )}` + (params.isPullRequest ? "?create_pr=1" : ""), + { + method: "POST", + headers: { + ...(accessToken && { Authorization: `Bearer ${accessToken}` }), + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + signal: abortSignal, + } + ); + + if (!res.ok) { + throw await createApiError(res); + } + + const json: ApiPreuploadResponse = await res.json(); + + for (const file of json.files) { + if (file.uploadMode === "lfs") { + lfsShas.set(file.path, null); + } + } + } + + yield { event: "phase", phase: "uploadingLargeFiles" }; + + for (const operations of chunk( + allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)), + 100 + )) { + const shas = yield* eventToGenerator< + { event: "fileProgress"; state: "hashing"; path: string; progress: number }, + string[] + >((yieldCallback, returnCallback, rejectCallack) => { + return promisesQueue( + operations.map((op) => async () => { + const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: abortSignal }); + let res: IteratorResult; + do { + res = await iterator.next(); + if (!res.done) { + yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, state: "hashing" }); + } + } while (!res.done); + const sha = res.value; + lfsShas.set(op.path, res.value); + return sha; + }), + CONCURRENT_SHAS + ).then(returnCallback, rejectCallack); + }); + + abortSignal?.throwIfAborted(); + + const payload: ApiLfsBatchRequest = { + operation: "upload", + // multipart is a custom protocol for HF + transfers: ["basic", "multipart"], + hash_algo: "sha_256", + ...(!params.isPullRequest && { + ref: { + name: params.branch ?? "main", + }, + }), + objects: operations.map((op, i) => ({ + oid: shas[i], + size: op.content.size, + })), + }; + + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${ + repoId.name + }.git/info/lfs/objects/batch`, + { + method: "POST", + headers: { + ...(accessToken && { Authorization: `Bearer ${accessToken}` }), + Accept: "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", + }, + body: JSON.stringify(payload), + signal: abortSignal, + } + ); + + if (!res.ok) { + throw await createApiError(res); + } + + const json: ApiLfsBatchResponse = await res.json(); + const batchRequestId = res.headers.get("X-Request-Id") || undefined; + + const shaToOperation = new Map(operations.map((op, i) => [shas[i], op])); + + yield* eventToGenerator((yieldCallback, returnCallback, rejectCallback) => { + return promisesQueueStreaming( + json.objects.map((obj) => async () => { + const op = shaToOperation.get(obj.oid); + + if (!op) { + throw new InvalidApiResponseFormatError("Unrequested object ID in response"); + } + + abortSignal?.throwIfAborted(); + + if (obj.error) { + const errorMessage = `Error while doing LFS batch call for ${operations[shas.indexOf(obj.oid)].path}: ${ + obj.error.message + }${batchRequestId ? ` - Request ID: ${batchRequestId}` : ""}`; + throw new HubApiError(res.url, obj.error.code, batchRequestId, errorMessage); + } + if (!obj.actions?.upload) { + // Already uploaded + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + state: "uploading", + }); + return; + } + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 0, + state: "uploading", + }); + const content = op.content; + const header = obj.actions.upload.header; + if (header?.chunk_size) { + const chunkSize = parseInt(header.chunk_size); + + // multipart upload + // parts are in upload.header['00001'] to upload.header['99999'] + + const completionUrl = obj.actions.upload.href; + const parts = Object.keys(header).filter((key) => /^[0-9]+$/.test(key)); + + if (parts.length !== Math.ceil(content.size / chunkSize)) { + throw new Error("Invalid server response to upload large LFS file, wrong number of parts"); + } + + const completeReq: ApiLfsCompleteMultipartRequest = { + oid: obj.oid, + parts: parts.map((part) => ({ + partNumber: +part, + etag: "", + })), + }; + + // Defined here so that it's not redefined at each iteration (and the caller can tell it's for the same file) + const progressCallback = (progress: number) => + yieldCallback({ event: "fileProgress", path: op.path, progress, state: "uploading" }); + + await promisesQueueStreaming( + parts.map((part) => async () => { + abortSignal?.throwIfAborted(); + + const index = parseInt(part) - 1; + const slice = content.slice(index * chunkSize, (index + 1) * chunkSize); + + const res = await (params.fetch ?? fetch)(header[part], { + method: "PUT", + /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ + body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice, + signal: abortSignal, + ...({ + progressHint: { + path: op.path, + part: index, + numParts: parts.length, + progressCallback, + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), + }); + + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error while uploading part ${part} of ${ + operations[shas.indexOf(obj.oid)].path + } to LFS storage`, + }); + } + + const eTag = res.headers.get("ETag"); + + if (!eTag) { + throw new Error("Cannot get ETag of part during multipart upload"); + } + + completeReq.parts[Number(part) - 1].etag = eTag; + }), + MULTIPART_PARALLEL_UPLOAD + ); + + abortSignal?.throwIfAborted(); + + const res = await (params.fetch ?? fetch)(completionUrl, { + method: "POST", + body: JSON.stringify(completeReq), + headers: { + Accept: "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", + }, + signal: abortSignal, + }); + + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error completing multipart upload of ${ + operations[shas.indexOf(obj.oid)].path + } to LFS storage`, + }); + } + + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + state: "uploading", + }); + } else { + const res = await (params.fetch ?? fetch)(obj.actions.upload.href, { + method: "PUT", + headers: { + ...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined), + }, + /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */ + body: content instanceof WebBlob && isFrontend ? await content.arrayBuffer() : content, + signal: abortSignal, + ...({ + progressHint: { + path: op.path, + progressCallback: (progress: number) => + yieldCallback({ + event: "fileProgress", + path: op.path, + progress, + state: "uploading", + }), + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), + }); + + if (!res.ok) { + throw await createApiError(res, { + requestId: batchRequestId, + message: `Error while uploading ${operations[shas.indexOf(obj.oid)].path} to LFS storage`, + }); + } + + yieldCallback({ + event: "fileProgress", + path: op.path, + progress: 1, + state: "uploading", + }); + } + }), + CONCURRENT_LFS_UPLOADS + ).then(returnCallback, rejectCallback); + }); + } + + abortSignal?.throwIfAborted(); + + yield { event: "phase", phase: "committing" }; + + return yield* eventToGenerator( + async (yieldCallback, returnCallback, rejectCallback) => + (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent( + params.branch ?? "main" + )}` + (params.isPullRequest ? "?create_pr=1" : ""), + { + method: "POST", + headers: { + ...(accessToken && { Authorization: `Bearer ${accessToken}` }), + "Content-Type": "application/x-ndjson", + }, + body: [ + { + key: "header", + value: { + summary: params.title, + description: params.description, + parentCommit: params.parentCommit, + } satisfies ApiCommitHeader, + }, + ...((await Promise.all( + allOperations.map((operation) => { + if (isFileOperation(operation)) { + const sha = lfsShas.get(operation.path); + if (sha) { + return { + key: "lfsFile", + value: { + path: operation.path, + algo: "sha256", + size: operation.content.size, + oid: sha, + } satisfies ApiCommitLfsFile, + }; + } + } + + return convertOperationToNdJson(operation); + }) + )) satisfies ApiCommitOperation[]), + ] + .map((x) => JSON.stringify(x)) + .join("\n"), + signal: abortSignal, + ...({ + progressHint: { + progressCallback: (progress: number) => { + // For now, we display equal progress for all files + // We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead + for (const op of allOperations) { + if (isFileOperation(op) && !lfsShas.has(op.path)) { + yieldCallback({ + event: "fileProgress", + path: op.path, + progress, + state: "uploading", + }); + } + } + }, + }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any), + } + ) + .then(async (res) => { + if (!res.ok) { + throw await createApiError(res); + } + + const json = await res.json(); + + returnCallback({ + pullRequestUrl: json.pullRequestUrl, + commit: { + oid: json.commitOid, + url: json.commitUrl, + }, + hookOutput: json.hookOutput, + }); + }) + .catch(rejectCallback) + ); + } catch (err) { + // For parallel requests, cancel them all if one fails + abortController.abort(); + throw err; + } +} + +export async function commit(params: CommitParams): Promise { + const iterator = commitIter(params); + let res = await iterator.next(); + while (!res.done) { + res = await iterator.next(); + } + return res.value; +} + +async function convertOperationToNdJson(operation: CommitBlobOperation): Promise { + switch (operation.operation) { + case "addOrUpdate": { + // todo: handle LFS + return { + key: "file", + value: { + content: base64FromBytes(new Uint8Array(await operation.content.arrayBuffer())), + path: operation.path, + encoding: "base64", + }, + }; + } + // case "rename": { + // // todo: detect when remote file is already LFS, and in that case rename as LFS + // return { + // key: "file", + // value: { + // content: operation.content, + // path: operation.path, + // oldPath: operation.oldPath + // } + // }; + // } + case "delete": { + return { + key: "deletedFile", + value: { + path: operation.path, + }, + }; + } + default: + throw new TypeError("Unknown operation: " + (operation as { operation: string }).operation); + } +} diff --git a/lib/count-commits.spec.ts b/lib/count-commits.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..f60754789b7a1b6f1c99765e9fa5d77689478eb8 --- /dev/null +++ b/lib/count-commits.spec.ts @@ -0,0 +1,16 @@ +import { assert, it, describe } from "vitest"; +import { countCommits } from "./count-commits"; + +describe("countCommits", () => { + it("should fetch paginated commits from the repo", async () => { + const count = await countCommits({ + repo: { + name: "openai-community/gpt2", + type: "model", + }, + revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e", + }); + + assert.equal(count, 26); + }); +}); diff --git a/lib/count-commits.ts b/lib/count-commits.ts new file mode 100644 index 0000000000000000000000000000000000000000..0e133253273d4b1ee689c91e174dd11021846e10 --- /dev/null +++ b/lib/count-commits.ts @@ -0,0 +1,35 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +export async function countCommits( + params: { + repo: RepoDesignation; + /** + * Revision to list commits from. Defaults to the default branch. + */ + revision?: string; + hubUrl?: string; + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + // Could upgrade to 1000 commits per page + const url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${ + params.revision ?? "main" + }?limit=1`; + + const res: Response = await (params.fetch ?? fetch)(url, { + headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + return parseInt(res.headers.get("x-total-count") ?? "0", 10); +} diff --git a/lib/create-branch.spec.ts b/lib/create-branch.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..b616fb4cebfd888a23c33bfe3a7b7334d575fd2d --- /dev/null +++ b/lib/create-branch.spec.ts @@ -0,0 +1,159 @@ +import { assert, it, describe } from "vitest"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { createBranch } from "./create-branch"; +import { uploadFile } from "./upload-file"; +import { downloadFile } from "./download-file"; + +describe("createBranch", () => { + it("should create a new branch from the default branch", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + }); + + await uploadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + file: { + path: "file.txt", + content: new Blob(["file content"]), + }, + }); + + await createBranch({ + repo, + branch: "new-branch", + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + + const content = await downloadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + path: "file.txt", + revision: "new-branch", + }); + + assert.equal(await content?.text(), "file content"); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); + + it("should create an empty branch", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + }); + + await uploadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + file: { + path: "file.txt", + content: new Blob(["file content"]), + }, + }); + + await createBranch({ + repo, + branch: "empty-branch", + empty: true, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + + const content = await downloadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + path: "file.txt", + revision: "empty-branch", + }); + + assert.equal(content, null); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); + + it("should overwrite an existing branch", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + }); + + await uploadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + file: { + path: "file.txt", + content: new Blob(["file content"]), + }, + }); + + await createBranch({ + repo, + branch: "overwrite-branch", + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + + await createBranch({ + repo, + branch: "overwrite-branch", + overwrite: true, + empty: true, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + + const content = await downloadFile({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + path: "file.txt", + revision: "overwrite-branch", + }); + + assert.equal(content, null); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/create-branch.ts b/lib/create-branch.ts new file mode 100644 index 0000000000000000000000000000000000000000..100e4d1b94283f51c83783efd6a99b144e3c31f1 --- /dev/null +++ b/lib/create-branch.ts @@ -0,0 +1,54 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { AccessToken, RepoDesignation } from "../types/public"; +import { toRepoId } from "../utils/toRepoId"; + +export async function createBranch(params: { + repo: RepoDesignation; + /** + * Revision to create the branch from. Defaults to the default branch. + * + * Use empty: true to create an empty branch. + */ + revision?: string; + hubUrl?: string; + accessToken?: AccessToken; + fetch?: typeof fetch; + /** + * The name of the branch to create + */ + branch: string; + /** + * Use this to create an empty branch, with no commits. + */ + empty?: boolean; + /** + * Use this to overwrite the branch if it already exists. + * + * If you only specify `overwrite` and no `revision`/`empty`, and the branch already exists, it will be a no-op. + */ + overwrite?: boolean; +}): Promise { + const repoId = toRepoId(params.repo); + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + ...(params.accessToken && { + Authorization: `Bearer ${params.accessToken}`, + }), + }, + body: JSON.stringify({ + startingPoint: params.revision, + ...(params.empty && { emptyBranch: true }), + overwrite: params.overwrite, + }), + } + ); + + if (!res.ok) { + throw await createApiError(res); + } +} diff --git a/lib/create-repo.spec.ts b/lib/create-repo.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..92d4d6b51a112205f08d236d46b136f3f2647e8c --- /dev/null +++ b/lib/create-repo.spec.ts @@ -0,0 +1,103 @@ +import { assert, it, describe, expect } from "vitest"; + +import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; + +describe("createRepo", () => { + it("should create a repo", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo: { + name: repoName, + type: "model", + }, + hubUrl: TEST_HUB_URL, + files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }], + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + const content = await downloadFile({ + repo: { + name: repoName, + type: "model", + }, + path: ".gitattributes", + hubUrl: TEST_HUB_URL, + }); + + assert(content); + assert.strictEqual(await content.text(), "*.html filter=lfs diff=lfs merge=lfs -text"); + + await deleteRepo({ + repo: { + name: repoName, + type: "model", + }, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + hubUrl: TEST_HUB_URL, + }); + }); + + it("should throw a client error when trying to create a repo without a fully-qualified name", async () => { + const tryCreate = createRepo({ + repo: { name: "canonical", type: "model" }, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + hubUrl: TEST_HUB_URL, + }); + + await expect(tryCreate).rejects.toBeInstanceOf(TypeError); + }); + + it("should create a model with a string as name", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo: repoName, + files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }], + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await deleteRepo({ + repo: { + name: repoName, + type: "model", + }, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + }); + + it("should create a dataset with a string as name", async () => { + const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`; + + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo: repoName, + files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }], + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await deleteRepo({ + repo: repoName, + hubUrl: TEST_HUB_URL, + credentials: { accessToken: TEST_ACCESS_TOKEN }, + }); + }); +}); diff --git a/lib/create-repo.ts b/lib/create-repo.ts new file mode 100644 index 0000000000000000000000000000000000000000..c0323dc1120e50d6c087f4467d26cd8f5a689888 --- /dev/null +++ b/lib/create-repo.ts @@ -0,0 +1,78 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiCreateRepoPayload } from "../types/api/api-create-repo"; +import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public"; +import { base64FromBytes } from "../utils/base64FromBytes"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +export async function createRepo( + params: { + repo: RepoDesignation; + /** + * If unset, will follow the organization's default setting. (typically public, except for some Enterprise organizations) + */ + private?: boolean; + license?: string; + /** + * Only a few lightweight files are supported at repo creation + */ + files?: Array<{ content: ArrayBuffer | Blob; path: string }>; + /** @required for when {@link repo.type} === "space" */ + sdk?: SpaceSdk; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & CredentialsParams +): Promise<{ repoUrl: string }> { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + const [namespace, repoName] = repoId.name.split("/"); + + if (!namespace || !repoName) { + throw new TypeError( + `"${repoId.name}" is not a fully qualified repo name. It should be of the form "{namespace}/{repoName}".` + ); + } + + const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/create`, { + method: "POST", + body: JSON.stringify({ + name: repoName, + private: params.private, + organization: namespace, + license: params.license, + ...(repoId.type === "space" + ? { + type: "space", + sdk: "static", + } + : { + type: repoId.type, + }), + files: params.files + ? await Promise.all( + params.files.map(async (file) => ({ + encoding: "base64", + path: file.path, + content: base64FromBytes( + new Uint8Array(file.content instanceof Blob ? await file.content.arrayBuffer() : file.content) + ), + })) + ) + : undefined, + } satisfies ApiCreateRepoPayload), + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + const output = await res.json(); + return { repoUrl: output.url }; +} diff --git a/lib/dataset-info.spec.ts b/lib/dataset-info.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ae235e5e83a308a19c0e4f60692651424d02a902 --- /dev/null +++ b/lib/dataset-info.spec.ts @@ -0,0 +1,56 @@ +import { describe, expect, it } from "vitest"; +import { datasetInfo } from "./dataset-info"; +import type { DatasetEntry } from "./list-datasets"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; + +describe("datasetInfo", () => { + it("should return the dataset info", async () => { + const info = await datasetInfo({ + name: "nyu-mll/glue", + }); + expect(info).toEqual({ + id: "621ffdd236468d709f181e3f", + downloads: expect.any(Number), + gated: false, + name: "nyu-mll/glue", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + }); + }); + + it("should return the dataset info with author", async () => { + const info: DatasetEntry & Pick = await datasetInfo({ + name: "nyu-mll/glue", + additionalFields: ["author"], + }); + expect(info).toEqual({ + id: "621ffdd236468d709f181e3f", + downloads: expect.any(Number), + gated: false, + name: "nyu-mll/glue", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + author: "nyu-mll", + }); + }); + + it("should return the dataset info for a specific revision", async () => { + const info: DatasetEntry & Pick = await datasetInfo({ + name: "nyu-mll/glue", + revision: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7", + additionalFields: ["sha"], + }); + expect(info).toEqual({ + id: "621ffdd236468d709f181e3f", + downloads: expect.any(Number), + gated: false, + name: "nyu-mll/glue", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + sha: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7", + }); + }); +}); diff --git a/lib/dataset-info.ts b/lib/dataset-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..542b5aa0f4cac268a88e8273beb66591d02b5dfa --- /dev/null +++ b/lib/dataset-info.ts @@ -0,0 +1,61 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; +import type { CredentialsParams } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { pick } from "../utils/pick"; +import { type DATASET_EXPANDABLE_KEYS, DATASET_EXPAND_KEYS, type DatasetEntry } from "./list-datasets"; + +export async function datasetInfo< + const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never, +>( + params: { + name: string; + hubUrl?: string; + additionalFields?: T[]; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + */ + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise> { + const accessToken = params && checkCredentials(params); + + const search = new URLSearchParams([ + ...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), + ]).toString(); + + const response = await (params.fetch || fetch)( + `${params?.hubUrl || HUB_URL}/api/datasets/${params.name}/revision/${encodeURIComponent( + params.revision ?? "HEAD" + )}?${search.toString()}`, + { + headers: { + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + Accepts: "application/json", + }, + } + ); + + if (!response.ok) { + throw await createApiError(response); + } + + const data = await response.json(); + + return { + ...(params?.additionalFields && pick(data, params.additionalFields)), + id: data._id, + name: data.id, + private: data.private, + downloads: data.downloads, + likes: data.likes, + gated: data.gated, + updatedAt: new Date(data.lastModified), + } as DatasetEntry & Pick; +} diff --git a/lib/delete-branch.spec.ts b/lib/delete-branch.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..dcd253214b671d2393a343bd74447495442e659f --- /dev/null +++ b/lib/delete-branch.spec.ts @@ -0,0 +1,43 @@ +import { it, describe } from "vitest"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { createBranch } from "./create-branch"; +import { deleteBranch } from "./delete-branch"; + +describe("deleteBranch", () => { + it("should delete an existing branch", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + }); + + await createBranch({ + repo, + branch: "branch-to-delete", + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + + await deleteBranch({ + repo, + branch: "branch-to-delete", + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/delete-branch.ts b/lib/delete-branch.ts new file mode 100644 index 0000000000000000000000000000000000000000..70227b185e06507e2f4973a20b2b8c56044ac969 --- /dev/null +++ b/lib/delete-branch.ts @@ -0,0 +1,32 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { AccessToken, RepoDesignation } from "../types/public"; +import { toRepoId } from "../utils/toRepoId"; + +export async function deleteBranch(params: { + repo: RepoDesignation; + /** + * The name of the branch to delete + */ + branch: string; + hubUrl?: string; + accessToken?: AccessToken; + fetch?: typeof fetch; +}): Promise { + const repoId = toRepoId(params.repo); + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`, + { + method: "DELETE", + headers: { + ...(params.accessToken && { + Authorization: `Bearer ${params.accessToken}`, + }), + }, + } + ); + + if (!res.ok) { + throw await createApiError(res); + } +} diff --git a/lib/delete-file.spec.ts b/lib/delete-file.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec38a916d808bc905da02bfa83e862e272c96432 --- /dev/null +++ b/lib/delete-file.spec.ts @@ -0,0 +1,64 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { deleteFile } from "./delete-file"; +import { downloadFile } from "./download-file"; + +describe("deleteFile", () => { + it("should delete a file", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + repo, + files: [ + { path: "file1", content: new Blob(["file1"]) }, + { path: "file2", content: new Blob(["file2"]) }, + ], + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + let content = await downloadFile({ + hubUrl: TEST_HUB_URL, + repo, + path: "file1", + }); + + assert.strictEqual(await content?.text(), "file1"); + + await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL }); + + content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(content, null); + + content = await downloadFile({ + repo, + path: "file2", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file2"); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/delete-file.ts b/lib/delete-file.ts new file mode 100644 index 0000000000000000000000000000000000000000..58b19f1983721279670ba093835177e9adca0981 --- /dev/null +++ b/lib/delete-file.ts @@ -0,0 +1,35 @@ +import type { CredentialsParams } from "../types/public"; +import type { CommitOutput, CommitParams } from "./commit"; +import { commit } from "./commit"; + +export function deleteFile( + params: { + repo: CommitParams["repo"]; + path: string; + commitTitle?: CommitParams["title"]; + commitDescription?: CommitParams["description"]; + hubUrl?: CommitParams["hubUrl"]; + fetch?: CommitParams["fetch"]; + branch?: CommitParams["branch"]; + isPullRequest?: CommitParams["isPullRequest"]; + parentCommit?: CommitParams["parentCommit"]; + } & CredentialsParams +): Promise { + return commit({ + ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }), + repo: params.repo, + operations: [ + { + operation: "delete", + path: params.path, + }, + ], + title: params.commitTitle ?? `Delete ${params.path}`, + description: params.commitDescription, + hubUrl: params.hubUrl, + branch: params.branch, + isPullRequest: params.isPullRequest, + parentCommit: params.parentCommit, + fetch: params.fetch, + }); +} diff --git a/lib/delete-files.spec.ts b/lib/delete-files.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..8124d9afa0f8d926746cfbc38265a77056fe0b78 --- /dev/null +++ b/lib/delete-files.spec.ts @@ -0,0 +1,81 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { deleteFiles } from "./delete-files"; +import { downloadFile } from "./download-file"; + +describe("deleteFiles", () => { + it("should delete multiple files", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + files: [ + { path: "file1", content: new Blob(["file1"]) }, + { path: "file2", content: new Blob(["file2"]) }, + { path: "file3", content: new Blob(["file3"]) }, + ], + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + let content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "file2", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file2"); + + await deleteFiles({ paths: ["file1", "file2"], repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL }); + + content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(content, null); + + content = await downloadFile({ + repo, + path: "file2", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(content, null); + + content = await downloadFile({ + repo, + path: "file3", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file3"); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/delete-files.ts b/lib/delete-files.ts new file mode 100644 index 0000000000000000000000000000000000000000..956bd473e49c06ee30dc9b39aa8fc600d243a04c --- /dev/null +++ b/lib/delete-files.ts @@ -0,0 +1,33 @@ +import type { CredentialsParams } from "../types/public"; +import type { CommitOutput, CommitParams } from "./commit"; +import { commit } from "./commit"; + +export function deleteFiles( + params: { + repo: CommitParams["repo"]; + paths: string[]; + commitTitle?: CommitParams["title"]; + commitDescription?: CommitParams["description"]; + hubUrl?: CommitParams["hubUrl"]; + branch?: CommitParams["branch"]; + isPullRequest?: CommitParams["isPullRequest"]; + parentCommit?: CommitParams["parentCommit"]; + fetch?: CommitParams["fetch"]; + } & CredentialsParams +): Promise { + return commit({ + ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }), + repo: params.repo, + operations: params.paths.map((path) => ({ + operation: "delete", + path, + })), + title: params.commitTitle ?? `Deletes ${params.paths.length} files`, + description: params.commitDescription, + hubUrl: params.hubUrl, + branch: params.branch, + isPullRequest: params.isPullRequest, + parentCommit: params.parentCommit, + fetch: params.fetch, + }); +} diff --git a/lib/delete-repo.ts b/lib/delete-repo.ts new file mode 100644 index 0000000000000000000000000000000000000000..7b34d1b8e918ea902f5a783079d32e3e5ea35854 --- /dev/null +++ b/lib/delete-repo.ts @@ -0,0 +1,37 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +export async function deleteRepo( + params: { + repo: RepoDesignation; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & CredentialsParams +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + const [namespace, repoName] = repoId.name.split("/"); + + const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/delete`, { + method: "DELETE", + body: JSON.stringify({ + name: repoName, + organization: namespace, + type: repoId.type, + }), + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } +} diff --git a/lib/download-file-to-cache-dir.spec.ts b/lib/download-file-to-cache-dir.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..fb407c4c2e462b99465e7260f8640174ef3cafdb --- /dev/null +++ b/lib/download-file-to-cache-dir.spec.ts @@ -0,0 +1,306 @@ +import { expect, test, describe, vi, beforeEach } from "vitest"; +import type { RepoDesignation, RepoId } from "../types/public"; +import { dirname, join } from "node:path"; +import { lstat, mkdir, stat, symlink, rename } from "node:fs/promises"; +import { pathsInfo } from "./paths-info"; +import { createWriteStream, type Stats } from "node:fs"; +import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; +import { toRepoId } from "../utils/toRepoId"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; +import { createSymlink } from "../utils/symlink"; + +vi.mock("node:fs/promises", () => ({ + rename: vi.fn(), + symlink: vi.fn(), + lstat: vi.fn(), + mkdir: vi.fn(), + stat: vi.fn(), +})); + +vi.mock("node:fs", () => ({ + createWriteStream: vi.fn(), +})); + +vi.mock("./paths-info", () => ({ + pathsInfo: vi.fn(), +})); + +vi.mock("../utils/symlink", () => ({ + createSymlink: vi.fn(), +})); + +const DUMMY_REPO: RepoId = { + name: "hello-world", + type: "model", +}; + +const DUMMY_ETAG = "dummy-etag"; + +// utility test method to get blob file path +function _getBlobFile(params: { + repo: RepoDesignation; + etag: string; + cacheDir?: string; // default to {@link getHFHubCache} +}) { + return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag); +} + +// utility test method to get snapshot file path +function _getSnapshotFile(params: { + repo: RepoDesignation; + path: string; + revision: string; + cacheDir?: string; // default to {@link getHFHubCache} +}) { + return join( + params.cacheDir ?? getHFHubCachePath(), + getRepoFolderName(toRepoId(params.repo)), + "snapshots", + params.revision, + params.path + ); +} + +describe("downloadFileToCacheDir", () => { + const fetchMock: typeof fetch = vi.fn(); + beforeEach(() => { + vi.resetAllMocks(); + // mock 200 request + vi.mocked(fetchMock).mockResolvedValue( + new Response("dummy-body", { + status: 200, + headers: { + etag: DUMMY_ETAG, + "Content-Range": "bytes 0-54/55", + }, + }) + ); + + // prevent to use caching + vi.mocked(stat).mockRejectedValue(new Error("Do not exists")); + vi.mocked(lstat).mockRejectedValue(new Error("Do not exists")); + }); + + test("should throw an error if fileDownloadInfo return nothing", async () => { + await expect(async () => { + await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + }); + }).rejects.toThrowError("cannot get path info for /README.md"); + + expect(pathsInfo).toHaveBeenCalledWith( + expect.objectContaining({ + repo: DUMMY_REPO, + paths: ["/README.md"], + fetch: fetchMock, + }) + ); + }); + + test("existing symlinked and blob should not re-download it", async () => { + // ///snapshots/README.md + const expectPointer = _getSnapshotFile({ + repo: DUMMY_REPO, + path: "/README.md", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + // stat ensure a symlink and the pointed file exists + vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject + + const output = await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + expect(stat).toHaveBeenCalledOnce(); + // Get call argument for stat + const starArg = vi.mocked(stat).mock.calls[0][0]; + + expect(starArg).toBe(expectPointer); + expect(fetchMock).not.toHaveBeenCalledWith(); + + expect(output).toBe(expectPointer); + }); + + test("existing symlinked and blob with default revision should not re-download it", async () => { + // ///snapshots/README.md + const expectPointer = _getSnapshotFile({ + repo: DUMMY_REPO, + path: "/README.md", + revision: "main", + }); + // stat ensure a symlink and the pointed file exists + vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject + vi.mocked(lstat).mockResolvedValue({} as Stats); + vi.mocked(pathsInfo).mockResolvedValue([ + { + oid: DUMMY_ETAG, + size: 55, + path: "README.md", + type: "file", + lastCommit: { + date: new Date(), + id: "main", + title: "Commit msg", + }, + }, + ]); + + const output = await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + }); + + expect(stat).toHaveBeenCalledOnce(); + expect(symlink).not.toHaveBeenCalledOnce(); + // Get call argument for stat + const starArg = vi.mocked(stat).mock.calls[0][0]; + + expect(starArg).toBe(expectPointer); + expect(fetchMock).not.toHaveBeenCalledWith(); + + expect(output).toBe(expectPointer); + }); + + test("existing blob should only create the symlink", async () => { + // ///snapshots/README.md + const expectPointer = _getSnapshotFile({ + repo: DUMMY_REPO, + path: "/README.md", + revision: "dummy-commit-hash", + }); + // //blobs/ + const expectedBlob = _getBlobFile({ + repo: DUMMY_REPO, + etag: DUMMY_ETAG, + }); + + // mock existing blob only no symlink + vi.mocked(lstat).mockResolvedValue({} as Stats); + // mock pathsInfo resolve content + vi.mocked(pathsInfo).mockResolvedValue([ + { + oid: DUMMY_ETAG, + size: 55, + path: "README.md", + type: "file", + lastCommit: { + date: new Date(), + id: "dummy-commit-hash", + title: "Commit msg", + }, + }, + ]); + + const output = await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + }); + + // should have check for the blob + expect(lstat).toHaveBeenCalled(); + expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob); + + // symlink should have been created + expect(createSymlink).toHaveBeenCalledOnce(); + // no download done + expect(fetchMock).not.toHaveBeenCalled(); + + expect(output).toBe(expectPointer); + }); + + test("expect resolve value to be the pointer path of downloaded file", async () => { + // ///snapshots/README.md + const expectPointer = _getSnapshotFile({ + repo: DUMMY_REPO, + path: "/README.md", + revision: "dummy-commit-hash", + }); + // //blobs/ + const expectedBlob = _getBlobFile({ + repo: DUMMY_REPO, + etag: DUMMY_ETAG, + }); + + vi.mocked(pathsInfo).mockResolvedValue([ + { + oid: DUMMY_ETAG, + size: 55, + path: "README.md", + type: "file", + lastCommit: { + date: new Date(), + id: "dummy-commit-hash", + title: "Commit msg", + }, + }, + ]); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any); + + const output = await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + }); + + // expect blobs and snapshots folder to have been mkdir + expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob)); + expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer)); + + expect(output).toBe(expectPointer); + }); + + test("should write fetch response to blob", async () => { + // ///snapshots/README.md + const expectPointer = _getSnapshotFile({ + repo: DUMMY_REPO, + path: "/README.md", + revision: "dummy-commit-hash", + }); + // //blobs/ + const expectedBlob = _getBlobFile({ + repo: DUMMY_REPO, + etag: DUMMY_ETAG, + }); + + // mock pathsInfo resolve content + vi.mocked(pathsInfo).mockResolvedValue([ + { + oid: DUMMY_ETAG, + size: 55, + path: "README.md", + type: "file", + lastCommit: { + date: new Date(), + id: "dummy-commit-hash", + title: "Commit msg", + }, + }, + ]); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any); + + await downloadFileToCacheDir({ + repo: DUMMY_REPO, + path: "/README.md", + fetch: fetchMock, + }); + + const incomplete = `${expectedBlob}.incomplete`; + // 1. should write fetch#response#body to incomplete file + expect(createWriteStream).toHaveBeenCalledWith(incomplete); + // 2. should rename the incomplete to the blob expected name + expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob); + // 3. should create symlink pointing to blob + expect(createSymlink).toHaveBeenCalledWith({ sourcePath: expectedBlob, finalPath: expectPointer }); + }); +}); diff --git a/lib/download-file-to-cache-dir.ts b/lib/download-file-to-cache-dir.ts new file mode 100644 index 0000000000000000000000000000000000000000..a7b67d9d214c86be8decc5ad68686d772ad1de1e --- /dev/null +++ b/lib/download-file-to-cache-dir.ts @@ -0,0 +1,138 @@ +import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; +import { dirname, join } from "node:path"; +import { rename, lstat, mkdir, stat } from "node:fs/promises"; +import type { CommitInfo, PathInfo } from "./paths-info"; +import { pathsInfo } from "./paths-info"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { toRepoId } from "../utils/toRepoId"; +import { downloadFile } from "./download-file"; +import { createSymlink } from "../utils/symlink"; +import { Readable } from "node:stream"; +import type { ReadableStream } from "node:stream/web"; +import { pipeline } from "node:stream/promises"; +import { createWriteStream } from "node:fs"; + +export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$"); + +function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string { + const snapshotPath = join(storageFolder, "snapshots"); + return join(snapshotPath, revision, relativeFilename); +} + +/** + * handy method to check if a file exists, or the pointer of a symlinks exists + * @param path + * @param followSymlinks + */ +async function exists(path: string, followSymlinks?: boolean): Promise { + try { + if (followSymlinks) { + await stat(path); + } else { + await lstat(path); + } + return true; + } catch (err: unknown) { + return false; + } +} + +/** + * Download a given file if it's not already present in the local cache. + * @param params + * @return the symlink to the blob object + */ +export async function downloadFileToCacheDir( + params: { + repo: RepoDesignation; + path: string; + /** + * If true, will download the raw git file. + * + * For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead. + */ + raw?: boolean; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + cacheDir?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + // get revision provided or default to main + const revision = params.revision ?? "main"; + const cacheDir = params.cacheDir ?? getHFHubCachePath(); + // get repo id + const repoId = toRepoId(params.repo); + // get storage folder + const storageFolder = join(cacheDir, getRepoFolderName(repoId)); + + let commitHash: string | undefined; + + // if user provides a commitHash as revision, and they already have the file on disk, shortcut everything. + if (REGEX_COMMIT_HASH.test(revision)) { + commitHash = revision; + const pointerPath = getFilePointer(storageFolder, revision, params.path); + if (await exists(pointerPath, true)) return pointerPath; + } + + const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({ + ...params, + paths: [params.path], + revision: revision, + expand: true, + }); + if (!pathsInformation || pathsInformation.length !== 1) throw new Error(`cannot get path info for ${params.path}`); + + let etag: string; + if (pathsInformation[0].lfs) { + etag = pathsInformation[0].lfs.oid; // get the LFS pointed file oid + } else { + etag = pathsInformation[0].oid; // get the repo file if not a LFS pointer + } + + const pointerPath = getFilePointer(storageFolder, commitHash ?? pathsInformation[0].lastCommit.id, params.path); + const blobPath = join(storageFolder, "blobs", etag); + + // if we have the pointer file, we can shortcut the download + if (await exists(pointerPath, true)) return pointerPath; + + // mkdir blob and pointer path parent directory + await mkdir(dirname(blobPath), { recursive: true }); + await mkdir(dirname(pointerPath), { recursive: true }); + + // We might already have the blob but not the pointer + // shortcut the download if needed + if (await exists(blobPath)) { + // create symlinks in snapshot folder to blob object + await createSymlink({ sourcePath: blobPath, finalPath: pointerPath }); + return pointerPath; + } + + const incomplete = `${blobPath}.incomplete`; + console.debug(`Downloading ${params.path} to ${incomplete}`); + + const blob: Blob | null = await downloadFile({ + ...params, + revision: commitHash, + }); + + if (!blob) { + throw new Error(`invalid response for file ${params.path}`); + } + + await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete)); + + // rename .incomplete file to expect blob + await rename(incomplete, blobPath); + // create symlinks in snapshot folder to blob object + await createSymlink({ sourcePath: blobPath, finalPath: pointerPath }); + return pointerPath; +} diff --git a/lib/download-file.spec.ts b/lib/download-file.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..c25a8b69b8c338500bba328111f6d3f8e3fd3427 --- /dev/null +++ b/lib/download-file.spec.ts @@ -0,0 +1,82 @@ +import { expect, test, describe, assert } from "vitest"; +import { downloadFile } from "./download-file"; +import { deleteRepo } from "./delete-repo"; +import { createRepo } from "./create-repo"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import { insecureRandomString } from "../utils/insecureRandomString"; + +describe("downloadFile", () => { + test("should download regular file", async () => { + const blob = await downloadFile({ + repo: { + type: "model", + name: "openai-community/gpt2", + }, + path: "README.md", + }); + + const text = await blob?.slice(0, 1000).text(); + assert( + text?.includes(`--- +language: en +tags: +- exbert + +license: mit +--- + + +# GPT-2 + +Test the whole generation capabilities here: https://transformer.huggingface.co/doc/gpt2-large`) + ); + }); + test("should downoad xet file", async () => { + const blob = await downloadFile({ + repo: { + type: "model", + name: "celinah/xet-experiments", + }, + path: "large_text.txt", + }); + + const text = await blob?.slice(0, 100).text(); + expect(text).toMatch("this is a text file.".repeat(10).slice(0, 100)); + }); + + test("should download private file", async () => { + const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`; + + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + private: true, + repo: repoName, + files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }], + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + try { + const blob = await downloadFile({ + repo: repoName, + path: ".gitattributes", + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, + }); + + assert(blob, "File should be found"); + + const text = await blob?.text(); + assert.strictEqual(text, "*.html filter=lfs diff=lfs merge=lfs -text"); + } finally { + await deleteRepo({ + repo: repoName, + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, + }); + } + }); +}); diff --git a/lib/download-file.ts b/lib/download-file.ts new file mode 100644 index 0000000000000000000000000000000000000000..5174bc09da2b12985c92de2acf2a508e5ce90278 --- /dev/null +++ b/lib/download-file.ts @@ -0,0 +1,77 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { WebBlob } from "../utils/WebBlob"; +import { XetBlob } from "../utils/XetBlob"; +import type { FileDownloadInfoOutput } from "./file-download-info"; +import { fileDownloadInfo } from "./file-download-info"; + +/** + * @returns null when the file doesn't exist + */ +export async function downloadFile( + params: { + repo: RepoDesignation; + path: string; + /** + * If true, will download the raw git file. + * + * For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead. + */ + raw?: boolean; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + /** + * Whether to use the xet protocol to download the file (if applicable). + * + * Currently there's experimental support for it, so it's not enabled by default. + * + * It will be enabled automatically in a future minor version. + * + * @default false + */ + xet?: boolean; + /** + * Can save an http request if provided + */ + downloadInfo?: FileDownloadInfoOutput; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + + const info = + params.downloadInfo ?? + (await fileDownloadInfo({ + accessToken, + repo: params.repo, + path: params.path, + revision: params.revision, + hubUrl: params.hubUrl, + fetch: params.fetch, + raw: params.raw, + })); + + if (!info) { + return null; + } + + if (info.xet && params.xet) { + return new XetBlob({ + refreshUrl: info.xet.refreshUrl.href, + reconstructionUrl: info.xet.reconstructionUrl.href, + fetch: params.fetch, + accessToken, + size: info.size, + }); + } + + return new WebBlob(new URL(info.url), 0, info.size, "", true, params.fetch ?? fetch, accessToken); +} diff --git a/lib/file-download-info.spec.ts b/lib/file-download-info.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..d2be156626c267cbbb642f6a8e0af47b2cbbd130 --- /dev/null +++ b/lib/file-download-info.spec.ts @@ -0,0 +1,59 @@ +import { assert, it, describe } from "vitest"; +import { fileDownloadInfo } from "./file-download-info"; + +describe("fileDownloadInfo", () => { + it("should fetch LFS file info", async () => { + const info = await fileDownloadInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + path: "tf_model.h5", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + assert.strictEqual(info?.size, 536063208); + assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"'); + }); + + it("should fetch raw LFS pointer info", async () => { + const info = await fileDownloadInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + path: "tf_model.h5", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + raw: true, + }); + + assert.strictEqual(info?.size, 134); + assert.strictEqual(info?.etag, '"9eb98c817f04b051b3bcca591bcd4e03cec88018"'); + }); + + it("should fetch non-LFS file info", async () => { + const info = await fileDownloadInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + path: "tokenizer_config.json", + revision: "1a7dd4986e3dab699c24ca19b2afd0f5e1a80f37", + }); + + assert.strictEqual(info?.size, 28); + assert.strictEqual(info?.etag, '"a661b1a138dac6dc5590367402d100765010ffd6"'); + }); + + it("should fetch xet file info", async () => { + const info = await fileDownloadInfo({ + repo: { + type: "model", + name: "celinah/xet-experiments", + }, + path: "large_text.txt", + }); + assert.strictEqual(info?.size, 62914580); + assert.strictEqual(info?.etag, '"c27f98578d9363b27db0bc1cbd9c692f8e6e90ae98c38cee7bc0a88829debd17"'); + }); +}); diff --git a/lib/file-download-info.ts b/lib/file-download-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..6ff4f6cebb26ae98d9d1aa509b3fefb3f494d23e --- /dev/null +++ b/lib/file-download-info.ts @@ -0,0 +1,151 @@ +import { HUB_URL } from "../consts"; +import { createApiError, InvalidApiResponseFormatError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { toRepoId } from "../utils/toRepoId"; + +export interface XetFileInfo { + hash: string; + refreshUrl: URL; + /** + * Can be directly used instead of the hash. + */ + reconstructionUrl: URL; +} + +export interface FileDownloadInfoOutput { + size: number; + etag: string; + xet?: XetFileInfo; + // URL to fetch (with the access token if private file) + url: string; +} +/** + * @returns null when the file doesn't exist + */ +export async function fileDownloadInfo( + params: { + repo: RepoDesignation; + path: string; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + /** + * To get the raw pointer file behind a LFS file + */ + raw?: boolean; + /** + * To avoid the content-disposition header in the `downloadLink` for LFS files + * + * So that on browsers you can use the URL in an iframe for example + */ + noContentDisposition?: boolean; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + const hubUrl = params.hubUrl ?? HUB_URL; + const url = + `${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/${ + params.raw ? "raw" : "resolve" + }/${encodeURIComponent(params.revision ?? "main")}/${params.path}` + + (params.noContentDisposition ? "?noContentDisposition=1" : ""); + + const resp = await (params.fetch ?? fetch)(url, { + method: "GET", + headers: { + ...(accessToken && { + Authorization: `Bearer ${accessToken}`, + }), + Range: "bytes=0-0", + Accept: "application/vnd.xet-fileinfo+json, */*", + }, + }); + + if (resp.status === 404 && resp.headers.get("X-Error-Code") === "EntryNotFound") { + return null; + } + + if (!resp.ok) { + throw await createApiError(resp); + } + + let size: number | undefined; + let xetInfo: XetFileInfo | undefined; + + if (resp.headers.get("Content-Type")?.includes("application/vnd.xet-fileinfo+json")) { + size = parseInt(resp.headers.get("X-Linked-Size") ?? "invalid"); + if (isNaN(size)) { + throw new InvalidApiResponseFormatError("Invalid file size received in X-Linked-Size header"); + } + + const hash = resp.headers.get("X-Xet-Hash"); + const links = parseLinkHeader(resp.headers.get("Link") ?? ""); + + const reconstructionUrl = (() => { + try { + return new URL(links["xet-reconstruction-info"]); + } catch { + return null; + } + })(); + const refreshUrl = (() => { + try { + return new URL(links["xet-auth"]); + } catch { + return null; + } + })(); + + if (!hash) { + throw new InvalidApiResponseFormatError("No hash received in X-Xet-Hash header"); + } + + if (!reconstructionUrl || !refreshUrl) { + throw new InvalidApiResponseFormatError("No xet-reconstruction-info or xet-auth link header"); + } + xetInfo = { + hash, + refreshUrl, + reconstructionUrl, + }; + } + + if (size === undefined || isNaN(size)) { + const contentRangeHeader = resp.headers.get("content-range"); + + if (!contentRangeHeader) { + throw new InvalidApiResponseFormatError("Expected size information"); + } + + const [, parsedSize] = contentRangeHeader.split("/"); + size = parseInt(parsedSize); + + if (isNaN(size)) { + throw new InvalidApiResponseFormatError("Invalid file size received"); + } + } + + const etag = resp.headers.get("X-Linked-ETag") ?? resp.headers.get("ETag") ?? undefined; + + if (!etag) { + throw new InvalidApiResponseFormatError("Expected ETag"); + } + + return { + etag, + size, + xet: xetInfo, + // Cannot use resp.url in case it's a S3 url and the user adds an Authorization header to it. + url: + resp.url && + (new URL(resp.url).origin === new URL(hubUrl).origin || resp.headers.get("X-Cache")?.endsWith(" cloudfront")) + ? resp.url + : url, + }; +} diff --git a/lib/file-exists.spec.ts b/lib/file-exists.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..e20acdf3bc13d233d9e3a0d0a5e0cb35e7b7a3bb --- /dev/null +++ b/lib/file-exists.spec.ts @@ -0,0 +1,30 @@ +import { assert, it, describe } from "vitest"; +import { fileExists } from "./file-exists"; + +describe("fileExists", () => { + it("should return true for file that exists", async () => { + const info = await fileExists({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + path: "tf_model.h5", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + assert(info, "file should exist"); + }); + + it("should return false for file that does not exist", async () => { + const info = await fileExists({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + path: "tf_model.h5dadazdzazd", + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + assert(!info, "file should not exist"); + }); +}); diff --git a/lib/file-exists.ts b/lib/file-exists.ts new file mode 100644 index 0000000000000000000000000000000000000000..64acf1dd39c7ba512586e77ecce1149adfa264b9 --- /dev/null +++ b/lib/file-exists.ts @@ -0,0 +1,41 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +export async function fileExists( + params: { + repo: RepoDesignation; + path: string; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + const hubUrl = params.hubUrl ?? HUB_URL; + const url = `${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/raw/${encodeURIComponent( + params.revision ?? "main" + )}/${params.path}`; + + const resp = await (params.fetch ?? fetch)(url, { + method: "HEAD", + headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}, + }); + + if (resp.status === 404) { + return false; + } + + if (!resp.ok) { + throw await createApiError(resp); + } + + return true; +} diff --git a/lib/index.ts b/lib/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..cd777784f7d604283d601f535892c43544b0a78b --- /dev/null +++ b/lib/index.ts @@ -0,0 +1,32 @@ +export * from "./cache-management"; +export * from "./check-repo-access"; +export * from "./commit"; +export * from "./count-commits"; +export * from "./create-repo"; +export * from "./create-branch"; +export * from "./dataset-info"; +export * from "./delete-branch"; +export * from "./delete-file"; +export * from "./delete-files"; +export * from "./delete-repo"; +export * from "./download-file"; +export * from "./download-file-to-cache-dir"; +export * from "./file-download-info"; +export * from "./file-exists"; +export * from "./list-commits"; +export * from "./list-datasets"; +export * from "./list-files"; +export * from "./list-models"; +export * from "./list-spaces"; +export * from "./model-info"; +export * from "./oauth-handle-redirect"; +export * from "./oauth-login-url"; +export * from "./parse-safetensors-metadata"; +export * from "./paths-info"; +export * from "./repo-exists"; +export * from "./snapshot-download"; +export * from "./space-info"; +export * from "./upload-file"; +export * from "./upload-files"; +export * from "./upload-files-with-progress"; +export * from "./who-am-i"; diff --git a/lib/list-commits.spec.ts b/lib/list-commits.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..a1f4dd5e5555923c40d02fc2024d765fea07b4ff --- /dev/null +++ b/lib/list-commits.spec.ts @@ -0,0 +1,117 @@ +import { assert, it, describe } from "vitest"; +import type { CommitData } from "./list-commits"; +import { listCommits } from "./list-commits"; + +describe("listCommits", () => { + it("should fetch paginated commits from the repo", async () => { + const commits: CommitData[] = []; + for await (const commit of listCommits({ + repo: { + name: "openai-community/gpt2", + type: "model", + }, + revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e", + batchSize: 5, + })) { + commits.push(commit); + } + + assert.equal(commits.length, 26); + assert.deepEqual(commits.slice(0, 6), [ + { + oid: "607a30d783dfa663caf39e06633721c8d4cfcd7e", + title: "Adds the tokenizer configuration file (#80)", + message: "\n\n\n- Adds tokenizer_config.json file (db6d57930088fb63e52c010bd9ac77c955ac55e7)\n\n", + authors: [ + { + username: "lysandre", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/5e3aec01f55e2b62848a5217/PMKS0NNB4MJQlTSFzh918.jpeg", + }, + ], + date: new Date("2024-02-19T10:57:45.000Z"), + }, + { + oid: "11c5a3d5811f50298f278a704980280950aedb10", + title: "Adding ONNX file of this model (#60)", + message: "\n\n\n- Adding ONNX file of this model (9411f419c589519e1a46c94ac7789ea20fd7c322)\n\n", + authors: [ + { + username: "fxmarty", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/1651743336129-624c60cba8ec93a7ac188b56.png", + }, + ], + date: new Date("2023-06-30T02:19:43.000Z"), + }, + { + oid: "e7da7f221d5bf496a48136c0cd264e630fe9fcc8", + title: "Update generation_config.json", + message: "", + authors: [ + { + username: "joaogante", + avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png", + }, + ], + date: new Date("2022-12-16T15:44:21.000Z"), + }, + { + oid: "f27b190eeac4c2302d24068eabf5e9d6044389ae", + title: "Add note that this is the smallest version of the model (#18)", + message: + "\n\n\n- Add note that this is the smallest version of the model (611838ef095a5bb35bf2027d05e1194b7c9d37ac)\n\n\nCo-authored-by: helen \n", + authors: [ + { + username: "sgugger", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg", + }, + { + username: "mathemakitten", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/1658248499901-6079afe2d2cd8c150e6ae05e.jpeg", + }, + ], + date: new Date("2022-11-23T12:55:26.000Z"), + }, + { + oid: "0dd7bcc7a64e4350d8859c9a2813132fbf6ae591", + title: "Our very first generation_config.json (#17)", + message: + "\n\n\n- Our very first generation_config.json (671851b7e9d56ef062890732065d7bd5f4628bd6)\n\n\nCo-authored-by: Joao Gante \n", + authors: [ + { + username: "sgugger", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg", + }, + { + username: "joaogante", + avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png", + }, + ], + date: new Date("2022-11-18T18:19:30.000Z"), + }, + { + oid: "75e09b43581151bd1d9ef6700faa605df408979f", + title: "Upload model.safetensors with huggingface_hub (#12)", + message: + "\n\n\n- Upload model.safetensors with huggingface_hub (ba2f794b2e4ea09ef932a6628fa0815dfaf09661)\n\n\nCo-authored-by: Nicolas Patry \n", + authors: [ + { + username: "julien-c", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/5dd96eb166059660ed1ee413/NQtzmrDdbG0H8qkZvRyGk.jpeg", + }, + { + username: "Narsil", + avatarUrl: + "https://cdn-avatars.huggingface.co/v1/production/uploads/1608285816082-5e2967b819407e3277369b95.png", + }, + ], + date: new Date("2022-10-20T09:34:54.000Z"), + }, + ]); + }); +}); diff --git a/lib/list-commits.ts b/lib/list-commits.ts new file mode 100644 index 0000000000000000000000000000000000000000..2bbcc99f71fb21f21acee2afe402cf29ada30bea --- /dev/null +++ b/lib/list-commits.ts @@ -0,0 +1,70 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiCommitData } from "../types/api/api-commit"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { toRepoId } from "../utils/toRepoId"; + +export interface CommitData { + oid: string; + title: string; + message: string; + authors: Array<{ username: string; avatarUrl: string }>; + date: Date; +} + +export async function* listCommits( + params: { + repo: RepoDesignation; + /** + * Revision to list commits from. Defaults to the default branch. + */ + revision?: string; + hubUrl?: string; + /** + * Number of commits to fetch from the hub each http call. Defaults to 100. Can be set to 1000. + */ + batchSize?: number; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): AsyncGenerator { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + // Could upgrade to 1000 commits per page + let url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${ + params.revision ?? "main" + }?limit=${params.batchSize ?? 100}`; + + while (url) { + const res: Response = await (params.fetch ?? fetch)(url, { + headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {}, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const resJson: ApiCommitData[] = await res.json(); + for (const commit of resJson) { + yield { + oid: commit.id, + title: commit.title, + message: commit.message, + authors: commit.authors.map((author) => ({ + username: author.user, + avatarUrl: author.avatar, + })), + date: new Date(commit.date), + }; + } + + const linkHeader = res.headers.get("Link"); + + url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + } +} diff --git a/lib/list-datasets.spec.ts b/lib/list-datasets.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..24993130ba7649ce083b993da16488096bad7549 --- /dev/null +++ b/lib/list-datasets.spec.ts @@ -0,0 +1,47 @@ +import { describe, expect, it } from "vitest"; +import type { DatasetEntry } from "./list-datasets"; +import { listDatasets } from "./list-datasets"; + +describe("listDatasets", () => { + it("should list datasets from hf-doc-builder", async () => { + const results: DatasetEntry[] = []; + + for await (const entry of listDatasets({ search: { owner: "hf-doc-build" } })) { + if (entry.name === "hf-doc-build/doc-build-dev-test") { + continue; + } + if (typeof entry.downloads === "number") { + entry.downloads = 0; + } + if (typeof entry.likes === "number") { + entry.likes = 0; + } + if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) { + entry.updatedAt = new Date(0); + } + + results.push(entry); + } + + expect(results).deep.equal([ + { + id: "6356b19985da6f13863228bd", + name: "hf-doc-build/doc-build", + private: false, + gated: false, + downloads: 0, + likes: 0, + updatedAt: new Date(0), + }, + { + id: "636a1b69f2f9ec4289c4c19e", + name: "hf-doc-build/doc-build-dev", + gated: false, + private: false, + downloads: 0, + likes: 0, + updatedAt: new Date(0), + }, + ]); + }); +}); diff --git a/lib/list-datasets.ts b/lib/list-datasets.ts new file mode 100644 index 0000000000000000000000000000000000000000..fecfa8c321229c93e102d776eb8a7700c6699eb8 --- /dev/null +++ b/lib/list-datasets.ts @@ -0,0 +1,121 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; +import type { CredentialsParams } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { pick } from "../utils/pick"; + +export const DATASET_EXPAND_KEYS = [ + "private", + "downloads", + "gated", + "likes", + "lastModified", +] as const satisfies readonly (keyof ApiDatasetInfo)[]; + +export const DATASET_EXPANDABLE_KEYS = [ + "author", + "cardData", + "citation", + "createdAt", + "disabled", + "description", + "downloads", + "downloadsAllTime", + "gated", + "gitalyUid", + "lastModified", + "likes", + "paperswithcode_id", + "private", + // "siblings", + "sha", + "tags", +] as const satisfies readonly (keyof ApiDatasetInfo)[]; + +export interface DatasetEntry { + id: string; + name: string; + private: boolean; + downloads: number; + gated: false | "auto" | "manual"; + likes: number; + updatedAt: Date; +} + +export async function* listDatasets< + const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never, +>( + params?: { + search?: { + /** + * Will search in the dataset name for matches + */ + query?: string; + owner?: string; + tags?: string[]; + }; + hubUrl?: string; + additionalFields?: T[]; + /** + * Set to limit the number of models returned. + */ + limit?: number; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): AsyncGenerator> { + const accessToken = params && checkCredentials(params); + let totalToFetch = params?.limit ?? Infinity; + const search = new URLSearchParams([ + ...Object.entries({ + limit: String(Math.min(totalToFetch, 500)), + ...(params?.search?.owner ? { author: params.search.owner } : undefined), + ...(params?.search?.query ? { search: params.search.query } : undefined), + }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), + ...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), + ]).toString(); + let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : ""); + + while (url) { + const res: Response = await (params?.fetch ?? fetch)(url, { + headers: { + accept: "application/json", + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const items: ApiDatasetInfo[] = await res.json(); + + for (const item of items) { + yield { + ...(params?.additionalFields && pick(item, params.additionalFields)), + id: item._id, + name: item.id, + private: item.private, + downloads: item.downloads, + likes: item.likes, + gated: item.gated, + updatedAt: new Date(item.lastModified), + } as DatasetEntry & Pick; + totalToFetch--; + if (totalToFetch <= 0) { + return; + } + } + + const linkHeader = res.headers.get("Link"); + + url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + // Could update limit in url to fetch less items if not all items of next page are needed. + } +} diff --git a/lib/list-files.spec.ts b/lib/list-files.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..7014193075715583b74fea53ab4a7875a9732934 --- /dev/null +++ b/lib/list-files.spec.ts @@ -0,0 +1,173 @@ +import { assert, it, describe } from "vitest"; +import type { ListFileEntry } from "./list-files"; +import { listFiles } from "./list-files"; + +describe("listFiles", () => { + it("should fetch the list of files from the repo", async () => { + const cursor = listFiles({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + const files: ListFileEntry[] = []; + + for await (const entry of cursor) { + files.push(entry); + } + + assert.deepStrictEqual(files, [ + { + oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f", + path: ".gitattributes", + size: 345, + type: "file", + }, + { + oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df", + path: "config.json", + size: 313, + type: "file", + }, + { + lfs: { + oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a", + size: 440473133, + pointerSize: 134, + }, + xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296", + oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5", + path: "pytorch_model.bin", + size: 440473133, + type: "file", + }, + { + lfs: { + oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2", + size: 536063208, + pointerSize: 134, + }, + xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b", + oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018", + path: "tf_model.h5", + size: 536063208, + type: "file", + }, + { + oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938", + path: "vocab.txt", + size: 231508, + type: "file", + }, + ]); + }); + + it("should fetch the list of files from the repo, including last commit", async () => { + const cursor = listFiles({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + expand: true, + }); + + const files: ListFileEntry[] = []; + + for await (const entry of cursor) { + delete entry.securityFileStatus; // flaky + files.push(entry); + } + + assert.deepStrictEqual(files, [ + { + lastCommit: { + date: "2018-11-14T23:35:08.000Z", + id: "504939aa53e8ce310dba3dd2296dbe266c575de4", + title: "initial commit", + }, + oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f", + path: ".gitattributes", + size: 345, + type: "file", + }, + { + lastCommit: { + date: "2019-06-18T09:06:51.000Z", + id: "bb3c1c3256d2598217df9889a14a2e811587891d", + title: "Update config.json", + }, + oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df", + path: "config.json", + size: 313, + type: "file", + }, + { + lastCommit: { + date: "2019-06-18T09:06:34.000Z", + id: "3d2477d72b675a999d1b13ca822aaaf4908634ad", + title: "Update pytorch_model.bin", + }, + lfs: { + oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a", + size: 440473133, + pointerSize: 134, + }, + xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296", + oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5", + path: "pytorch_model.bin", + size: 440473133, + type: "file", + }, + { + lastCommit: { + date: "2019-09-23T19:48:44.000Z", + id: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + title: "Update tf_model.h5", + }, + lfs: { + oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2", + size: 536063208, + pointerSize: 134, + }, + xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b", + oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018", + path: "tf_model.h5", + size: 536063208, + type: "file", + }, + { + lastCommit: { + date: "2018-11-14T23:35:08.000Z", + id: "2f07d813ca87c8c709147704c87210359ccf2309", + title: "Update vocab.txt", + }, + oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938", + path: "vocab.txt", + size: 231508, + type: "file", + }, + ]); + }); + + it("should fetch the list of files from the repo, including subfolders", async () => { + const cursor = listFiles({ + repo: { + name: "xsum", + type: "dataset", + }, + revision: "0f3ea2f2b55fcb11e71fb1e3aec6822e44ddcb0f", + recursive: true, + }); + + const files: ListFileEntry[] = []; + + for await (const entry of cursor) { + files.push(entry); + } + + assert(files.some((file) => file.path === "data/XSUM-EMNLP18-Summary-Data-Original.tar.gz")); + }); +}); diff --git a/lib/list-files.ts b/lib/list-files.ts new file mode 100644 index 0000000000000000000000000000000000000000..2bf76f7817106c8cdc3ae18140b0a4e196d67811 --- /dev/null +++ b/lib/list-files.ts @@ -0,0 +1,94 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiIndexTreeEntry } from "../types/api/api-index-tree"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { toRepoId } from "../utils/toRepoId"; + +export interface ListFileEntry { + type: "file" | "directory" | "unknown"; + size: number; + path: string; + oid: string; + lfs?: { + oid: string; + size: number; + /** Size of the raw pointer file, 100~200 bytes */ + pointerSize: number; + }; + /** + * Xet-backed hash, a new protocol replacing LFS for big files. + */ + xetHash?: string; + /** + * Only fetched if `expand` is set to `true` in the `listFiles` call. + */ + lastCommit?: { + date: string; + id: string; + title: string; + }; + /** + * Only fetched if `expand` is set to `true` in the `listFiles` call. + */ + securityFileStatus?: unknown; +} + +/** + * List files in a folder. To list ALL files in the directory, call it + * with {@link params.recursive} set to `true`. + */ +export async function* listFiles( + params: { + repo: RepoDesignation; + /** + * Do we want to list files in subdirectories? + */ + recursive?: boolean; + /** + * Eg 'data' for listing all files in the 'data' folder. Leave it empty to list all + * files in the repo. + */ + path?: string; + /** + * Fetch `lastCommit` and `securityFileStatus` for each file. + */ + expand?: boolean; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): AsyncGenerator { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + let url: string | undefined = `${params.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}/tree/${ + params.revision || "main" + }${params.path ? "/" + params.path : ""}?recursive=${!!params.recursive}&expand=${!!params.expand}`; + + while (url) { + const res: Response = await (params.fetch ?? fetch)(url, { + headers: { + accept: "application/json", + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const items: ApiIndexTreeEntry[] = await res.json(); + + for (const item of items) { + yield item; + } + + const linkHeader = res.headers.get("Link"); + + url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + } +} diff --git a/lib/list-models.spec.ts b/lib/list-models.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..93cbe812d9b7127ccb4d246d56b65be50ee245e1 --- /dev/null +++ b/lib/list-models.spec.ts @@ -0,0 +1,118 @@ +import { describe, expect, it } from "vitest"; +import type { ModelEntry } from "./list-models"; +import { listModels } from "./list-models"; + +describe("listModels", () => { + it("should list models for depth estimation", async () => { + const results: ModelEntry[] = []; + + for await (const entry of listModels({ + search: { owner: "Intel", task: "depth-estimation" }, + })) { + if (typeof entry.downloads === "number") { + entry.downloads = 0; + } + if (typeof entry.likes === "number") { + entry.likes = 0; + } + if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) { + entry.updatedAt = new Date(0); + } + + if (!["Intel/dpt-large", "Intel/dpt-hybrid-midas"].includes(entry.name)) { + expect(entry.task).to.equal("depth-estimation"); + continue; + } + + results.push(entry); + } + + results.sort((a, b) => a.id.localeCompare(b.id)); + + expect(results).deep.equal([ + { + id: "621ffdc136468d709f17e709", + name: "Intel/dpt-large", + private: false, + gated: false, + downloads: 0, + likes: 0, + task: "depth-estimation", + updatedAt: new Date(0), + }, + { + id: "638f07977559bf9a2b2b04ac", + name: "Intel/dpt-hybrid-midas", + gated: false, + private: false, + downloads: 0, + likes: 0, + task: "depth-estimation", + updatedAt: new Date(0), + }, + ]); + }); + + it("should list indonesian models with gguf format", async () => { + let count = 0; + for await (const entry of listModels({ + search: { tags: ["gguf", "id"] }, + additionalFields: ["tags"], + limit: 2, + })) { + count++; + expect(entry.tags).to.include("gguf"); + expect(entry.tags).to.include("id"); + } + + expect(count).to.equal(2); + }); + + it("should search model by name", async () => { + let count = 0; + for await (const entry of listModels({ + search: { query: "t5" }, + limit: 10, + })) { + count++; + expect(entry.name.toLocaleLowerCase()).to.include("t5"); + } + + expect(count).to.equal(10); + }); + + it("should search model by inference provider", async () => { + let count = 0; + for await (const entry of listModels({ + search: { inferenceProviders: ["together"] }, + additionalFields: ["inferenceProviderMapping"], + limit: 10, + })) { + count++; + if (Array.isArray(entry.inferenceProviderMapping)) { + expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together"); + } + } + + expect(count).to.equal(10); + }); + + it("should search model by several inference providers", async () => { + let count = 0; + const inferenceProviders = ["together", "replicate"]; + for await (const entry of listModels({ + search: { inferenceProviders }, + additionalFields: ["inferenceProviderMapping"], + limit: 10, + })) { + count++; + if (Array.isArray(entry.inferenceProviderMapping)) { + expect( + entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length + ).toBeGreaterThan(0); + } + } + + expect(count).to.equal(10); + }); +}); diff --git a/lib/list-models.ts b/lib/list-models.ts new file mode 100644 index 0000000000000000000000000000000000000000..edd317ff1812909c26707b3961bf193def72741e --- /dev/null +++ b/lib/list-models.ts @@ -0,0 +1,139 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiModelInfo } from "../types/api/api-model"; +import type { CredentialsParams, PipelineType } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { pick } from "../utils/pick"; + +export const MODEL_EXPAND_KEYS = [ + "pipeline_tag", + "private", + "gated", + "downloads", + "likes", + "lastModified", +] as const satisfies readonly (keyof ApiModelInfo)[]; + +export const MODEL_EXPANDABLE_KEYS = [ + "author", + "cardData", + "config", + "createdAt", + "disabled", + "downloads", + "downloadsAllTime", + "gated", + "gitalyUid", + "inferenceProviderMapping", + "lastModified", + "library_name", + "likes", + "model-index", + "pipeline_tag", + "private", + "safetensors", + "sha", + // "siblings", + "spaces", + "tags", + "transformersInfo", +] as const satisfies readonly (keyof ApiModelInfo)[]; + +export interface ModelEntry { + id: string; + name: string; + private: boolean; + gated: false | "auto" | "manual"; + task?: PipelineType; + likes: number; + downloads: number; + updatedAt: Date; +} + +export async function* listModels< + const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never, +>( + params?: { + search?: { + /** + * Will search in the model name for matches + */ + query?: string; + owner?: string; + task?: PipelineType; + tags?: string[]; + /** + * Will search for models that have one of the inference providers in the list. + */ + inferenceProviders?: string[]; + }; + hubUrl?: string; + additionalFields?: T[]; + /** + * Set to limit the number of models returned. + */ + limit?: number; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): AsyncGenerator> { + const accessToken = params && checkCredentials(params); + let totalToFetch = params?.limit ?? Infinity; + const search = new URLSearchParams([ + ...Object.entries({ + limit: String(Math.min(totalToFetch, 500)), + ...(params?.search?.owner ? { author: params.search.owner } : undefined), + ...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined), + ...(params?.search?.query ? { search: params.search.query } : undefined), + ...(params?.search?.inferenceProviders + ? { inference_provider: params.search.inferenceProviders.join(",") } + : undefined), + }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), + ...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), + ]).toString(); + let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`; + + while (url) { + const res: Response = await (params?.fetch ?? fetch)(url, { + headers: { + accept: "application/json", + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const items: ApiModelInfo[] = await res.json(); + + for (const item of items) { + yield { + ...(params?.additionalFields && pick(item, params.additionalFields)), + id: item._id, + name: item.id, + private: item.private, + task: item.pipeline_tag, + downloads: item.downloads, + gated: item.gated, + likes: item.likes, + updatedAt: new Date(item.lastModified), + } as ModelEntry & Pick; + totalToFetch--; + + if (totalToFetch <= 0) { + return; + } + } + + const linkHeader = res.headers.get("Link"); + + url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + // Could update url to reduce the limit if we don't need the whole 500 of the next batch. + } +} diff --git a/lib/list-spaces.spec.ts b/lib/list-spaces.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..3cc5999137a7837ae4312374a7a89ec46985dff3 --- /dev/null +++ b/lib/list-spaces.spec.ts @@ -0,0 +1,40 @@ +import { describe, expect, it } from "vitest"; +import type { SpaceEntry } from "./list-spaces"; +import { listSpaces } from "./list-spaces"; + +describe("listSpaces", () => { + it("should list spaces for Microsoft", async () => { + const results: SpaceEntry[] = []; + + for await (const entry of listSpaces({ + search: { owner: "microsoft" }, + additionalFields: ["subdomain"], + })) { + if (entry.name !== "microsoft/visual_chatgpt") { + continue; + } + if (typeof entry.likes === "number") { + entry.likes = 0; + } + if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) { + entry.updatedAt = new Date(0); + } + + results.push(entry); + } + + results.sort((a, b) => a.id.localeCompare(b.id)); + + expect(results).deep.equal([ + { + id: "6409a392bbc73d022c58c980", + name: "microsoft/visual_chatgpt", + private: false, + likes: 0, + sdk: "gradio", + subdomain: "microsoft-visual-chatgpt", + updatedAt: new Date(0), + }, + ]); + }); +}); diff --git a/lib/list-spaces.ts b/lib/list-spaces.ts new file mode 100644 index 0000000000000000000000000000000000000000..a14e7e301a439161bed34fde231faf19786fbee1 --- /dev/null +++ b/lib/list-spaces.ts @@ -0,0 +1,111 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiSpaceInfo } from "../types/api/api-space"; +import type { CredentialsParams, SpaceSdk } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { parseLinkHeader } from "../utils/parseLinkHeader"; +import { pick } from "../utils/pick"; + +export const SPACE_EXPAND_KEYS = [ + "sdk", + "likes", + "private", + "lastModified", +] as const satisfies readonly (keyof ApiSpaceInfo)[]; +export const SPACE_EXPANDABLE_KEYS = [ + "author", + "cardData", + "datasets", + "disabled", + "gitalyUid", + "lastModified", + "createdAt", + "likes", + "private", + "runtime", + "sdk", + // "siblings", + "sha", + "subdomain", + "tags", + "models", +] as const satisfies readonly (keyof ApiSpaceInfo)[]; + +export interface SpaceEntry { + id: string; + name: string; + sdk?: SpaceSdk; + likes: number; + private: boolean; + updatedAt: Date; + // Use additionalFields to fetch the fields from ApiSpaceInfo +} + +export async function* listSpaces< + const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never, +>( + params?: { + search?: { + /** + * Will search in the space name for matches + */ + query?: string; + owner?: string; + tags?: string[]; + }; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + /** + * Additional fields to fetch from huggingface.co. + */ + additionalFields?: T[]; + } & Partial +): AsyncGenerator> { + const accessToken = params && checkCredentials(params); + const search = new URLSearchParams([ + ...Object.entries({ + limit: "500", + ...(params?.search?.owner ? { author: params.search.owner } : undefined), + ...(params?.search?.query ? { search: params.search.query } : undefined), + }), + ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []), + ...[...SPACE_EXPAND_KEYS, ...(params?.additionalFields ?? [])].map( + (val) => ["expand", val] satisfies [string, string] + ), + ]).toString(); + let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`; + + while (url) { + const res: Response = await (params?.fetch ?? fetch)(url, { + headers: { + accept: "application/json", + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined), + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const items: ApiSpaceInfo[] = await res.json(); + + for (const item of items) { + yield { + ...(params?.additionalFields && pick(item, params.additionalFields)), + id: item._id, + name: item.id, + sdk: item.sdk, + likes: item.likes, + private: item.private, + updatedAt: new Date(item.lastModified), + } as SpaceEntry & Pick; + } + + const linkHeader = res.headers.get("Link"); + + url = linkHeader ? parseLinkHeader(linkHeader).next : undefined; + } +} diff --git a/lib/model-info.spec.ts b/lib/model-info.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..3657e964d76a877f933bad7d9bf332a39dcad708 --- /dev/null +++ b/lib/model-info.spec.ts @@ -0,0 +1,59 @@ +import { describe, expect, it } from "vitest"; +import { modelInfo } from "./model-info"; +import type { ModelEntry } from "./list-models"; +import type { ApiModelInfo } from "../types/api/api-model"; + +describe("modelInfo", () => { + it("should return the model info", async () => { + const info = await modelInfo({ + name: "openai-community/gpt2", + }); + expect(info).toEqual({ + id: "621ffdc036468d709f17434d", + downloads: expect.any(Number), + gated: false, + name: "openai-community/gpt2", + updatedAt: expect.any(Date), + likes: expect.any(Number), + task: "text-generation", + private: false, + }); + }); + + it("should return the model info with author", async () => { + const info: ModelEntry & Pick = await modelInfo({ + name: "openai-community/gpt2", + additionalFields: ["author"], + }); + expect(info).toEqual({ + id: "621ffdc036468d709f17434d", + downloads: expect.any(Number), + author: "openai-community", + gated: false, + name: "openai-community/gpt2", + updatedAt: expect.any(Date), + likes: expect.any(Number), + task: "text-generation", + private: false, + }); + }); + + it("should return the model info for a specific revision", async () => { + const info: ModelEntry & Pick = await modelInfo({ + name: "openai-community/gpt2", + additionalFields: ["sha"], + revision: "f27b190eeac4c2302d24068eabf5e9d6044389ae", + }); + expect(info).toEqual({ + id: "621ffdc036468d709f17434d", + downloads: expect.any(Number), + gated: false, + name: "openai-community/gpt2", + updatedAt: expect.any(Date), + likes: expect.any(Number), + task: "text-generation", + private: false, + sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae", + }); + }); +}); diff --git a/lib/model-info.ts b/lib/model-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..4e4291c3b405a10a9a66515939afd4dc8118d639 --- /dev/null +++ b/lib/model-info.ts @@ -0,0 +1,62 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiModelInfo } from "../types/api/api-model"; +import type { CredentialsParams } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { pick } from "../utils/pick"; +import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models"; + +export async function modelInfo< + const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never, +>( + params: { + name: string; + hubUrl?: string; + additionalFields?: T[]; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + */ + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise> { + const accessToken = params && checkCredentials(params); + + const search = new URLSearchParams([ + ...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), + ]).toString(); + + const response = await (params.fetch || fetch)( + `${params?.hubUrl || HUB_URL}/api/models/${params.name}/revision/${encodeURIComponent( + params.revision ?? "HEAD" + )}?${search.toString()}`, + { + headers: { + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + Accepts: "application/json", + }, + } + ); + + if (!response.ok) { + throw await createApiError(response); + } + + const data = await response.json(); + + return { + ...(params?.additionalFields && pick(data, params.additionalFields)), + id: data._id, + name: data.id, + private: data.private, + task: data.pipeline_tag, + downloads: data.downloads, + gated: data.gated, + likes: data.likes, + updatedAt: new Date(data.lastModified), + } as ModelEntry & Pick; +} diff --git a/lib/oauth-handle-redirect.spec.ts b/lib/oauth-handle-redirect.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..f06f6a40e0dcdb6621ea75dedd369c3771b11561 --- /dev/null +++ b/lib/oauth-handle-redirect.spec.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from "vitest"; +import { TEST_COOKIE, TEST_HUB_URL } from "../test/consts"; +import { oauthLoginUrl } from "./oauth-login-url"; +import { oauthHandleRedirect } from "./oauth-handle-redirect"; + +describe("oauthHandleRedirect", () => { + it("should work", async () => { + const localStorage = { + nonce: undefined, + codeVerifier: undefined, + }; + const url = await oauthLoginUrl({ + clientId: "dummy-app", + redirectUrl: "http://localhost:3000", + localStorage, + scopes: "openid profile email", + hubUrl: TEST_HUB_URL, + }); + const resp = await fetch(url, { + method: "POST", + headers: { + Cookie: `token=${TEST_COOKIE}`, + }, + redirect: "manual", + }); + if (resp.status !== 303) { + throw new Error(`Failed to fetch url ${url}: ${resp.status} ${resp.statusText}`); + } + const location = resp.headers.get("Location"); + if (!location) { + throw new Error(`No location header in response`); + } + const result = await oauthHandleRedirect({ + redirectedUrl: location, + codeVerifier: localStorage.codeVerifier, + nonce: localStorage.nonce, + hubUrl: TEST_HUB_URL, + }); + + if (!result) { + throw new Error("Expected result to be defined"); + } + expect(result.accessToken).toEqual(expect.any(String)); + expect(result.accessTokenExpiresAt).toBeInstanceOf(Date); + expect(result.accessTokenExpiresAt.getTime()).toBeGreaterThan(Date.now()); + expect(result.scope).toEqual(expect.any(String)); + expect(result.userInfo).toEqual({ + sub: "62f264b9f3c90f4b6514a269", + name: "@huggingface/hub CI bot", + preferred_username: "hub.js", + email_verified: true, + email: "eliott@huggingface.co", + isPro: false, + picture: "https://hub-ci.huggingface.co/avatars/934b830e9fdaa879487852f79eef7165.svg", + profile: "https://hub-ci.huggingface.co/hub.js", + website: "https://github.com/huggingface/hub.js", + orgs: [], + }); + }); +}); diff --git a/lib/oauth-handle-redirect.ts b/lib/oauth-handle-redirect.ts new file mode 100644 index 0000000000000000000000000000000000000000..4df6311ec480d3a2f41e3c0c3746c67bc2260f9c --- /dev/null +++ b/lib/oauth-handle-redirect.ts @@ -0,0 +1,334 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; + +export interface UserInfo { + /** + * OpenID Connect field. Unique identifier for the user, even in case of rename. + */ + sub: string; + /** + * OpenID Connect field. The user's full name. + */ + name: string; + /** + * OpenID Connect field. The user's username. + */ + preferred_username: string; + /** + * OpenID Connect field, available if scope "email" was granted. + */ + email_verified?: boolean; + /** + * OpenID Connect field, available if scope "email" was granted. + */ + email?: string; + /** + * OpenID Connect field. The user's profile picture URL. + */ + picture: string; + /** + * OpenID Connect field. The user's profile URL. + */ + profile: string; + /** + * OpenID Connect field. The user's website URL. + */ + website?: string; + + /** + * Hugging Face field. Whether the user is a pro user. + */ + isPro: boolean; + /** + * Hugging Face field. Whether the user has a payment method set up. Needs "read-billing" scope. + */ + canPay?: boolean; + /** + * Hugging Face field. The user's orgs + */ + orgs?: Array<{ + /** + * OpenID Connect field. Unique identifier for the org. + */ + sub: string; + /** + * OpenID Connect field. The org's full name. + */ + name: string; + /** + * OpenID Connect field. The org's username. + */ + preferred_username: string; + /** + * OpenID Connect field. The org's profile picture URL. + */ + picture: string; + + /** + * Hugging Face field. Whether the org is an enterprise org. + */ + isEnterprise: boolean; + /** + * Hugging Face field. Whether the org has a payment method set up. Needs "read-billing" scope, and the user needs to approve access to the org in the OAuth page. + */ + canPay?: boolean; + /** + * Hugging Face field. The user's role in the org. The user needs to approve access to the org in the OAuth page. + */ + roleInOrg?: string; + /** + * HuggingFace field. When the user granted the oauth app access to the org, but didn't complete SSO. + * + * Should never happen directly after the oauth flow. + */ + pendingSSO?: boolean; + /** + * HuggingFace field. When the user granted the oauth app access to the org, but didn't complete MFA. + * + * Should never happen directly after the oauth flow. + */ + missingMFA?: boolean; + }>; +} + +export interface OAuthResult { + accessToken: string; + accessTokenExpiresAt: Date; + userInfo: UserInfo; + /** + * State passed to the OAuth provider in the original request to the OAuth provider. + */ + state?: string; + /** + * Granted scope + */ + scope: string; +} + +/** + * To call after the OAuth provider redirects back to the app. + * + * There is also a helper function {@link oauthHandleRedirectIfPresent}, which will call `oauthHandleRedirect` if the URL contains an oauth code + * in the query parameters and return `false` otherwise. + */ +export async function oauthHandleRedirect(opts?: { + /** + * The URL of the hub. Defaults to {@link HUB_URL}. + */ + hubUrl?: string; + /** + * The URL to analyze. + * + * @default window.location.href + */ + redirectedUrl?: string; + /** + * nonce generated by oauthLoginUrl + * + * @default localStorage.getItem("huggingface.co:oauth:nonce") + */ + nonce?: string; + /** + * codeVerifier generated by oauthLoginUrl + * + * @default localStorage.getItem("huggingface.co:oauth:code_verifier") + */ + codeVerifier?: string; +}): Promise { + if (typeof window === "undefined" && !opts?.redirectedUrl) { + throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl"); + } + if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) { + throw new Error( + "oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier" + ); + } + + const redirectedUrl = opts?.redirectedUrl ?? window.location.href; + const searchParams = (() => { + try { + return new URL(redirectedUrl).searchParams; + } catch (err) { + throw new Error("Failed to parse redirected URL: " + redirectedUrl); + } + })(); + + const [error, errorDescription] = [searchParams.get("error"), searchParams.get("error_description")]; + + if (error) { + throw new Error(`${error}: ${errorDescription}`); + } + + const code = searchParams.get("code"); + const nonce = opts?.nonce ?? localStorage.getItem("huggingface.co:oauth:nonce"); + + if (!code) { + throw new Error("Missing oauth code from query parameters in redirected URL: " + redirectedUrl); + } + + if (!nonce) { + throw new Error("Missing oauth nonce from localStorage"); + } + + const codeVerifier = opts?.codeVerifier ?? localStorage.getItem("huggingface.co:oauth:code_verifier"); + + if (!codeVerifier) { + throw new Error("Missing oauth code_verifier from localStorage"); + } + + const state = searchParams.get("state"); + + if (!state) { + throw new Error("Missing oauth state from query parameters in redirected URL"); + } + + let parsedState: { nonce: string; redirectUri: string; state?: string }; + + try { + parsedState = JSON.parse(state); + } catch { + throw new Error("Invalid oauth state in redirected URL, unable to parse JSON: " + state); + } + + if (parsedState.nonce !== nonce) { + throw new Error("Invalid oauth state in redirected URL"); + } + + const hubUrl = opts?.hubUrl || HUB_URL; + + const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`; + const openidConfigRes = await fetch(openidConfigUrl, { + headers: { + Accept: "application/json", + }, + }); + + if (!openidConfigRes.ok) { + throw await createApiError(openidConfigRes); + } + + const openidConfig: { + authorization_endpoint: string; + token_endpoint: string; + userinfo_endpoint: string; + } = await openidConfigRes.json(); + + const tokenRes = await fetch(openidConfig.token_endpoint, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + }, + body: new URLSearchParams({ + grant_type: "authorization_code", + code, + redirect_uri: parsedState.redirectUri, + code_verifier: codeVerifier, + }).toString(), + }); + + if (!opts?.codeVerifier) { + localStorage.removeItem("huggingface.co:oauth:code_verifier"); + } + if (!opts?.nonce) { + localStorage.removeItem("huggingface.co:oauth:nonce"); + } + + if (!tokenRes.ok) { + throw await createApiError(tokenRes); + } + + const token: { + access_token: string; + expires_in: number; + id_token: string; + // refresh_token: string; + scope: string; + token_type: string; + } = await tokenRes.json(); + + const accessTokenExpiresAt = new Date(Date.now() + token.expires_in * 1000); + + const userInfoRes = await fetch(openidConfig.userinfo_endpoint, { + headers: { + Authorization: `Bearer ${token.access_token}`, + }, + }); + + if (!userInfoRes.ok) { + throw await createApiError(userInfoRes); + } + + const userInfo: UserInfo = await userInfoRes.json(); + + return { + accessToken: token.access_token, + accessTokenExpiresAt, + userInfo: userInfo, + state: parsedState.state, + scope: token.scope, + }; +} + +// if (code && !nonce) { +// console.warn("Missing oauth nonce from localStorage"); +// } + +/** + * To call after the OAuth provider redirects back to the app. + * + * It returns false if the URL does not contain an oauth code in the query parameters, otherwise + * it calls {@link oauthHandleRedirect}. + * + * Depending on your app, you may want to call {@link oauthHandleRedirect} directly instead. + */ +export async function oauthHandleRedirectIfPresent(opts?: { + /** + * The URL of the hub. Defaults to {@link HUB_URL}. + */ + hubUrl?: string; + /** + * The URL to analyze. + * + * @default window.location.href + */ + redirectedUrl?: string; + /** + * nonce generated by oauthLoginUrl + * + * @default localStorage.getItem("huggingface.co:oauth:nonce") + */ + nonce?: string; + /** + * codeVerifier generated by oauthLoginUrl + * + * @default localStorage.getItem("huggingface.co:oauth:code_verifier") + */ + codeVerifier?: string; +}): Promise { + if (typeof window === "undefined" && !opts?.redirectedUrl) { + throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl"); + } + if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) { + throw new Error( + "oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier" + ); + } + const searchParams = new URLSearchParams(opts?.redirectedUrl ?? window.location.search); + + if (searchParams.has("error")) { + return oauthHandleRedirect(opts); + } + + if (searchParams.has("code")) { + if (!localStorage.getItem("huggingface.co:oauth:nonce")) { + console.warn( + "Missing oauth nonce from localStorage. This can happen when the user refreshes the page after logging in, without changing the URL." + ); + return false; + } + + return oauthHandleRedirect(opts); + } + + return false; +} diff --git a/lib/oauth-login-url.ts b/lib/oauth-login-url.ts new file mode 100644 index 0000000000000000000000000000000000000000..0594d09a13ccd9b0998f9a8990e2ae01212c2174 --- /dev/null +++ b/lib/oauth-login-url.ts @@ -0,0 +1,166 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import { base64FromBytes } from "../utils/base64FromBytes"; + +/** + * Use "Sign in with Hub" to authenticate a user, and get oauth user info / access token. + * + * Returns an url to redirect to. After the user is redirected back to your app, call `oauthHandleRedirect` to get the oauth user info / access token. + * + * When called from inside a static Space with OAuth enabled, it will load the config from the space, otherwise you need to at least specify + * the client ID of your OAuth App. + * + * @example + * ```ts + * import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub"; + * + * const oauthResult = await oauthHandleRedirectIfPresent(); + * + * if (!oauthResult) { + * // If the user is not logged in, redirect to the login page + * window.location.href = await oauthLoginUrl(); + * } + * + * // You can use oauthResult.accessToken, oauthResult.accessTokenExpiresAt and oauthResult.userInfo + * console.log(oauthResult); + * ``` + * + * (Theoretically, this function could be used to authenticate a user for any OAuth provider supporting PKCE and OpenID Connect by changing `hubUrl`, + * but it is currently only tested with the Hugging Face Hub.) + */ +export async function oauthLoginUrl(opts?: { + /** + * OAuth client ID. + * + * For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata. + * For other Spaces, it is available to the backend in the OAUTH_CLIENT_ID environment variable, as long as `hf_oauth: true` is present in the README.md's metadata. + * + * You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its client ID. + */ + clientId?: string; + hubUrl?: string; + /** + * OAuth scope, a list of space-separated scopes. + * + * For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata. + * For other Spaces, it is available to the backend in the OAUTH_SCOPES environment variable, as long as `hf_oauth: true` is present in the README.md's metadata. + * + * Defaults to "openid profile". + * + * You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its scopes. + * + * See https://huggingface.co/docs/hub/oauth for a list of available scopes. + */ + scopes?: string; + /** + * Redirect URI, defaults to the current URL. + * + * For Spaces, any URL within the Space is allowed. + * + * For Developer Applications, you can add any URL you want to the list of allowed redirect URIs at https://huggingface.co/settings/connected-applications. + */ + redirectUrl?: string; + /** + * State to pass to the OAuth provider, which will be returned in the call to `oauthLogin` after the redirect. + */ + state?: string; + /** + * If provided, will be filled with the code verifier and nonce used for the OAuth flow, + * instead of using localStorage. + * + * When calling {@link `oauthHandleRedirectIfPresent`} or {@link `oauthHandleRedirect`} you will need to provide the same values. + */ + localStorage?: { + codeVerifier?: string; + nonce?: string; + }; +}): Promise { + if (typeof window === "undefined" && (!opts?.redirectUrl || !opts?.clientId)) { + throw new Error("oauthLogin is only available in the browser, unless you provide clientId and redirectUrl"); + } + if (typeof localStorage === "undefined" && !opts?.localStorage) { + throw new Error( + "oauthLogin requires localStorage to be available in the context, unless you provide a localStorage empty object as argument" + ); + } + + const hubUrl = opts?.hubUrl || HUB_URL; + const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`; + const openidConfigRes = await fetch(openidConfigUrl, { + headers: { + Accept: "application/json", + }, + }); + + if (!openidConfigRes.ok) { + throw await createApiError(openidConfigRes); + } + + const opendidConfig: { + authorization_endpoint: string; + token_endpoint: string; + userinfo_endpoint: string; + } = await openidConfigRes.json(); + + const newNonce = globalThis.crypto.randomUUID(); + // Two random UUIDs concatenated together, because min length is 43 and max length is 128 + const newCodeVerifier = globalThis.crypto.randomUUID() + globalThis.crypto.randomUUID(); + + if (opts?.localStorage) { + if (opts.localStorage.codeVerifier !== undefined && opts.localStorage.codeVerifier !== null) { + throw new Error( + "localStorage.codeVerifier must be initially set to null or undefined, and will be filled by oauthLoginUrl" + ); + } + if (opts.localStorage.nonce !== undefined && opts.localStorage.nonce !== null) { + throw new Error( + "localStorage.nonce must be initially set to null or undefined, and will be filled by oauthLoginUrl" + ); + } + opts.localStorage.codeVerifier = newCodeVerifier; + opts.localStorage.nonce = newNonce; + } else { + localStorage.setItem("huggingface.co:oauth:nonce", newNonce); + localStorage.setItem("huggingface.co:oauth:code_verifier", newCodeVerifier); + } + + const redirectUri = opts?.redirectUrl || (typeof window !== "undefined" ? window.location.href : undefined); + if (!redirectUri) { + throw new Error("Missing redirectUrl"); + } + const state = JSON.stringify({ + nonce: newNonce, + redirectUri, + state: opts?.state, + }); + + const variables: Record | null = + // @ts-expect-error window.huggingface is defined inside static Spaces. + typeof window !== "undefined" ? window.huggingface?.variables ?? null : null; + + const clientId = opts?.clientId || variables?.OAUTH_CLIENT_ID; + + if (!clientId) { + if (variables) { + throw new Error("Missing clientId, please add hf_oauth: true to the README.md's metadata in your static Space"); + } + throw new Error("Missing clientId"); + } + + const challenge = base64FromBytes( + new Uint8Array(await globalThis.crypto.subtle.digest("SHA-256", new TextEncoder().encode(newCodeVerifier))) + ) + .replace(/[+]/g, "-") + .replace(/[/]/g, "_") + .replace(/=/g, ""); + + return `${opendidConfig.authorization_endpoint}?${new URLSearchParams({ + client_id: clientId, + scope: opts?.scopes || variables?.OAUTH_SCOPES || "openid profile", + response_type: "code", + redirect_uri: redirectUri, + state, + code_challenge: challenge, + code_challenge_method: "S256", + }).toString()}`; +} diff --git a/lib/parse-safetensors-metadata.spec.ts b/lib/parse-safetensors-metadata.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..d96f5ed6507b5ca8b80c05fa1d33b076a7a985fc --- /dev/null +++ b/lib/parse-safetensors-metadata.spec.ts @@ -0,0 +1,122 @@ +import { assert, it, describe } from "vitest"; +import { parseSafetensorsMetadata, parseSafetensorsShardFilename } from "./parse-safetensors-metadata"; +import { sum } from "../utils/sum"; + +describe("parseSafetensorsMetadata", () => { + it("fetch info for single-file (with the default conventional filename)", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "bert-base-uncased", + computeParametersCount: true, + revision: "86b5e0934494bd15c9632b12f734a8a67f723594", + }); + + assert(!parse.sharded); + assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" }); + + // Example of one tensor (the header contains many tensors) + + assert.deepStrictEqual(parse.header["bert.embeddings.LayerNorm.beta"], { + dtype: "F32", + shape: [768], + data_offsets: [0, 3072], + }); + + assert.deepStrictEqual(parse.parameterCount, { F32: 110_106_428 }); + assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 110_106_428); + // total params = 110m + }); + + it("fetch info for sharded (with the default conventional filename)", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "bigscience/bloom", + computeParametersCount: true, + revision: "053d9cd9fbe814e091294f67fcfedb3397b954bb", + }); + + assert(parse.sharded); + + assert.strictEqual(Object.keys(parse.headers).length, 72); + // This model has 72 shards! + + // Example of one tensor inside one file + + assert.deepStrictEqual(parse.headers["model_00012-of-00072.safetensors"]["h.10.input_layernorm.weight"], { + dtype: "BF16", + shape: [14336], + data_offsets: [3288649728, 3288678400], + }); + + assert.deepStrictEqual(parse.parameterCount, { BF16: 176_247_271_424 }); + assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 176_247_271_424); + // total params = 176B + }); + + it("fetch info for single-file with multiple dtypes", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "roberta-base", + computeParametersCount: true, + revision: "e2da8e2f811d1448a5b465c236feacd80ffbac7b", + }); + + assert(!parse.sharded); + + assert.deepStrictEqual(parse.parameterCount, { F32: 124_697_433, I64: 514 }); + assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 124_697_947); + // total params = 124m + }); + + it("fetch info for single-file with file path", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "CompVis/stable-diffusion-v1-4", + computeParametersCount: true, + path: "unet/diffusion_pytorch_model.safetensors", + revision: "133a221b8aa7292a167afc5127cb63fb5005638b", + }); + + assert(!parse.sharded); + assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" }); + + // Example of one tensor (the header contains many tensors) + + assert.deepStrictEqual(parse.header["up_blocks.3.resnets.0.norm2.bias"], { + dtype: "F32", + shape: [320], + data_offsets: [3_409_382_416, 3_409_383_696], + }); + + assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 }); + assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964); + }); + + it("fetch info for sharded (with the default conventional filename) with file path", async () => { + const parse = await parseSafetensorsMetadata({ + repo: "Alignment-Lab-AI/ALAI-gemma-7b", + computeParametersCount: true, + path: "7b/1/model.safetensors.index.json", + revision: "37e307261fe97bbf8b2463d61dbdd1a10daa264c", + }); + + assert(parse.sharded); + + assert.strictEqual(Object.keys(parse.headers).length, 4); + + assert.deepStrictEqual(parse.headers["model-00004-of-00004.safetensors"]["model.layers.24.mlp.up_proj.weight"], { + dtype: "BF16", + shape: [24576, 3072], + data_offsets: [301996032, 452990976], + }); + + assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 }); + assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896); + }); + + it("should detect sharded safetensors filename", async () => { + const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors + const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename); + + assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_"); + assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model"); + assert.strictEqual(safetensorsShardFileInfo?.shard, "00005"); + assert.strictEqual(safetensorsShardFileInfo?.total, "00072"); + }); +}); diff --git a/lib/parse-safetensors-metadata.ts b/lib/parse-safetensors-metadata.ts new file mode 100644 index 0000000000000000000000000000000000000000..ca43a00883420b7e2ac1bad7b33fbf6cb723e83c --- /dev/null +++ b/lib/parse-safetensors-metadata.ts @@ -0,0 +1,274 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { omit } from "../utils/omit"; +import { toRepoId } from "../utils/toRepoId"; +import { typedEntries } from "../utils/typedEntries"; +import { downloadFile } from "./download-file"; +import { fileExists } from "./file-exists"; +import { promisesQueue } from "../utils/promisesQueue"; +import type { SetRequired } from "../vendor/type-fest/set-required"; + +export const SAFETENSORS_FILE = "model.safetensors"; +export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; +/// We advise model/library authors to use the filenames above for convention inside model repos, +/// but in some situations safetensors weights have different filenames. +export const RE_SAFETENSORS_FILE = /\.safetensors$/; +export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; +export const RE_SAFETENSORS_SHARD_FILE = + /^(?(?.*?)[_-])(?\d{5})-of-(?\d{5})\.safetensors$/; +export interface SafetensorsShardFileInfo { + prefix: string; + basePrefix: string; + shard: string; + total: string; +} +export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null { + const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); + if (match && match.groups) { + return { + prefix: match.groups["prefix"], + basePrefix: match.groups["basePrefix"], + shard: match.groups["shard"], + total: match.groups["total"], + }; + } + return null; +} + +const PARALLEL_DOWNLOADS = 20; +const MAX_HEADER_LENGTH = 25_000_000; + +class SafetensorParseError extends Error {} + +type FileName = string; + +export type TensorName = string; +export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; + +export interface TensorInfo { + dtype: Dtype; + shape: number[]; + data_offsets: [number, number]; +} + +export type SafetensorsFileHeader = Record & { + __metadata__: Record; +}; + +export interface SafetensorsIndexJson { + dtype?: string; + /// ^there's sometimes a dtype but it looks inconsistent. + metadata?: Record; + /// ^ why the naming inconsistency? + weight_map: Record; +} + +export type SafetensorsShardedHeaders = Record; + +export type SafetensorsParseFromRepo = + | { + sharded: false; + header: SafetensorsFileHeader; + parameterCount?: Partial>; + } + | { + sharded: true; + index: SafetensorsIndexJson; + headers: SafetensorsShardedHeaders; + parameterCount?: Partial>; + }; + +async function parseSingleFile( + path: string, + params: { + repo: RepoDesignation; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + const blob = await downloadFile({ ...params, path }); + + if (!blob) { + throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); + } + + const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); + const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); + // ^little-endian + if (lengthOfHeader <= 0) { + throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`); + } + if (lengthOfHeader > MAX_HEADER_LENGTH) { + throw new SafetensorParseError( + `Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.` + ); + } + + try { + // no validation for now, we assume it's a valid FileHeader. + const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); + return header; + } catch (err) { + throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); + } +} + +async function parseShardedIndex( + path: string, + params: { + repo: RepoDesignation; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { + const indexBlob = await downloadFile({ + ...params, + path, + }); + + if (!indexBlob) { + throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); + } + + // no validation for now, we assume it's a valid IndexJson. + let index: SafetensorsIndexJson; + try { + index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); + } catch (error) { + throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); + } + + const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); + const filenames = [...new Set(Object.values(index.weight_map))]; + const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( + await promisesQueue( + filenames.map( + (filename) => async () => + [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader] + ), + PARALLEL_DOWNLOADS + ) + ); + return { index, headers: shardedMap }; +} + +/** + * Analyze model.safetensors.index.json or model.safetensors from a model hosted + * on Hugging Face using smart range requests to extract its metadata. + */ +export async function parseSafetensorsMetadata( + params: { + /** Only models are supported */ + repo: RepoDesignation; + /** + * Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists). + */ + path?: string; + /** + * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType + * + * @default false + */ + computeParametersCount: true; + hubUrl?: string; + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise>; +export async function parseSafetensorsMetadata( + params: { + /** Only models are supported */ + repo: RepoDesignation; + /** + * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType + * + * @default false + */ + path?: string; + computeParametersCount?: boolean; + hubUrl?: string; + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise; +export async function parseSafetensorsMetadata( + params: { + repo: RepoDesignation; + path?: string; + computeParametersCount?: boolean; + hubUrl?: string; + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + const repoId = toRepoId(params.repo); + + if (repoId.type !== "model") { + throw new TypeError("Only model repos should contain safetensors files."); + } + + if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) { + const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); + return { + sharded: false, + header, + ...(params.computeParametersCount && { + parameterCount: computeNumOfParamsByDtypeSingleFile(header), + }), + }; + } else if ( + RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") || + (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) + ) { + const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); + return { + sharded: true, + index, + headers, + ...(params.computeParametersCount && { + parameterCount: computeNumOfParamsByDtypeSharded(headers), + }), + }; + } else { + throw new Error("model id does not seem to contain safetensors weights"); + } +} + +function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial> { + const counter: Partial> = {}; + const tensors = omit(header, "__metadata__"); + + for (const [, v] of typedEntries(tensors)) { + if (v.shape.length === 0) { + continue; + } + counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b); + } + return counter; +} + +function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial> { + const counter: Partial> = {}; + for (const header of Object.values(shardedMap)) { + for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) { + counter[k] = (counter[k] ?? 0) + (v ?? 0); + } + } + return counter; +} diff --git a/lib/paths-info.spec.ts b/lib/paths-info.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..837f4a192454183c5eb6792e7e403aafc52f2479 --- /dev/null +++ b/lib/paths-info.spec.ts @@ -0,0 +1,75 @@ +import { expect, it, describe } from "vitest"; +import type { CommitInfo, PathInfo, SecurityFileStatus } from "./paths-info"; +import { pathsInfo } from "./paths-info"; + +describe("pathsInfo", () => { + it("should fetch LFS path info", async () => { + const result: PathInfo[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["tf_model.h5"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + expect(modelPathInfo.path).toBe("tf_model.h5"); + expect(modelPathInfo.type).toBe("file"); + // lfs pointer, therefore lfs should be defined + expect(modelPathInfo?.lfs).toBeDefined(); + expect(modelPathInfo?.lfs?.oid).toBe("a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"); + expect(modelPathInfo?.lfs?.size).toBe(536063208); + expect(modelPathInfo?.lfs?.pointerSize).toBe(134); + + // should not include expand info + expect(modelPathInfo.lastCommit).toBeUndefined(); + expect(modelPathInfo.securityFileStatus).toBeUndefined(); + }); + + it("expand parmas should fetch lastCommit and securityFileStatus", async () => { + const result: (PathInfo & { + lastCommit: CommitInfo; + securityFileStatus: SecurityFileStatus; + })[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["tf_model.h5"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + expand: true, // include + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + + // should include expand info + expect(modelPathInfo.lastCommit).toBeDefined(); + expect(modelPathInfo.securityFileStatus).toBeDefined(); + + expect(modelPathInfo.lastCommit.id).toBe("dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7"); + expect(modelPathInfo.lastCommit.title).toBe("Update tf_model.h5"); + expect(modelPathInfo.lastCommit.date.getTime()).toBe(1569268124000); // 2019-09-23T19:48:44.000Z + }); + + it("non-LFS pointer should have lfs undefined", async () => { + const result: PathInfo[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["config.json"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + expect(modelPathInfo.path).toBe("config.json"); + expect(modelPathInfo.lfs).toBeUndefined(); + }); +}); diff --git a/lib/paths-info.ts b/lib/paths-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..ec455f882d22c3899dcfd7721c8b1df3a1388f0a --- /dev/null +++ b/lib/paths-info.ts @@ -0,0 +1,124 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; + +export interface LfsPathInfo { + oid: string; + size: number; + pointerSize: number; +} + +export interface CommitInfo { + id: string; + title: string; + date: Date; +} + +export interface SecurityFileStatus { + status: string; +} + +export interface PathInfo { + path: string; + type: string; + oid: string; + size: number; + /** + * Only defined when path is LFS pointer + */ + lfs?: LfsPathInfo; + lastCommit?: CommitInfo; + securityFileStatus?: SecurityFileStatus; +} + +// Define the overloaded signatures +export function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand: true; // if expand true + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise<(PathInfo & { lastCommit: CommitInfo; securityFileStatus: SecurityFileStatus })[]>; +export function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand?: boolean; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise; + +export async function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand?: boolean; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + const hubUrl = params.hubUrl ?? HUB_URL; + + const url = `${hubUrl}/api/${repoId.type}s/${repoId.name}/paths-info/${encodeURIComponent( + params.revision ?? "main" + )}`; + + const resp = await (params.fetch ?? fetch)(url, { + method: "POST", + headers: { + ...(accessToken && { + Authorization: `Bearer ${accessToken}`, + }), + Accept: "application/json", + "Content-Type": "application/json", + }, + body: JSON.stringify({ + paths: params.paths, + expand: params.expand, + }), + }); + + if (!resp.ok) { + throw await createApiError(resp); + } + + const json: unknown = await resp.json(); + if (!Array.isArray(json)) throw new Error("malformed response: expected array"); + + return json.map((item: PathInfo) => ({ + path: item.path, + lfs: item.lfs, + type: item.type, + oid: item.oid, + size: item.size, + // expand fields + securityFileStatus: item.securityFileStatus, + lastCommit: item.lastCommit + ? { + date: new Date(item.lastCommit.date), + title: item.lastCommit.title, + id: item.lastCommit.id, + } + : undefined, + })); +} diff --git a/lib/repo-exists.spec.ts b/lib/repo-exists.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..c4bbb192f236db689fbed854df5609d5d6c091e9 --- /dev/null +++ b/lib/repo-exists.spec.ts @@ -0,0 +1,13 @@ +import { describe, expect, it } from "vitest"; +import { repoExists } from "./repo-exists"; + +describe("repoExists", () => { + it("should check if a repo exists", async () => { + const exists1 = await repoExists({ repo: { type: "model", name: "openai-community/gpt2" } }); + + expect(exists1).toBe(true); + + const exists2 = await repoExists({ repo: { type: "model", name: "openai-community/gpt9000" } }); + expect(exists2).toBe(false); + }); +}); diff --git a/lib/repo-exists.ts b/lib/repo-exists.ts new file mode 100644 index 0000000000000000000000000000000000000000..b53cae0e25fbe3ae47ba8e51e1d5d7554dca54c8 --- /dev/null +++ b/lib/repo-exists.ts @@ -0,0 +1,43 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { RepoDesignation } from "../types/public"; +import { toRepoId } from "../utils/toRepoId"; + +export async function repoExists(params: { + repo: RepoDesignation; + + hubUrl?: string; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + */ + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + accessToken?: string; +}): Promise { + const repoId = toRepoId(params.repo); + + const res = await (params.fetch ?? fetch)( + `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}?expand[]=likes`, + { + method: "GET", + headers: { + ...(params.accessToken && { + Authorization: `Bearer ${params.accessToken}`, + }), + }, + } + ); + + if (res.status === 404 || res.status === 401) { + return false; + } + + if (!res.ok) { + throw await createApiError(res); + } + + return true; +} diff --git a/lib/snapshot-download.spec.ts b/lib/snapshot-download.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..0c44cc842467bac518e2a8795a8643189494e9e9 --- /dev/null +++ b/lib/snapshot-download.spec.ts @@ -0,0 +1,275 @@ +import { expect, test, describe, vi, beforeEach } from "vitest"; +import { dirname, join } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { getHFHubCachePath } from "./cache-management"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; +import { snapshotDownload } from "./snapshot-download"; +import type { ListFileEntry } from "./list-files"; +import { listFiles } from "./list-files"; +import { modelInfo } from "./model-info"; +import type { ModelEntry } from "./list-models"; +import type { ApiModelInfo } from "../types/api/api-model"; +import { datasetInfo } from "./dataset-info"; +import type { DatasetEntry } from "./list-datasets"; +import type { ApiDatasetInfo } from "../types/api/api-dataset"; +import { spaceInfo } from "./space-info"; +import type { SpaceEntry } from "./list-spaces"; +import type { ApiSpaceInfo } from "../types/api/api-space"; + +vi.mock("node:fs/promises", () => ({ + writeFile: vi.fn(), + mkdir: vi.fn(), +})); + +vi.mock("./space-info", () => ({ + spaceInfo: vi.fn(), +})); + +vi.mock("./dataset-info", () => ({ + datasetInfo: vi.fn(), +})); + +vi.mock("./model-info", () => ({ + modelInfo: vi.fn(), +})); + +vi.mock("./list-files", () => ({ + listFiles: vi.fn(), +})); + +vi.mock("./download-file-to-cache-dir", () => ({ + downloadFileToCacheDir: vi.fn(), +})); + +const DUMMY_SHA = "dummy-sha"; + +// utility method to transform an array of ListFileEntry to an AsyncGenerator +async function* toAsyncGenerator(content: ListFileEntry[]): AsyncGenerator { + for (const entry of content) { + yield Promise.resolve(entry); + } +} + +beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator([])); + + // mock repo info + vi.mocked(modelInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as ModelEntry & ApiModelInfo); + vi.mocked(datasetInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as DatasetEntry & ApiDatasetInfo); + vi.mocked(spaceInfo).mockResolvedValue({ + sha: DUMMY_SHA, + } as SpaceEntry & ApiSpaceInfo); +}); + +describe("snapshotDownload", () => { + test("empty AsyncGenerator should not call downloadFileToCacheDir", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + expect(downloadFileToCacheDir).not.toHaveBeenCalled(); + }); + + test("repo type model should use modelInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "model", + }, + }); + expect(modelInfo).toHaveBeenCalledOnce(); + expect(modelInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "model", + }, + }); + }); + + test("repo type dataset should use datasetInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "dataset", + }, + }); + expect(datasetInfo).toHaveBeenCalledOnce(); + expect(datasetInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "dataset", + }, + }); + }); + + test("repo type space should use spaceInfo", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + expect(spaceInfo).toHaveBeenCalledOnce(); + expect(spaceInfo).toHaveBeenCalledWith({ + name: "foo/bar", + additionalFields: ["sha"], + revision: "main", + repo: { + name: "foo/bar", + type: "space", + }, + }); + }); + + test("commitHash should be saved to ref folder", async () => { + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + revision: "dummy-revision", + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "refs", "dummy-revision"); + expect(mkdir).toHaveBeenCalledWith(dirname(expectedPath), { recursive: true }); + expect(writeFile).toHaveBeenCalledWith(expectedPath, DUMMY_SHA); + }); + + test("directory ListFileEntry should mkdir it", async () => { + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: "dummy-etag", + type: "directory", + path: "potatoes", + size: 0, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + // cross-platform testing + const expectedPath = join(getHFHubCachePath(), "spaces--foo--bar", "snapshots", DUMMY_SHA, "potatoes"); + expect(mkdir).toHaveBeenCalledWith(expectedPath, { recursive: true }); + }); + + test("files in ListFileEntry should download them", async () => { + const entries: ListFileEntry[] = Array.from({ length: 10 }, (_, i) => ({ + oid: `dummy-etag-${i}`, + type: "file", + path: `file-${i}.txt`, + size: i, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + })); + vi.mocked(listFiles).mockReturnValue(toAsyncGenerator(entries)); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + }); + + for (const entry of entries) { + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + repo: { + name: "foo/bar", + type: "space", + }, + path: entry.path, + revision: DUMMY_SHA, + }) + ); + } + }); + + test("custom params should be propagated", async () => { + // fetch mock + const fetchMock: typeof fetch = vi.fn(); + const hubMock = "https://foor.bar"; + const accessTokenMock = "dummy-access-token"; + + vi.mocked(listFiles).mockReturnValue( + toAsyncGenerator([ + { + oid: `dummy-etag`, + type: "file", + path: `file.txt`, + size: 10, + lastCommit: { + date: new Date().toISOString(), + id: DUMMY_SHA, + title: "feat: best commit", + }, + }, + ]) + ); + + await snapshotDownload({ + repo: { + name: "foo/bar", + type: "space", + }, + hubUrl: hubMock, + fetch: fetchMock, + accessToken: accessTokenMock, + }); + + expect(spaceInfo).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // list files should receive custom fetch + expect(listFiles).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + + // download file to cache should receive custom fetch + expect(downloadFileToCacheDir).toHaveBeenCalledWith( + expect.objectContaining({ + fetch: fetchMock, + hubUrl: hubMock, + accessToken: accessTokenMock, + }) + ); + }); +}); diff --git a/lib/snapshot-download.ts b/lib/snapshot-download.ts new file mode 100644 index 0000000000000000000000000000000000000000..b3e30c13f140e7e7305efe24bba441d93fecafbb --- /dev/null +++ b/lib/snapshot-download.ts @@ -0,0 +1,124 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { listFiles } from "./list-files"; +import { getHFHubCachePath, getRepoFolderName } from "./cache-management"; +import { spaceInfo } from "./space-info"; +import { datasetInfo } from "./dataset-info"; +import { modelInfo } from "./model-info"; +import { toRepoId } from "../utils/toRepoId"; +import { join, dirname } from "node:path"; +import { mkdir, writeFile } from "node:fs/promises"; +import { downloadFileToCacheDir } from "./download-file-to-cache-dir"; + +export const DEFAULT_REVISION = "main"; + +/** + * Downloads an entire repository at a given revision in the cache directory {@link getHFHubCachePath}. + * You can list all cached repositories using {@link scanCachedRepo} + * @remarks It uses internally {@link downloadFileToCacheDir}. + */ +export async function snapshotDownload( + params: { + repo: RepoDesignation; + cacheDir?: string; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + * + * @default "main" + */ + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + let cacheDir: string; + if (params.cacheDir) { + cacheDir = params.cacheDir; + } else { + cacheDir = getHFHubCachePath(); + } + + let revision: string; + if (params.revision) { + revision = params.revision; + } else { + revision = DEFAULT_REVISION; + } + + const repoId = toRepoId(params.repo); + + // get repository revision value (sha) + let repoInfo: { sha: string }; + switch (repoId.type) { + case "space": + repoInfo = await spaceInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "dataset": + repoInfo = await datasetInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + case "model": + repoInfo = await modelInfo({ + ...params, + name: repoId.name, + additionalFields: ["sha"], + revision: revision, + }); + break; + default: + throw new Error(`invalid repository type ${repoId.type}`); + } + + const commitHash: string = repoInfo.sha; + + // get storage folder + const storageFolder = join(cacheDir, getRepoFolderName(repoId)); + const snapshotFolder = join(storageFolder, "snapshots", commitHash); + + // if passed revision is not identical to commit_hash + // then revision has to be a branch name or tag name. + // In that case store a ref. + if (revision !== commitHash) { + const refPath = join(storageFolder, "refs", revision); + await mkdir(dirname(refPath), { recursive: true }); + await writeFile(refPath, commitHash); + } + + const cursor = listFiles({ + ...params, + repo: params.repo, + recursive: true, + revision: repoInfo.sha, + }); + + for await (const entry of cursor) { + switch (entry.type) { + case "file": + await downloadFileToCacheDir({ + ...params, + path: entry.path, + revision: commitHash, + cacheDir: cacheDir, + }); + break; + case "directory": + await mkdir(join(snapshotFolder, entry.path), { recursive: true }); + break; + default: + throw new Error(`unknown entry type: ${entry.type}`); + } + } + + return snapshotFolder; +} diff --git a/lib/space-info.spec.ts b/lib/space-info.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..ea966f98bceea2577e268164f8512bc822a8f82e --- /dev/null +++ b/lib/space-info.spec.ts @@ -0,0 +1,53 @@ +import { describe, expect, it } from "vitest"; +import { spaceInfo } from "./space-info"; +import type { SpaceEntry } from "./list-spaces"; +import type { ApiSpaceInfo } from "../types/api/api-space"; + +describe("spaceInfo", () => { + it("should return the space info", async () => { + const info = await spaceInfo({ + name: "huggingfacejs/client-side-oauth", + }); + expect(info).toEqual({ + id: "659835e689010f9c7aed608d", + name: "huggingfacejs/client-side-oauth", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + sdk: "static", + }); + }); + + it("should return the space info with author", async () => { + const info: SpaceEntry & Pick = await spaceInfo({ + name: "huggingfacejs/client-side-oauth", + additionalFields: ["author"], + }); + expect(info).toEqual({ + id: "659835e689010f9c7aed608d", + name: "huggingfacejs/client-side-oauth", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + sdk: "static", + author: "huggingfacejs", + }); + }); + + it("should return the space info for a given revision", async () => { + const info: SpaceEntry & Pick = await spaceInfo({ + name: "huggingfacejs/client-side-oauth", + additionalFields: ["sha"], + revision: "e410a9ff348e6bed393b847711e793282d7c672e", + }); + expect(info).toEqual({ + id: "659835e689010f9c7aed608d", + name: "huggingfacejs/client-side-oauth", + updatedAt: expect.any(Date), + likes: expect.any(Number), + private: false, + sdk: "static", + sha: "e410a9ff348e6bed393b847711e793282d7c672e", + }); + }); +}); diff --git a/lib/space-info.ts b/lib/space-info.ts new file mode 100644 index 0000000000000000000000000000000000000000..94223538253dc8e6037af67436e32126c0fdda58 --- /dev/null +++ b/lib/space-info.ts @@ -0,0 +1,61 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiSpaceInfo } from "../types/api/api-space"; +import type { CredentialsParams } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { pick } from "../utils/pick"; +import type { SPACE_EXPANDABLE_KEYS, SpaceEntry } from "./list-spaces"; +import { SPACE_EXPAND_KEYS } from "./list-spaces"; + +export async function spaceInfo< + const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never, +>( + params: { + name: string; + hubUrl?: string; + additionalFields?: T[]; + /** + * An optional Git revision id which can be a branch name, a tag, or a commit hash. + */ + revision?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise> { + const accessToken = params && checkCredentials(params); + + const search = new URLSearchParams([ + ...SPACE_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]), + ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []), + ]).toString(); + + const response = await (params.fetch || fetch)( + `${params?.hubUrl || HUB_URL}/api/spaces/${params.name}/revision/${encodeURIComponent( + params.revision ?? "HEAD" + )}?${search.toString()}`, + { + headers: { + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + Accepts: "application/json", + }, + } + ); + + if (!response.ok) { + throw await createApiError(response); + } + + const data = await response.json(); + + return { + ...(params?.additionalFields && pick(data, params.additionalFields)), + id: data._id, + name: data.id, + sdk: data.sdk, + likes: data.likes, + private: data.private, + updatedAt: new Date(data.lastModified), + } as SpaceEntry & Pick; +} diff --git a/lib/upload-file.spec.ts b/lib/upload-file.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..af75349631bedfc55a7f1e53111c307bf72e1688 --- /dev/null +++ b/lib/upload-file.spec.ts @@ -0,0 +1,98 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { uploadFile } from "./upload-file"; + +describe("uploadFile", () => { + it("should upload a file", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await uploadFile({ + accessToken: TEST_ACCESS_TOKEN, + repo, + file: { content: new Blob(["file1"]), path: "file1" }, + hubUrl: TEST_HUB_URL, + }); + await uploadFile({ + accessToken: TEST_ACCESS_TOKEN, + repo, + file: new URL("https://huggingface.co/gpt2/raw/main/config.json"), + hubUrl: TEST_HUB_URL, + }); + + let content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "config.json", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual( + (await content?.text())?.trim(), + ` +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_layer": 12, + "n_positions": 1024, + "resid_pdrop": 0.1, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "vocab_size": 50257 +} + `.trim() + ); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/upload-file.ts b/lib/upload-file.ts new file mode 100644 index 0000000000000000000000000000000000000000..290cfcb03ce86e1f937605bee64260e4644d1447 --- /dev/null +++ b/lib/upload-file.ts @@ -0,0 +1,47 @@ +import type { CredentialsParams } from "../types/public"; +import type { CommitOutput, CommitParams, ContentSource } from "./commit"; +import { commit } from "./commit"; + +export function uploadFile( + params: { + repo: CommitParams["repo"]; + file: URL | File | { path: string; content: ContentSource }; + commitTitle?: CommitParams["title"]; + commitDescription?: CommitParams["description"]; + hubUrl?: CommitParams["hubUrl"]; + branch?: CommitParams["branch"]; + isPullRequest?: CommitParams["isPullRequest"]; + parentCommit?: CommitParams["parentCommit"]; + fetch?: CommitParams["fetch"]; + useWebWorkers?: CommitParams["useWebWorkers"]; + abortSignal?: CommitParams["abortSignal"]; + } & Partial +): Promise { + const path = + params.file instanceof URL + ? params.file.pathname.split("/").at(-1) ?? "file" + : "path" in params.file + ? params.file.path + : params.file.name; + + return commit({ + ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }), + repo: params.repo, + operations: [ + { + operation: "addOrUpdate", + path, + content: "content" in params.file ? params.file.content : params.file, + }, + ], + title: params.commitTitle ?? `Add ${path}`, + description: params.commitDescription, + hubUrl: params.hubUrl, + branch: params.branch, + isPullRequest: params.isPullRequest, + parentCommit: params.parentCommit, + fetch: params.fetch, + useWebWorkers: params.useWebWorkers, + abortSignal: params.abortSignal, + }); +} diff --git a/lib/upload-files-with-progress.spec.ts b/lib/upload-files-with-progress.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..50a1b4d3802c6f3efedbcfca9d728d7febe046cd --- /dev/null +++ b/lib/upload-files-with-progress.spec.ts @@ -0,0 +1,168 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { uploadFilesWithProgress } from "./upload-files-with-progress"; +import type { CommitOutput, CommitProgressEvent } from "./commit"; + +describe("uploadFilesWithProgress", () => { + it("should upload files", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + const lfsContent = "O123456789".repeat(100_000); + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + const it = uploadFilesWithProgress({ + accessToken: TEST_ACCESS_TOKEN, + repo, + files: [ + { content: new Blob(["file1"]), path: "file1" }, + new URL("https://huggingface.co/gpt2/raw/main/config.json"), + // Large file + { + content: new Blob([lfsContent]), + path: "test.lfs.txt", + }, + ], + useWebWorkers: { + minSize: 1_000, + }, + hubUrl: TEST_HUB_URL, + }); + + let res: IteratorResult; + let progressEvents: CommitProgressEvent[] = []; + + do { + res = await it.next(); + if (!res.done) { + progressEvents.push(res.value); + } + } while (!res.done); + + // const intermediateHashingEvents = progressEvents.filter( + // (e) => e.event === "fileProgress" && e.type === "hashing" && e.progress !== 0 && e.progress !== 1 + // ); + // if (isFrontend) { + // assert(intermediateHashingEvents.length > 0); + // } + // const intermediateUploadEvents = progressEvents.filter( + // (e) => e.event === "fileProgress" && e.type === "uploading" && e.progress !== 0 && e.progress !== 1 + // ); + // if (isFrontend) { + // assert(intermediateUploadEvents.length > 0, "There should be at least one intermediate upload event"); + // } + progressEvents = progressEvents.filter((e) => e.event !== "fileProgress" || e.progress === 0 || e.progress === 1); + + assert.deepStrictEqual(progressEvents, [ + { + event: "phase", + phase: "preuploading", + }, + { + event: "phase", + phase: "uploadingLargeFiles", + }, + { + event: "fileProgress", + path: "test.lfs.txt", + progress: 0, + state: "hashing", + }, + { + event: "fileProgress", + path: "test.lfs.txt", + progress: 1, + state: "hashing", + }, + { + event: "fileProgress", + path: "test.lfs.txt", + progress: 0, + state: "uploading", + }, + { + event: "fileProgress", + path: "test.lfs.txt", + progress: 1, + state: "uploading", + }, + { + event: "phase", + phase: "committing", + }, + ]); + + let content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "config.json", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual( + (await content?.text())?.trim(), + ` +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_layer": 12, + "n_positions": 1024, + "resid_pdrop": 0.1, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "vocab_size": 50257 +} + `.trim() + ); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/upload-files-with-progress.ts b/lib/upload-files-with-progress.ts new file mode 100644 index 0000000000000000000000000000000000000000..f0a0af25251f17fadf69fed149a4235e64d0a555 --- /dev/null +++ b/lib/upload-files-with-progress.ts @@ -0,0 +1,154 @@ +import type { CredentialsParams } from "../types/public"; +import { typedInclude } from "../utils/typedInclude"; +import type { CommitOutput, CommitParams, CommitProgressEvent, ContentSource } from "./commit"; +import { commitIter } from "./commit"; + +const multipartUploadTracking = new WeakMap< + (progress: number) => void, + { + numParts: number; + partsProgress: Record; + } +>(); + +/** + * Uploads with progress + * + * Needs XMLHttpRequest to be available for progress events for uploads + * Set useWebWorkers to true in order to have progress events for hashing + */ +export async function* uploadFilesWithProgress( + params: { + repo: CommitParams["repo"]; + files: Array; + commitTitle?: CommitParams["title"]; + commitDescription?: CommitParams["description"]; + hubUrl?: CommitParams["hubUrl"]; + branch?: CommitParams["branch"]; + isPullRequest?: CommitParams["isPullRequest"]; + parentCommit?: CommitParams["parentCommit"]; + abortSignal?: CommitParams["abortSignal"]; + maxFolderDepth?: CommitParams["maxFolderDepth"]; + /** + * Set this to true in order to have progress events for hashing + */ + useWebWorkers?: CommitParams["useWebWorkers"]; + } & Partial +): AsyncGenerator { + return yield* commitIter({ + ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }), + repo: params.repo, + operations: params.files.map((file) => ({ + operation: "addOrUpdate", + path: file instanceof URL ? file.pathname.split("/").at(-1) ?? "file" : "path" in file ? file.path : file.name, + content: "content" in file ? file.content : file, + })), + title: params.commitTitle ?? `Add ${params.files.length} files`, + description: params.commitDescription, + hubUrl: params.hubUrl, + branch: params.branch, + isPullRequest: params.isPullRequest, + parentCommit: params.parentCommit, + useWebWorkers: params.useWebWorkers, + abortSignal: params.abortSignal, + fetch: async (input, init) => { + if (!init) { + return fetch(input); + } + + if ( + !typedInclude(["PUT", "POST"], init.method) || + !("progressHint" in init) || + !init.progressHint || + typeof XMLHttpRequest === "undefined" || + typeof input !== "string" || + (!(init.body instanceof ArrayBuffer) && + !(init.body instanceof Blob) && + !(init.body instanceof File) && + typeof init.body !== "string") + ) { + return fetch(input, init); + } + + const progressHint = init.progressHint as { + progressCallback: (progress: number) => void; + } & (Record | { part: number; numParts: number }); + const progressCallback = progressHint.progressCallback; + + const xhr = new XMLHttpRequest(); + + xhr.upload.addEventListener("progress", (event) => { + if (event.lengthComputable) { + if (progressHint.part !== undefined) { + let tracking = multipartUploadTracking.get(progressCallback); + if (!tracking) { + tracking = { numParts: progressHint.numParts, partsProgress: {} }; + multipartUploadTracking.set(progressCallback, tracking); + } + tracking.partsProgress[progressHint.part] = event.loaded / event.total; + let totalProgress = 0; + for (const partProgress of Object.values(tracking.partsProgress)) { + totalProgress += partProgress; + } + if (totalProgress === tracking.numParts) { + progressCallback(0.9999999999); + } else { + progressCallback(totalProgress / tracking.numParts); + } + } else { + if (event.loaded === event.total) { + progressCallback(0.9999999999); + } else { + progressCallback(event.loaded / event.total); + } + } + } + }); + + xhr.open(init.method, input, true); + + if (init.headers) { + const headers = new Headers(init.headers); + headers.forEach((value, key) => { + xhr.setRequestHeader(key, value); + }); + } + + init.signal?.throwIfAborted(); + xhr.send(init.body); + + return new Promise((resolve, reject) => { + xhr.addEventListener("load", () => { + resolve( + new Response(xhr.responseText, { + status: xhr.status, + statusText: xhr.statusText, + headers: Object.fromEntries( + xhr + .getAllResponseHeaders() + .trim() + .split("\n") + .map((header) => [header.slice(0, header.indexOf(":")), header.slice(header.indexOf(":") + 1).trim()]) + ), + }) + ); + }); + xhr.addEventListener("error", () => { + reject(new Error(xhr.statusText)); + }); + + if (init.signal) { + init.signal.addEventListener("abort", () => { + xhr.abort(); + + try { + init.signal?.throwIfAborted(); + } catch (err) { + reject(err); + } + }); + } + }); + }, + }); +} diff --git a/lib/upload-files.fs.spec.ts b/lib/upload-files.fs.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..415d71fd98fbebc0851e3115e06abafcefad15d4 --- /dev/null +++ b/lib/upload-files.fs.spec.ts @@ -0,0 +1,71 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { uploadFiles } from "./upload-files"; +import { mkdir } from "fs/promises"; +import { writeFile } from "fs/promises"; +import { pathToFileURL } from "url"; +import { tmpdir } from "os"; + +describe("uploadFiles", () => { + it("should upload local folder", async () => { + const tmpDir = tmpdir(); + + await mkdir(`${tmpDir}/test-folder/sub`, { recursive: true }); + + await writeFile(`${tmpDir}/test-folder/sub/file1.txt`, "file1"); + await writeFile(`${tmpDir}/test-folder/sub/file2.txt`, "file2"); + + await writeFile(`${tmpDir}/test-folder/file3.txt`, "file3"); + await writeFile(`${tmpDir}/test-folder/file4.txt`, "file4"); + + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await uploadFiles({ + accessToken: TEST_ACCESS_TOKEN, + repo, + files: [pathToFileURL(`${tmpDir}/test-folder`)], + hubUrl: TEST_HUB_URL, + }); + + let content = await downloadFile({ + repo, + path: "test-folder/sub/file1.txt", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "test-folder/file3.txt", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), `file3`); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/upload-files.spec.ts b/lib/upload-files.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..94258ad1b2a8ac0254bb78973962cfe81a82a22b --- /dev/null +++ b/lib/upload-files.spec.ts @@ -0,0 +1,95 @@ +import { assert, it, describe } from "vitest"; + +import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts"; +import type { RepoId } from "../types/public"; +import { insecureRandomString } from "../utils/insecureRandomString"; +import { createRepo } from "./create-repo"; +import { deleteRepo } from "./delete-repo"; +import { downloadFile } from "./download-file"; +import { uploadFiles } from "./upload-files"; + +describe("uploadFiles", () => { + it("should upload files", async () => { + const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`; + const repo = { type: "model", name: repoName } satisfies RepoId; + + try { + const result = await createRepo({ + accessToken: TEST_ACCESS_TOKEN, + repo, + hubUrl: TEST_HUB_URL, + }); + + assert.deepStrictEqual(result, { + repoUrl: `${TEST_HUB_URL}/${repoName}`, + }); + + await uploadFiles({ + accessToken: TEST_ACCESS_TOKEN, + repo, + files: [ + { content: new Blob(["file1"]), path: "file1" }, + new URL("https://huggingface.co/gpt2/raw/main/config.json"), + ], + hubUrl: TEST_HUB_URL, + }); + + let content = await downloadFile({ + repo, + path: "file1", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual(await content?.text(), "file1"); + + content = await downloadFile({ + repo, + path: "config.json", + hubUrl: TEST_HUB_URL, + }); + + assert.strictEqual( + (await content?.text())?.trim(), + ` +{ + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_layer": 12, + "n_positions": 1024, + "resid_pdrop": 0.1, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 + } + }, + "vocab_size": 50257 +} + `.trim() + ); + } finally { + await deleteRepo({ + repo, + accessToken: TEST_ACCESS_TOKEN, + hubUrl: TEST_HUB_URL, + }); + } + }); +}); diff --git a/lib/upload-files.ts b/lib/upload-files.ts new file mode 100644 index 0000000000000000000000000000000000000000..4eda11015eae63630bc54a6b0d30744468a81ba2 --- /dev/null +++ b/lib/upload-files.ts @@ -0,0 +1,39 @@ +import type { CredentialsParams } from "../types/public"; +import type { CommitOutput, CommitParams, ContentSource } from "./commit"; +import { commit } from "./commit"; + +export function uploadFiles( + params: { + repo: CommitParams["repo"]; + files: Array; + commitTitle?: CommitParams["title"]; + commitDescription?: CommitParams["description"]; + hubUrl?: CommitParams["hubUrl"]; + branch?: CommitParams["branch"]; + isPullRequest?: CommitParams["isPullRequest"]; + parentCommit?: CommitParams["parentCommit"]; + fetch?: CommitParams["fetch"]; + useWebWorkers?: CommitParams["useWebWorkers"]; + maxFolderDepth?: CommitParams["maxFolderDepth"]; + abortSignal?: CommitParams["abortSignal"]; + } & Partial +): Promise { + return commit({ + ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }), + repo: params.repo, + operations: params.files.map((file) => ({ + operation: "addOrUpdate", + path: file instanceof URL ? file.pathname.split("/").at(-1) ?? "file" : "path" in file ? file.path : file.name, + content: "content" in file ? file.content : file, + })), + title: params.commitTitle ?? `Add ${params.files.length} files`, + description: params.commitDescription, + hubUrl: params.hubUrl, + branch: params.branch, + isPullRequest: params.isPullRequest, + parentCommit: params.parentCommit, + fetch: params.fetch, + useWebWorkers: params.useWebWorkers, + abortSignal: params.abortSignal, + }); +} diff --git a/lib/who-am-i.spec.ts b/lib/who-am-i.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..387373923975299545ce2fef139ae0772ea8e58a --- /dev/null +++ b/lib/who-am-i.spec.ts @@ -0,0 +1,35 @@ +import { assert, it, describe } from "vitest"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts"; +import { whoAmI } from "./who-am-i"; + +describe("whoAmI", () => { + it("should fetch identity info", async () => { + const info = await whoAmI({ accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL }); + + if (info.auth.accessToken?.createdAt instanceof Date) { + info.auth.accessToken.createdAt = new Date(0); + } + + assert.deepStrictEqual(info, { + type: "user", + id: "62f264b9f3c90f4b6514a269", + name: "hub.js", + fullname: "@huggingface/hub CI bot", + email: "eliott@huggingface.co", + emailVerified: true, + canPay: false, + isPro: false, + periodEnd: null, + avatarUrl: "/avatars/934b830e9fdaa879487852f79eef7165.svg", + orgs: [], + auth: { + type: "access_token", + accessToken: { + createdAt: new Date(0), + displayName: "ci-hub.js", + role: "write", + }, + }, + }); + }); +}); diff --git a/lib/who-am-i.ts b/lib/who-am-i.ts new file mode 100644 index 0000000000000000000000000000000000000000..5f4c1845ba89c7c8608ab24eace258abdbdaeb71 --- /dev/null +++ b/lib/who-am-i.ts @@ -0,0 +1,91 @@ +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; +import type { ApiWhoAmIReponse } from "../types/api/api-who-am-i"; +import type { AccessTokenRole, AuthType, CredentialsParams } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; + +export interface WhoAmIUser { + /** Unique ID persistent across renames */ + id: string; + type: "user"; + email: string; + emailVerified: boolean; + isPro: boolean; + orgs: WhoAmIOrg[]; + name: string; + fullname: string; + canPay: boolean; + avatarUrl: string; + /** + * Unix timestamp in seconds + */ + periodEnd: number | null; +} + +export interface WhoAmIOrg { + /** Unique ID persistent across renames */ + id: string; + type: "org"; + name: string; + fullname: string; + email: string | null; + canPay: boolean; + avatarUrl: string; + /** + * Unix timestamp in seconds + */ + periodEnd: number | null; +} + +export interface WhoAmIApp { + id: string; + type: "app"; + name: string; + scope?: { + entities: string[]; + role: "admin" | "write" | "contributor" | "read"; + }; +} + +export type WhoAmI = WhoAmIApp | WhoAmIOrg | WhoAmIUser; +export interface AuthInfo { + type: AuthType; + accessToken?: { + displayName: string; + role: AccessTokenRole; + createdAt: Date; + }; + expiresAt?: Date; +} + +export async function whoAmI( + params: { + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & CredentialsParams +): Promise { + const accessToken = checkCredentials(params); + + const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/whoami-v2`, { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); + + if (!res.ok) { + throw await createApiError(res); + } + + const response: ApiWhoAmIReponse & { + auth: AuthInfo; + } = await res.json(); + + if (typeof response.auth.accessToken?.createdAt === "string") { + response.auth.accessToken.createdAt = new Date(response.auth.accessToken.createdAt); + } + + return response; +} diff --git a/test/consts.ts b/test/consts.ts new file mode 100644 index 0000000000000000000000000000000000000000..6b8b7983d62f140076851c03c901ab2da6c6bff3 --- /dev/null +++ b/test/consts.ts @@ -0,0 +1,4 @@ +export const TEST_HUB_URL = "https://hub-ci.huggingface.co"; +export const TEST_USER = "hub.js"; +export const TEST_ACCESS_TOKEN = "hf_hub.js"; +export const TEST_COOKIE = "huggingface-hub.js-cookie"; diff --git a/types/api/api-commit.ts b/types/api/api-commit.ts new file mode 100644 index 0000000000000000000000000000000000000000..b5fbfec4b4de4e0566bcc52927071ade954023e8 --- /dev/null +++ b/types/api/api-commit.ts @@ -0,0 +1,191 @@ +export interface ApiLfsBatchRequest { + /// github.com/git-lfs/git-lfs/blob/master/docs/api/batch.md + operation: "download" | "upload"; + transfers?: string[]; + /** + * Optional object describing the server ref that the objects belong to. Note: Added in v2.4. + * + * We use this object for QOL and to fail early for users when they're trying to push to the wrong reference. + * But it does nothing for security. + */ + ref?: { + name: string; + } | null; + objects: { + oid: string; + /** + * Integer byte size of the LFS object. Must be at least zero. + */ + size: number; + }[]; + /** + * The hash algorithm used to name Git LFS objects. Optional; defaults to sha256 if not specified. + * */ + hash_algo?: string; +} + +export interface ApiLfsBatchResponse { + transfer?: ApiLfsResponseTransfer; + objects: ApiLfsResponseObject[]; +} + +export type ApiLfsResponseTransfer = "basic" | "multipart"; + +export interface ApiLfsCompleteMultipartRequest { + oid: string; + parts: { etag: string; partNumber: number }[]; +} + +export interface ApiLfsResponseObject { + /** + * Optional boolean specifying whether the request + * for this specific object is authenticated. + * If omitted or false, Git LFS will attempt to find credentials for this URL. + */ + authenticated?: boolean; + oid: string; + /** + * Integer byte size of the LFS object. Must be at least zero. + */ + size: number; + /** + * Applicable actions depend on which `operation` is specified in the request. + * How these properties are interpreted depends on which transfer adapter + * the client will be using. + */ + actions?: { + /** + * Download operations MUST specify a download action, + * or an object error if the object cannot be downloaded for some reason + */ + download?: ApiLfsAction; + /** + * Upload operations can specify an upload and a verify action. + * The upload action describes how to upload the object. + */ + upload?: ApiLfsAction; + /** + * The LFS client will hit this URL after a successful upload. + * Servers can use this for extra verification, if needed. + */ + verify?: ApiLfsAction; + }; + /** + * If there are problems accessing individual objects, servers should continue + * to return a 200 status code, and provide per-object errors + */ + error?: { + code: number; + message: string; + }; +} + +export interface ApiLfsAction { + href: string; + /** + * Optional hash of String HTTP header key/value pairs to apply to the request + */ + header?: { [key: string]: string } & { chunk_size?: string }; + /** + * Whole number of seconds after local client time when transfer will expire. + * Preferred over `expires_at` if both are provided. + * Maximum of 2147483647, minimum of -2147483647. + */ + expires_in?: number; + /** + * String uppercase RFC 3339-formatted timestamp with second precision + * for when the given action expires (usually due to a temporary token). + */ + expires_at?: string; +} + +export interface ApiPreuploadRequest { + /** + * Optional, otherwise takes the existing content of `.gitattributes` for the revision. + * + * Provide this parameter if you plan to modify `.gitattributes` yourself at the same + * time as uploading LFS files. + * + * Note that this is not needed if you solely rely on automatic LFS detection from HF: the commit endpoint + * will automatically edit the `.gitattributes` file to track the files passed to its `lfsFiles` param. + */ + gitAttributes?: string; + files: Array<{ + /** + * Path of the LFS file + */ + path: string; + /** + * Full size of the LFS file + */ + size: number; + /** + * Base64-encoded sample of the first 512 bytes of the file + */ + sample: string; + }>; +} + +export interface ApiPreuploadResponse { + files: Array<{ + path: string; + uploadMode: "lfs" | "regular"; + }>; +} + +export interface ApiCommitHeader { + summary: string; + description?: string; + /** + * Parent commit. Optional + * + * - When opening a PR: will use parentCommit as the parent commit + * - When committing on a branch: Will make sure that there were no intermediate commits + */ + parentCommit?: string; +} + +export interface ApiCommitDeletedEntry { + path: string; +} + +export interface ApiCommitLfsFile { + path: string; + oldPath?: string; + /** Required if {@link oldPath} is not set */ + algo?: "sha256"; + /** Required if {@link oldPath} is not set */ + oid?: string; + size?: number; +} + +export interface ApiCommitFile { + /** Required if {@link oldPath} is not set */ + content?: string; + path: string; + oldPath?: string; + encoding?: "utf-8" | "base64"; +} + +export type ApiCommitOperation = + | { + key: "file"; + value: ApiCommitFile; + } + | { + key: "lfsFile"; + value: ApiCommitLfsFile; + } + | { + key: "deletedFile"; + value: ApiCommitDeletedEntry; + }; + +export interface ApiCommitData { + id: string; + title: string; + message: string; + authors: Array<{ user: string; avatar: string }>; + date: string; + formatted?: string; +} diff --git a/types/api/api-create-repo.ts b/types/api/api-create-repo.ts new file mode 100644 index 0000000000000000000000000000000000000000..f701f2d524ca67c86db5739fa224c949f27ca201 --- /dev/null +++ b/types/api/api-create-repo.ts @@ -0,0 +1,25 @@ +import type { SetRequired } from "../../vendor/type-fest/set-required"; +import type { RepoType, SpaceHardwareFlavor, SpaceSdk } from "../public"; +import type { ApiCommitFile } from "./api-commit"; + +export type ApiCreateRepoPayload = { + name: string; + canonical?: boolean; + license?: string; + template?: string; + organization?: string; + /** @default false */ + private?: boolean; + lfsmultipartthresh?: number; + files?: SetRequired[]; +} & ( + | { + type: Exclude; + } + | { + type: "space"; + hardware?: SpaceHardwareFlavor; + sdk: SpaceSdk; + sdkVersion?: string; + } +); diff --git a/types/api/api-dataset.ts b/types/api/api-dataset.ts new file mode 100644 index 0000000000000000000000000000000000000000..43b0978537d8855745fd668443b3fc102060b3e4 --- /dev/null +++ b/types/api/api-dataset.ts @@ -0,0 +1,89 @@ +import type { License } from "../public"; + +export interface ApiDatasetInfo { + _id: string; + id: string; + arxivIds?: string[]; + author?: string; + cardExists?: true; + cardError?: unknown; + cardData?: ApiDatasetMetadata; + contributors?: Array<{ user: string; _id: string }>; + disabled: boolean; + discussionsDisabled: boolean; + gated: false | "auto" | "manual"; + gitalyUid: string; + lastAuthor: { email: string; user?: string }; + lastModified: string; // date + likes: number; + likesRecent: number; + private: boolean; + updatedAt: string; // date + createdAt: string; // date + tags: string[]; + paperswithcode_id?: string; + sha: string; + files?: string[]; + citation?: string; + description?: string; + downloads: number; + downloadsAllTime: number; + previewable?: boolean; + doi?: { id: string; commit: string }; +} + +export interface ApiDatasetMetadata { + licenses?: undefined; + license?: License | License[]; + license_name?: string; + license_link?: "LICENSE" | "LICENSE.md" | string; + license_details?: string; + languages?: undefined; + language?: string | string[]; + language_bcp47?: string[]; + language_details?: string; + tags?: string[]; + task_categories?: string[]; + task_ids?: string[]; + config_names?: string[]; + configs?: { + config_name: string; + data_files?: + | string + | string[] + | { + split: string; + path: string | string[]; + }[]; + data_dir?: string; + }[]; + benchmark?: string; + paperswithcode_id?: string | null; + pretty_name?: string; + viewer?: boolean; + viewer_display_urls?: boolean; + thumbnail?: string | null; + description?: string | null; + annotations_creators?: string[]; + language_creators?: string[]; + multilinguality?: string[]; + size_categories?: string[]; + source_datasets?: string[]; + extra_gated_prompt?: string; + extra_gated_fields?: { + /** + * "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array } Property + */ + [x: string]: + | "text" + | "checkbox" + | "date_picker" + | "country" + | "ip_location" + | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } + | { type: "select"; options: Array }; + }; + extra_gated_heading?: string; + extra_gated_description?: string; + extra_gated_button_content?: string; +} diff --git a/types/api/api-index-tree.ts b/types/api/api-index-tree.ts new file mode 100644 index 0000000000000000000000000000000000000000..a467218cf288dbbdcb90662e72736fcabfa0246a --- /dev/null +++ b/types/api/api-index-tree.ts @@ -0,0 +1,46 @@ +export interface ApiIndexTreeEntry { + type: "file" | "directory" | "unknown"; + size: number; + path: string; + oid: string; + lfs?: { + oid: string; + size: number; + /** Size of the raw pointer file, 100~200 bytes */ + pointerSize: number; + }; + lastCommit?: { + date: string; + id: string; + title: string; + }; + security?: ApiFileScanResult; +} + +export interface ApiFileScanResult { + /** namespaced by repo type (models/, datasets/, spaces/) */ + repositoryId: string; + blobId: string; + name: string; + safe: boolean; + avScan?: ApiAVScan; + pickleImportScan?: ApiPickleImportScan; +} + +interface ApiAVScan { + virusFound: boolean; + virusNames?: string[]; +} + +type ApiSafetyLevel = "innocuous" | "suspicious" | "dangerous"; + +interface ApiPickleImport { + module: string; + name: string; + safety: ApiSafetyLevel; +} + +interface ApiPickleImportScan { + highestSafetyLevel: ApiSafetyLevel; + imports: ApiPickleImport[]; +} diff --git a/types/api/api-model.ts b/types/api/api-model.ts new file mode 100644 index 0000000000000000000000000000000000000000..5a052f69b2e9a2f0438c7439bdce6ca6a5076518 --- /dev/null +++ b/types/api/api-model.ts @@ -0,0 +1,273 @@ +import type { ModelLibraryKey, TransformersInfo, WidgetType } from "@huggingface/tasks"; +import type { License, PipelineType } from "../public"; + +export interface ApiModelInfo { + _id: string; + id: string; + arxivIds: string[]; + author?: string; + cardData?: ApiModelMetadata; + cardError: unknown; + cardExists?: true; + config: unknown; + contributors: Array<{ user: string; _id: string }>; + disabled: boolean; + discussionsDisabled: boolean; + doi?: { id: string; commit: string }; + downloads: number; + downloadsAllTime: number; + files: string[]; + gitalyUid: string; + inferenceProviderMapping: Partial< + Record + >; + lastAuthor: { email: string; user?: string }; + lastModified: string; // convert to date + library_name?: ModelLibraryKey; + likes: number; + likesRecent: number; + private: boolean; + gated: false | "auto" | "manual"; + sha: string; + spaces: string[]; + updatedAt: string; // convert to date + createdAt: string; // convert to date + pipeline_tag: PipelineType; + tags: string[]; + "model-index": unknown; + safetensors?: { + parameters: Record; + total: number; + }; + transformersInfo?: TransformersInfo; +} + +export interface ApiModelIndex { + name: string; + results: { + task: { + /** + * Example: automatic-speech-recognition +Use task id from https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasksData.ts + */ + type: string; + /** + * Example: Speech Recognition + */ + name?: string; + }; + /** + * This will switch to required at some point. +in any case, we need them to link to PWC + */ + dataset?: { + /** + * Example: common_voice. Use dataset id from https://hf.co/datasets + */ + type: string; + /** + * A pretty name for the dataset. Example: Common Voice zh-CN +Also encode config params into the name if relevant. + */ + name: string; + /** + * Optional. The name of the dataset configuration used in `load_dataset()` + */ + config?: string; + /** + * Optional. Example: test + */ + split?: string; + /** + * Optional. Example: 5503434ddd753f426f4b38109466949a1217c2bb + */ + revision?: string; + args?: + | string + | { + /** + * String Property + */ + [x: string]: string; + }; + }; + metrics: { + /** + * Example: wer. Use metric id from https://hf.co/metrics + */ + type: string; + /** + * Required. Example: 20.0 or "20.0 ± 1.2" + */ + value: unknown; + /** + * Example: Test WER + */ + name?: string; + /** + * Optional. The name of the metric configuration used in `load_metric()`. + */ + config?: string; + args?: + | string + | { + /** + * String Property + */ + [x: string]: string; + }; + /** + * [Automatically computed, do not set] Dynamically overridden by huggingface in API calls to indicate if it was verified by Hugging Face. + */ + verified?: boolean; + /** + * Generated by Hugging Face to prove the results are valid. + */ + verifyToken?: string; + }[]; + /** + * The source for this evaluation result. + */ + source?: { + /** + * Example: Open LLM Leaderboard + */ + name?: string; + /** + * Example: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard + */ + url: string; + }; + }[]; +} + +export interface ApiWidgetExampleFromModelcard { + example_title?: string; + group?: string; + text?: string; + src?: string; + table?: { + /** + * (string | number)[] Property + */ + [x: string]: (string | number)[]; + }; + structured_data?: { + /** + * (string | number)[] Property + */ + [x: string]: (string | number)[]; + }; + candidate_labels?: string; + messages?: { + role: "system" | "user" | "assistant"; + content: string; + }[]; + multi_class?: boolean; + source_sentence?: string; + sentences?: string[]; + parameters?: { + aggregation_strategy?: string; + top_k?: number; + top_p?: number; + temperature?: number; + max_new_tokens?: number; + do_sample?: boolean; + negative_prompt?: string; + guidance_scale?: number; + num_inference_steps?: number; + }; + output?: + | { + label: string; + score: number; + }[] + | { + answer: string; + score: number; + } + | { + text: string; + } + | { + url: string; + }; +} + +export interface ApiModelMetadata { + datasets?: string | string[]; + license?: License | License[]; + license_name?: string; + license_link?: "LICENSE" | "LICENSE.md" | string; + license_details?: string; + inference?: + | boolean + | { + parameters?: { + aggregation_strategy?: string; + top_k?: number; + top_p?: number; + temperature?: number; + max_new_tokens?: number; + do_sample?: boolean; + negative_prompt?: string; + guidance_scale?: number; + num_inference_steps?: number; + }; + }; + language?: string | string[]; + language_bcp47?: string[]; + language_details?: string; + tags?: string[]; + pipeline_tag?: string; + co2_eq_emissions?: + | number + | { + /** + * Emissions in grams of CO2 + */ + emissions: number; + /** + * source of the information, either directly from AutoTrain, code carbon or from a scientific article documenting the model + */ + source?: string; + /** + * pre-training or fine-tuning + */ + training_type?: string; + /** + * as granular as possible, for instance Quebec, Canada or Brooklyn, NY, USA + */ + geographical_location?: string; + /** + * how much compute and what kind, e.g. 8 v100 GPUs + */ + hardware_used?: string; + }; + library_name?: string; + thumbnail?: string | null; + description?: string | null; + mask_token?: string; + widget?: ApiWidgetExampleFromModelcard[]; + "model-index"?: ApiModelIndex[]; + finetuned_from?: string; + base_model?: string | string[]; + instance_prompt?: string | null; + extra_gated_prompt?: string; + extra_gated_fields?: { + /** + * "text" | "checkbox" | "date_picker" | "country" | "ip_location" | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } | { type: "select", options: Array } Property + */ + [x: string]: + | "text" + | "checkbox" + | "date_picker" + | "country" + | "ip_location" + | { type: "text" | "checkbox" | "date_picker" | "country" | "ip_location" } + | { type: "select"; options: Array }; + }; + extra_gated_heading?: string; + extra_gated_description?: string; + extra_gated_button_content?: string; +} diff --git a/types/api/api-space.ts b/types/api/api-space.ts new file mode 100644 index 0000000000000000000000000000000000000000..151677f4b36aa268f6ec95340f838a58b4a0f0f5 --- /dev/null +++ b/types/api/api-space.ts @@ -0,0 +1,93 @@ +import type { License, SpaceRuntime, SpaceSdk } from "../public"; + +type Color = "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; + +export interface ApiSpaceInfo { + _id: string; + id: string; + arxivIds?: string[]; + author: string; + cardExists?: true; + cardError?: unknown; + cardData?: unknown; + contributors?: Array<{ user: string; _id: string }>; + disabled: boolean; + discussionsDisabled: boolean; + duplicationDisabled: boolean; + gated: false | "auto" | "manual"; + gitalyUid: string; + lastAuthor: { email: string; user?: string }; + lastModified: string; // date + likes: number; + likesRecent: number; + private: boolean; + updatedAt: string; // date + createdAt: string; // date + tags: string[]; + sha: string; + subdomain: string; + title: string; + emoji: string; + colorFrom: Color; + colorTo: Color; + pinned: boolean; + siblings: Array<{ rfilename: string }>; + sdk?: SpaceSdk; + runtime?: SpaceRuntime; + models?: string[]; + datasets?: string[]; + originSpace?: { _id: string; authorId: string }; +} + +export interface ApiSpaceMetadata { + license?: License | License[]; + tags?: string[]; + title?: string; + colorFrom?: "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; + colorTo?: "red" | "yellow" | "green" | "blue" | "indigo" | "purple" | "pink" | "gray"; + emoji?: string; + sdk?: "streamlit" | "gradio" | "docker" | "static"; + sdk_version?: string | string; + python_version?: string | string; + fullWidth?: boolean; + header?: "mini" | "default"; + app_file?: string; + app_port?: number; + base_path?: string; + models?: string[]; + datasets?: string[]; + pinned?: boolean; + metaTitle?: string; + description?: string; + thumbnail?: string; + /** + * If enabled, will associate an oauth app to the Space, adding variables and secrets to the Space's environment + */ + hf_oauth?: boolean; + /** + * The expiration of access tokens for your oauth app in minutes. max 30 days (43,200 minutes). Defaults to 8 hours (480 minutes) + */ + hf_oauth_expiration_minutes?: number; + /** + * OAuth scopes to request. By default you have access to the user's profile, you can request access to their repos or inference-api. + */ + hf_oauth_scopes?: ("email" | "read-repos" | "write-repos" | "manage-repos" | "inference-api")[]; + suggested_hardware?: + | "cpu-basic" + | "zero-a10g" + | "cpu-upgrade" + | "cpu-xl" + | "t4-small" + | "t4-medium" + | "a10g-small" + | "a10g-large" + | "a10g-largex2" + | "a10g-largex4" + | "a100-large"; + suggested_storage?: "small" | "medium" | "large"; + custom_headers?: { + "cross-origin-embedder-policy"?: "unsafe-none" | "require-corp" | "credentialless"; + "cross-origin-opener-policy"?: "same-origin" | "same-origin-allow-popups" | "unsafe-none"; + "cross-origin-resource-policy"?: "same-site" | "same-origin" | "cross-origin"; + }; +} diff --git a/types/api/api-who-am-i.ts b/types/api/api-who-am-i.ts new file mode 100644 index 0000000000000000000000000000000000000000..1cb75c2112b25604bff526dc783753d6cd3f5863 --- /dev/null +++ b/types/api/api-who-am-i.ts @@ -0,0 +1,51 @@ +import type { AccessTokenRole, AuthType } from "../public"; + +interface ApiWhoAmIBase { + /** Unique ID persistent across renames */ + id: string; + type: "user" | "org" | "app"; + name: string; +} + +interface ApiWhoAmIEntityBase extends ApiWhoAmIBase { + fullname: string; + email: string | null; + canPay: boolean; + avatarUrl: string; + /** + * Unix timestamp in seconds + */ + periodEnd: number | null; +} + +interface ApiWhoAmIOrg extends ApiWhoAmIEntityBase { + type: "org"; +} + +interface ApiWhoAmIUser extends ApiWhoAmIEntityBase { + type: "user"; + email: string; + emailVerified: boolean; + isPro: boolean; + orgs: ApiWhoAmIOrg[]; +} + +interface ApiWhoAmIApp extends ApiWhoAmIBase { + type: "app"; + name: string; + scope?: { + entities: string[]; + role: AccessTokenRole; + }; +} + +export type ApiWhoAmIReponse = ApiWhoAmIUser | ApiWhoAmIOrg | ApiWhoAmIApp; + +export interface ApiWhoAmIAuthInfo { + type: AuthType; + accessToken?: { + displayName: string; + expiration?: string; + role: AccessTokenRole; + }; +} diff --git a/types/public.ts b/types/public.ts new file mode 100644 index 0000000000000000000000000000000000000000..6a5b4a3004405da0a46cf28c6316bf4e3add1369 --- /dev/null +++ b/types/public.ts @@ -0,0 +1,184 @@ +import type { PipelineType } from "@huggingface/tasks"; + +export type RepoType = "space" | "dataset" | "model"; + +export interface RepoId { + name: string; + type: RepoType; +} + +export type RepoFullName = string | `spaces/${string}` | `datasets/${string}`; + +export type RepoDesignation = RepoId | RepoFullName; + +/** Actually `hf_${string}`, but for convenience, using the string type */ +export type AccessToken = string; + +/** + * @deprecated Use `AccessToken` instead. Pass { accessToken: "hf_..." } instead of { credentials: { accessToken: "hf_..." } } + */ +export interface Credentials { + accessToken: AccessToken; +} + +export type CredentialsParams = + | { + accessToken?: undefined; + /** + * @deprecated Use `accessToken` instead + */ + credentials: Credentials; + } + | { + accessToken: AccessToken; + /** + * @deprecated Use `accessToken` instead + */ + credentials?: undefined; + }; + +export type SpaceHardwareFlavor = + | "cpu-basic" + | "cpu-upgrade" + | "t4-small" + | "t4-medium" + | "l4x1" + | "l4x4" + | "a10g-small" + | "a10g-large" + | "a10g-largex2" + | "a10g-largex4" + | "a100-large" + | "v5e-1x1" + | "v5e-2x2" + | "v5e-2x4"; + +export type SpaceSdk = "streamlit" | "gradio" | "docker" | "static"; + +export type SpaceStage = + | "NO_APP_FILE" + | "CONFIG_ERROR" + | "BUILDING" + | "BUILD_ERROR" + | "RUNNING" + | "RUNNING_BUILDING" + | "RUNTIME_ERROR" + | "DELETING" + | "PAUSED" + | "SLEEPING"; + +export type AccessTokenRole = "admin" | "write" | "contributor" | "read"; + +export type AuthType = "access_token" | "app_token" | "app_token_as_user"; + +export type { PipelineType }; + +export interface SpaceRuntime { + stage: SpaceStage; + sdk?: SpaceSdk; + sdkVersion?: string; + errorMessage?: string; + hardware?: { + current: SpaceHardwareFlavor | null; + currentPrettyName?: string; + requested: SpaceHardwareFlavor | null; + requestedPrettyName?: string; + }; + /** when calling /spaces, those props are only fetched if ?full=true */ + resources?: SpaceResourceConfig; + /** in seconds */ + gcTimeout?: number | null; +} + +export interface SpaceResourceRequirement { + cpu?: string; + memory?: string; + gpu?: string; + gpuModel?: string; + ephemeral?: string; +} + +export interface SpaceResourceConfig { + requests: SpaceResourceRequirement; + limits: SpaceResourceRequirement; + replicas?: number; + throttled?: boolean; + is_custom?: boolean; +} + +export type License = + | "apache-2.0" + | "mit" + | "openrail" + | "bigscience-openrail-m" + | "creativeml-openrail-m" + | "bigscience-bloom-rail-1.0" + | "bigcode-openrail-m" + | "afl-3.0" + | "artistic-2.0" + | "bsl-1.0" + | "bsd" + | "bsd-2-clause" + | "bsd-3-clause" + | "bsd-3-clause-clear" + | "c-uda" + | "cc" + | "cc0-1.0" + | "cc-by-2.0" + | "cc-by-2.5" + | "cc-by-3.0" + | "cc-by-4.0" + | "cc-by-sa-3.0" + | "cc-by-sa-4.0" + | "cc-by-nc-2.0" + | "cc-by-nc-3.0" + | "cc-by-nc-4.0" + | "cc-by-nd-4.0" + | "cc-by-nc-nd-3.0" + | "cc-by-nc-nd-4.0" + | "cc-by-nc-sa-2.0" + | "cc-by-nc-sa-3.0" + | "cc-by-nc-sa-4.0" + | "cdla-sharing-1.0" + | "cdla-permissive-1.0" + | "cdla-permissive-2.0" + | "wtfpl" + | "ecl-2.0" + | "epl-1.0" + | "epl-2.0" + | "etalab-2.0" + | "eupl-1.1" + | "agpl-3.0" + | "gfdl" + | "gpl" + | "gpl-2.0" + | "gpl-3.0" + | "lgpl" + | "lgpl-2.1" + | "lgpl-3.0" + | "isc" + | "lppl-1.3c" + | "ms-pl" + | "mpl-2.0" + | "odc-by" + | "odbl" + | "openrail++" + | "osl-3.0" + | "postgresql" + | "ofl-1.1" + | "ncsa" + | "unlicense" + | "zlib" + | "pddl" + | "lgpl-lr" + | "deepfloyd-if-license" + | "llama2" + | "llama3" + | "llama3.1" + | "llama3.2" + | "llama3.3" + | "gemma" + | "apple-ascl" + | "apple-amlr" + | "unknown" + | "other"; diff --git a/utils/FileBlob.spec.ts b/utils/FileBlob.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..2ed51d8e38e817bc69d0924e3f4b47885e85c5d9 --- /dev/null +++ b/utils/FileBlob.spec.ts @@ -0,0 +1,45 @@ +import { open, stat } from "node:fs/promises"; +import { TextDecoder } from "node:util"; +import { describe, expect, it } from "vitest"; +import { FileBlob } from "./FileBlob"; + +describe("FileBlob", () => { + it("should create a FileBlob with a slice on the entire file", async () => { + const file = await open("package.json", "r"); + const { size } = await stat("package.json"); + + const fileBlob = await FileBlob.create("package.json"); + + expect(fileBlob).toMatchObject({ + path: "package.json", + start: 0, + end: size, + }); + expect(fileBlob.size).toBe(size); + expect(fileBlob.type).toBe(""); + const text = await fileBlob.text(); + const expectedText = (await file.read(Buffer.alloc(size), 0, size)).buffer.toString("utf8"); + expect(text).toBe(expectedText); + const result = await fileBlob.stream().getReader().read(); + expect(new TextDecoder().decode(result.value)).toBe(expectedText); + }); + + it("should create a slice on the file", async () => { + const file = await open("package.json", "r"); + const fileBlob = await FileBlob.create("package.json"); + + const slice = fileBlob.slice(10, 20); + + expect(slice).toMatchObject({ + path: "package.json", + start: 10, + end: 20, + }); + expect(slice.size).toBe(10); + const sliceText = await slice.text(); + const expectedText = (await file.read(Buffer.alloc(10), 0, 10, 10)).buffer.toString("utf8"); + expect(sliceText).toBe(expectedText); + const result = await slice.stream().getReader().read(); + expect(new TextDecoder().decode(result.value)).toBe(expectedText); + }); +}); diff --git a/utils/FileBlob.ts b/utils/FileBlob.ts new file mode 100644 index 0000000000000000000000000000000000000000..e783ca6fa6ad79ad35604cb8f9e4f2ab388acc01 --- /dev/null +++ b/utils/FileBlob.ts @@ -0,0 +1,118 @@ +import { createReadStream } from "node:fs"; +import { open, stat } from "node:fs/promises"; +import { Readable } from "node:stream"; +import type { FileHandle } from "node:fs/promises"; +import { fileURLToPath } from "node:url"; + +/** + * @internal + * + * A FileBlob is a replacement for the Blob class that allows to lazy read files + * in order to preserve memory. + * + * It is a drop-in replacement for the Blob class, so you can use it as a Blob. + * + * The main difference is the instantiation, which is done asynchronously using the `FileBlob.create` method. + * + * @example + * const fileBlob = await FileBlob.create("path/to/package.json"); + * + * await fetch("https://aschen.tech", { method: "POST", body: fileBlob }); + */ +export class FileBlob extends Blob { + /** + * Creates a new FileBlob on the provided file. + * + * @param path Path to the file to be lazy readed + */ + static async create(path: string | URL): Promise { + path = path instanceof URL ? fileURLToPath(path) : path; + + const { size } = await stat(path); + + const fileBlob = new FileBlob(path, 0, size); + + return fileBlob; + } + + private path: string; + private start: number; + private end: number; + + private constructor(path: string, start: number, end: number) { + super(); + + this.path = path; + this.start = start; + this.end = end; + } + + /** + * Returns the size of the blob. + */ + override get size(): number { + return this.end - this.start; + } + + /** + * Returns a new instance of FileBlob that is a slice of the current one. + * + * The slice is inclusive of the start and exclusive of the end. + * + * The slice method does not supports negative start/end. + * + * @param start beginning of the slice + * @param end end of the slice + */ + override slice(start = 0, end = this.size): FileBlob { + if (start < 0 || end < 0) { + new TypeError("Unsupported negative start/end on FileBlob.slice"); + } + + const slice = new FileBlob(this.path, this.start + start, Math.min(this.start + end, this.end)); + + return slice; + } + + /** + * Read the part of the file delimited by the FileBlob and returns it as an ArrayBuffer. + */ + override async arrayBuffer(): Promise { + const slice = await this.execute((file) => file.read(Buffer.alloc(this.size), 0, this.size, this.start)); + + return slice.buffer; + } + + /** + * Read the part of the file delimited by the FileBlob and returns it as a string. + */ + override async text(): Promise { + const buffer = (await this.arrayBuffer()) as Buffer; + + return buffer.toString("utf8"); + } + + /** + * Returns a stream around the part of the file delimited by the FileBlob. + */ + override stream(): ReturnType { + return Readable.toWeb(createReadStream(this.path, { start: this.start, end: this.end - 1 })) as ReturnType< + Blob["stream"] + >; + } + + /** + * We are opening and closing the file for each action to prevent file descriptor leaks. + * + * It is an intended choice of developer experience over performances. + */ + private async execute(action: (file: FileHandle) => Promise) { + const file = await open(this.path, "r"); + + try { + return await action(file); + } finally { + await file.close(); + } + } +} diff --git a/utils/RangeList.spec.ts b/utils/RangeList.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..e05f85a8fa14b41b2c4549738dd9b8d67e5b30f1 --- /dev/null +++ b/utils/RangeList.spec.ts @@ -0,0 +1,96 @@ +import { describe, it, expect } from "vitest"; +import { RangeList } from "./RangeList"; + +describe("RangeList", () => { + it("should add a single range", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + + const ranges = rangeList.getAllRanges(); + expect(ranges).toHaveLength(1); + expect(ranges[0]).toEqual({ + start: 1, + end: 100, + refCount: 1, + data: null, + }); + }); + + it("should handle overlapping ranges", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + rangeList.add(30, 50); + + const ranges = rangeList.getAllRanges(); + expect(ranges).toHaveLength(3); + expect(ranges).toEqual([ + { start: 1, end: 30, refCount: 1, data: null }, + { start: 30, end: 50, refCount: 2, data: null }, + { start: 50, end: 100, refCount: 1, data: null }, + ]); + }); + + it("should remove a range at existing boundaries", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + rangeList.add(30, 50); + rangeList.remove(30, 50); + + const ranges = rangeList.getAllRanges(); + expect(ranges).toHaveLength(3); + expect(ranges).toEqual([ + { start: 1, end: 30, refCount: 1, data: null }, + { start: 30, end: 50, refCount: 1, data: null }, + { start: 50, end: 100, refCount: 1, data: null }, + ]); + }); + + it("should throw error when removing range at non-existing boundaries", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + rangeList.add(30, 50); + + expect(() => rangeList.remove(2, 50)).toThrow("Range boundaries must match existing boundaries"); + }); + + it("should get ranges within boundaries", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + rangeList.add(30, 50); + + const ranges = rangeList.getRanges(30, 100); + expect(ranges).toHaveLength(2); + expect(ranges).toEqual([ + { start: 30, end: 50, refCount: 2, data: null }, + { start: 50, end: 100, refCount: 1, data: null }, + ]); + }); + + it("should throw error when end is less than or equal to start", () => { + const rangeList = new RangeList(); + + expect(() => rangeList.add(100, 1)).toThrow("End must be greater than start"); + expect(() => rangeList.add(1, 1)).toThrow("End must be greater than start"); + expect(() => rangeList.remove(100, 1)).toThrow("End must be greater than start"); + expect(() => rangeList.remove(1, 1)).toThrow("End must be greater than start"); + expect(() => rangeList.getRanges(100, 1)).toThrow("End must be greater than start"); + expect(() => rangeList.getRanges(1, 1)).toThrow("End must be greater than start"); + }); + + it("should handle multiple overlapping ranges", () => { + const rangeList = new RangeList(); + rangeList.add(1, 100); + rangeList.add(30, 50); + rangeList.add(40, 60); + + const ranges = rangeList.getAllRanges(); + expect(ranges).toHaveLength(5); + expect(ranges).toEqual([ + { start: 1, end: 30, refCount: 1, data: null }, + { start: 30, end: 40, refCount: 2, data: null }, + { start: 40, end: 50, refCount: 3, data: null }, + { start: 50, end: 60, refCount: 2, data: null }, + { start: 60, end: 100, refCount: 1, data: null }, + ]); + }); +}); diff --git a/utils/RangeList.ts b/utils/RangeList.ts new file mode 100644 index 0000000000000000000000000000000000000000..69137fe69f08f0a97f9d9fb217c096b36254e07e --- /dev/null +++ b/utils/RangeList.ts @@ -0,0 +1,179 @@ +/** + * Code generated with this prompt by Cursor: + * + * I want to build a class to manage ranges + * + * I can add ranges to it with a start& an end (both integer, end > start). It should store those ranges efficently. + * + * When several ranges overlap, eg [1, 100] and [30, 50], I want the class to split the range into non-overlapping ranges, and add a "ref counter" to the ranges. For example, [1, 30], [30, 50] * 2, [50, 100] + * + * I also want to be able to remove ranges, it will decrease the ref counter or remove the range altogether. I can only remove ranges at existing boundaries. For example, with the [1, 30], [30, 50] * 2, [50, 100] configuration + * + * - removing [1, 100] => the only range remaning is [30, 50] + * - removing [2, 50] => error, because "2' is not a boundary + * - removing [30, 50] => [1, 30], [30, 50], [50, 100] (do not "merge" the ranges back together) + * + * I want to be able to associate data to each range. And I want to be able to get the ranges inside boundaries. For example , with [1, 30], [30, 50] * 2, [50, 100] configuration + * + * - getting [30, 100] => I receive [30, 50] * 2, [50, 100], and I can get / modify the data assocaited to each range by accessing their data prop. Note the "*2" is just the ref counter, there is onlly one range object for the interval returned + * - getting [2, 50] => I get [30, 50] * 2 + * + * ---- + * + * Could optimize with binary search, but the ranges we want to handle are not that many. + */ +interface Range { + start: number; + end: number; + refCount: number; + data: T | null; +} + +export class RangeList { + private ranges: Range[] = []; + + /** + * Add a range to the list. If it overlaps with existing ranges, + * it will split them and increment reference counts accordingly. + */ + add(start: number, end: number): void { + if (end <= start) { + throw new TypeError("End must be greater than start"); + } + + // Find all ranges that overlap with the new range + const overlappingRanges: { index: number; range: Range }[] = []; + for (let i = 0; i < this.ranges.length; i++) { + const range = this.ranges[i]; + if (start < range.end && end > range.start) { + overlappingRanges.push({ index: i, range }); + } + if (range.data !== null) { + throw new Error("Overlapping range already has data"); + } + } + + if (overlappingRanges.length === 0) { + // No overlaps, just add the new range + this.ranges.push({ start, end, refCount: 1, data: null }); + this.ranges.sort((a, b) => a.start - b.start); + return; + } + + // Handle overlaps by splitting ranges + const newRanges: Range[] = []; + let currentPos = start; + + for (let i = 0; i < overlappingRanges.length; i++) { + const { range } = overlappingRanges[i]; + + // Add range before overlap if exists + if (currentPos < range.start) { + newRanges.push({ + start: currentPos, + end: range.start, + refCount: 1, + data: null, + }); + } else if (range.start < currentPos) { + newRanges.push({ + start: range.start, + end: currentPos, + refCount: range.refCount, + data: null, + }); + } + + // Add overlapping part with increased ref count + newRanges.push({ + start: Math.max(currentPos, range.start), + end: Math.min(end, range.end), + refCount: range.refCount + 1, + data: null, + }); + + // Add remaining part of existing range if exists + if (range.end > end) { + newRanges.push({ + start: end, + end: range.end, + refCount: range.refCount, + data: null, + }); + } + + currentPos = Math.max(currentPos, range.end); + } + + // Add remaining part after last overlap if exists + if (currentPos < end) { + newRanges.push({ + start: currentPos, + end, + refCount: 1, + data: null, + }); + } + + // Remove old overlapping ranges and insert new ones + const firstIndex = overlappingRanges[0].index; + const lastIndex = overlappingRanges[overlappingRanges.length - 1].index; + this.ranges.splice(firstIndex, lastIndex - firstIndex + 1, ...newRanges); + this.ranges.sort((a, b) => a.start - b.start); + } + + /** + * Remove a range from the list. The range must start and end at existing boundaries. + */ + remove(start: number, end: number): void { + if (end <= start) { + throw new TypeError("End must be greater than start"); + } + + // Find ranges that need to be modified + const affectedRanges: { index: number; range: Range }[] = []; + for (let i = 0; i < this.ranges.length; i++) { + const range = this.ranges[i]; + if (start < range.end && end > range.start) { + affectedRanges.push({ index: i, range }); + } + } + + if (affectedRanges.length === 0) { + throw new Error("No ranges found to remove"); + } + + // Verify boundaries match + if (start !== affectedRanges[0].range.start || end !== affectedRanges[affectedRanges.length - 1].range.end) { + throw new Error("Range boundaries must match existing boundaries"); + } + + // Todo: also check if there's a gap in the middle but it should not happen with our usage + + for (let i = 0; i < affectedRanges.length; i++) { + const { range } = affectedRanges[i]; + + range.refCount--; + } + + this.ranges = this.ranges.filter((range) => range.refCount > 0); + } + + /** + * Get all ranges within the specified boundaries. + */ + getRanges(start: number, end: number): Range[] { + if (end <= start) { + throw new TypeError("End must be greater than start"); + } + + return this.ranges.filter((range) => start < range.end && end > range.start); + } + + /** + * Get all ranges in the list + */ + getAllRanges(): Range[] { + return [...this.ranges]; + } +} diff --git a/utils/WebBlob.spec.ts b/utils/WebBlob.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..242a51e08eddd94fbf49b6858cf5b0d9db081b41 --- /dev/null +++ b/utils/WebBlob.spec.ts @@ -0,0 +1,95 @@ +import { describe, expect, it, beforeAll } from "vitest"; +import { WebBlob } from "./WebBlob"; + +describe("WebBlob", () => { + const resourceUrl = new URL("https://huggingface.co/spaces/aschen/push-model-from-web/raw/main/mobilenet/model.json"); + let fullText: string; + let size: number; + let contentType: string; + + beforeAll(async () => { + const response = await fetch(resourceUrl, { method: "HEAD" }); + size = Number(response.headers.get("content-length")); + contentType = response.headers.get("content-type") || ""; + fullText = await (await fetch(resourceUrl)).text(); + }); + + it("should create a WebBlob with a slice on the entire resource", async () => { + const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 0, accessToken: undefined }); + + expect(webBlob).toMatchObject({ + url: resourceUrl, + start: 0, + end: size, + contentType, + }); + expect(webBlob).toBeInstanceOf(WebBlob); + expect(webBlob.size).toBe(size); + expect(webBlob.type).toBe(contentType); + + const text = await webBlob.text(); + expect(text).toBe(fullText); + + const streamText = await new Response(webBlob.stream()).text(); + expect(streamText).toBe(fullText); + }); + + it("should create a WebBlob with a slice on the entire resource, cached", async () => { + const webBlob = await WebBlob.create(resourceUrl, { cacheBelow: 1_000_000, accessToken: undefined }); + + expect(webBlob).not.toBeInstanceOf(WebBlob); + expect(webBlob.size).toBe(size); + expect(webBlob.type.replace(/;\s*charset=utf-8/, "")).toBe(contentType.replace(/;\s*charset=utf-8/, "")); + + const text = await webBlob.text(); + expect(text).toBe(fullText); + + const streamText = await new Response(webBlob.stream()).text(); + expect(streamText).toBe(fullText); + }); + + it("should lazy load a LFS file hosted on Hugging Face", async () => { + const zephyrUrl = + "https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha/resolve/main/model-00001-of-00008.safetensors"; + const url = new URL(zephyrUrl); + const webBlob = await WebBlob.create(url); + + expect(webBlob.size).toBe(1_889_587_040); + expect(webBlob).toBeInstanceOf(WebBlob); + expect(webBlob).toMatchObject({ url }); + expect(await webBlob.slice(10, 22).text()).toBe("__metadata__"); + }); + + it("should lazy load a Xet file hosted on Hugging Face", async () => { + const stableDiffusionUrl = + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors"; + const url = new URL(stableDiffusionUrl); + const webBlob = await WebBlob.create(url); + + expect(webBlob.size).toBe(5_135_149_760); + expect(webBlob).toBeInstanceOf(WebBlob); + expect(webBlob).toMatchObject({ url }); + expect(await webBlob.slice(10, 22).text()).toBe("__metadata__"); + }); + + it("should create a slice on the file", async () => { + const expectedText = fullText.slice(10, 20); + + const slice = (await WebBlob.create(resourceUrl, { cacheBelow: 0, accessToken: undefined })).slice(10, 20); + + expect(slice).toMatchObject({ + url: resourceUrl, + start: 10, + end: 20, + contentType, + }); + expect(slice.size).toBe(10); + expect(slice.type).toBe(contentType); + + const sliceText = await slice.text(); + expect(sliceText).toBe(expectedText); + + const streamText = await new Response(slice.stream()).text(); + expect(streamText).toBe(expectedText); + }); +}); diff --git a/utils/WebBlob.ts b/utils/WebBlob.ts new file mode 100644 index 0000000000000000000000000000000000000000..364bd95094032f12a22740b3d4f80fb8e98d27e1 --- /dev/null +++ b/utils/WebBlob.ts @@ -0,0 +1,139 @@ +/** + * WebBlob is a Blob implementation for web resources that supports range requests. + */ + +import { createApiError } from "../error"; + +interface WebBlobCreateOptions { + /** + * @default 1_000_000 + * + * Objects below that size will immediately be fetched and put in RAM, rather + * than streamed ad-hoc + */ + cacheBelow?: number; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + accessToken: string | undefined; +} + +export class WebBlob extends Blob { + static async create(url: URL, opts?: WebBlobCreateOptions): Promise { + const customFetch = opts?.fetch ?? fetch; + const response = await customFetch(url, { + method: "HEAD", + ...(opts?.accessToken && { + headers: { + Authorization: `Bearer ${opts.accessToken}`, + }, + }), + }); + + const size = Number(response.headers.get("content-length")); + const contentType = response.headers.get("content-type") || ""; + const supportRange = response.headers.get("accept-ranges") === "bytes"; + + if (!supportRange || size < (opts?.cacheBelow ?? 1_000_000)) { + return await (await customFetch(url)).blob(); + } + + return new WebBlob(url, 0, size, contentType, true, customFetch, opts?.accessToken); + } + + private url: URL; + private start: number; + private end: number; + private contentType: string; + private full: boolean; + private fetch: typeof fetch; + private accessToken: string | undefined; + + constructor( + url: URL, + start: number, + end: number, + contentType: string, + full: boolean, + customFetch: typeof fetch, + accessToken: string | undefined + ) { + super([]); + + this.url = url; + this.start = start; + this.end = end; + this.contentType = contentType; + this.full = full; + this.fetch = customFetch; + this.accessToken = accessToken; + } + + override get size(): number { + return this.end - this.start; + } + + override get type(): string { + return this.contentType; + } + + override slice(start = 0, end = this.size): WebBlob { + if (start < 0 || end < 0) { + new TypeError("Unsupported negative start/end on WebBlob.slice"); + } + + const slice = new WebBlob( + this.url, + this.start + start, + Math.min(this.start + end, this.end), + this.contentType, + start === 0 && end === this.size ? this.full : false, + this.fetch, + this.accessToken + ); + + return slice; + } + + override async arrayBuffer(): Promise { + const result = await this.fetchRange(); + + return result.arrayBuffer(); + } + + override async text(): Promise { + const result = await this.fetchRange(); + + return result.text(); + } + + override stream(): ReturnType { + const stream = new TransformStream(); + + this.fetchRange() + .then((response) => response.body?.pipeThrough(stream)) + .catch((error) => stream.writable.abort(error.message)); + + return stream.readable; + } + + private fetchRange(): Promise { + const fetch = this.fetch; // to avoid this.fetch() which is bound to the instance instead of globalThis + if (this.full) { + return fetch(this.url, { + ...(this.accessToken && { + headers: { + Authorization: `Bearer ${this.accessToken}`, + }, + }), + }).then((resp) => (resp.ok ? resp : createApiError(resp))); + } + return fetch(this.url, { + headers: { + Range: `bytes=${this.start}-${this.end - 1}`, + ...(this.accessToken && { Authorization: `Bearer ${this.accessToken}` }), + }, + }).then((resp) => (resp.ok ? resp : createApiError(resp))); + } +} diff --git a/utils/XetBlob.spec.ts b/utils/XetBlob.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..e3233fab6dd6ed86aeeb0fd5beefee204528010e --- /dev/null +++ b/utils/XetBlob.spec.ts @@ -0,0 +1,882 @@ +import { describe, expect, it } from "vitest"; +import type { ReconstructionInfo } from "./XetBlob"; +import { bg4_regoup_bytes, XetBlob } from "./XetBlob"; +import { sum } from "./sum"; + +describe("XetBlob", () => { + it("should lazy load the first 22 bytes", async () => { + const blob = new XetBlob({ + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + }); + + expect(await blob.slice(10, 22).text()).toBe("__metadata__"); + }); + + it("should load the first chunk correctly", async () => { + let xorbCount = 0; + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + fetch: async (url, opts) => { + if (typeof url === "string" && url.includes("/xorbs/")) { + xorbCount++; + } + return fetch(url, opts); + }, + }); + + const xetDownload = await blob.slice(0, 29928).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=0-29927", + }, + } + ).then((res) => res.arrayBuffer()); + + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + expect(xorbCount).toBe(1); + }); + + it("should load just past the first chunk correctly", async () => { + let xorbCount = 0; + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + fetch: async (url, opts) => { + if (typeof url === "string" && url.includes("/xorbs/")) { + xorbCount++; + } + return fetch(url, opts); + }, + }); + + const xetDownload = await blob.slice(0, 29929).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=0-29928", + }, + } + ).then((res) => res.arrayBuffer()); + + expect(xetDownload.byteLength).toBe(29929); + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + expect(xorbCount).toBe(2); + }); + + it("should load the first 200kB correctly", async () => { + let xorbCount = 0; + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + fetch: async (url, opts) => { + if (typeof url === "string" && url.includes("/xorbs/")) { + xorbCount++; + } + return fetch(url, opts); + }, + // internalLogging: true, + }); + + const xetDownload = await blob.slice(0, 200_000).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=0-199999", + }, + } + ).then((res) => res.arrayBuffer()); + + expect(xetDownload.byteLength).toBe(200_000); + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + expect(xorbCount).toBe(2); + }, 60_000); + + it("should load correctly when loading far into a chunk range", async () => { + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "7b3b6d07673a88cf467e67c1f7edef1a8c268cbf66e9dd9b0366322d4ab56d9b", + size: 5_234_139_343, + // internalLogging: true, + }); + + const xetDownload = await blob.slice(10_000_000, 10_100_000).arrayBuffer(); + const bridgeDownload = await fetch( + "https://huggingface.co/celinah/xet-experiments/resolve/main/model5GB.safetensors", + { + headers: { + Range: "bytes=10000000-10099999", + }, + } + ).then((res) => res.arrayBuffer()); + + console.log("xet", xetDownload.byteLength, "bridge", bridgeDownload.byteLength); + expect(new Uint8Array(xetDownload).length).toEqual(100_000); + expect(new Uint8Array(xetDownload)).toEqual(new Uint8Array(bridgeDownload)); + }); + + it("should load text correctly when offset_into_range starts in a chunk further than the first", async () => { + const blob = new XetBlob({ + refreshUrl: "https://huggingface.co/api/models/celinah/xet-experiments/xet-read-token/main", + hash: "794efea76d8cb372bbe1385d9e51c3384555f3281e629903ecb6abeff7d54eec", + size: 62_914_580, + }); + + // Reconstruction info + // { + // "offset_into_first_range": 600000, + // "terms": + // [ + // { + // "hash": "be748f77930d5929cabd510a15f2c30f2f460b639804ef79dea46affa04fd8b2", + // "unpacked_length": 655360, + // "range": { "start": 0, "end": 5 }, + // }, + // { + // "hash": "be748f77930d5929cabd510a15f2c30f2f460b639804ef79dea46affa04fd8b2", + // "unpacked_length": 655360, + // "range": { "start": 0, "end": 5 }, + // }, + // ], + // "fetch_info": + // { + // "be748f77930d5929cabd510a15f2c30f2f460b639804ef79dea46affa04fd8b2": + // [ + // { + // "range": { "start": 0, "end": 5 }, + // "url": "...", + // "url_range": { "start": 0, "end": 2839 }, + // }, + // ], + // }, + // } + + const text = await blob.slice(600_000, 700_000).text(); + const bridgeDownload = await fetch("https://huggingface.co/celinah/xet-experiments/resolve/main/large_text.txt", { + headers: { + Range: "bytes=600000-699999", + }, + }).then((res) => res.text()); + + console.log("xet", text.length, "bridge", bridgeDownload.length); + expect(text.length).toBe(bridgeDownload.length); + }); + + describe("bg4_regoup_bytes", () => { + it("should regroup bytes when the array is %4 length", () => { + expect(bg4_regoup_bytes(new Uint8Array([1, 5, 2, 6, 3, 7, 4, 8]))).toEqual( + new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8]) + ); + }); + + it("should regroup bytes when the array is %4 + 1 length", () => { + expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 3, 7, 4, 8]))).toEqual( + new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + ); + }); + + it("should regroup bytes when the array is %4 + 2 length", () => { + expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 4, 8]))).toEqual( + new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + ); + }); + + it("should regroup bytes when the array is %4 + 3 length", () => { + expect(bg4_regoup_bytes(new Uint8Array([1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8]))).toEqual( + new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + ); + }); + }); + + describe("when mocked", () => { + describe("loading many chunks every read", () => { + it("should load different slices", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(1000) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]); + + const mergedChunks = await new Blob(chunks).arrayBuffer(); + const wholeText = (chunk1Content + chunk2Content).repeat(1000); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: Array(1000) + .fill(0) + .map(() => ({ + hash: "test", + range: { + start: 0, + end: 2, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + })), + fetch_info: { + test: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: mergedChunks.byteLength / 1000 - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + controller.enqueue(new Uint8Array(mergedChunks)); + controller.close(); + }, + }) + //mergedChunks + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, 2000, totalSize - 12, totalSize - 2]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index).text(); + expect(content.length).toBe(wholeText.length - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(debugged.filter((e) => e.event === "read").length).toBe(2); // 1 read + 1 undefined + expect(fetchCount).toEqual(1); + + fetchCount = 0; + debugged.length = 0; + } + }); + + it("should load different slices when working with different XORBS", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(1000) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]); + + const mergedChunks = await new Blob(chunks).arrayBuffer(); + const wholeText = (chunk1Content + chunk2Content).repeat(1000); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: Array(1000) + .fill(0) + .map((_, i) => ({ + hash: "test" + (i % 2), + range: { + start: 0, + end: 2, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + })), + fetch_info: { + test0: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: mergedChunks.byteLength - 1, + }, + }, + ], + test1: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: mergedChunks.byteLength - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + controller.enqueue(new Uint8Array(mergedChunks)); + controller.close(); + }, + }) + //mergedChunks + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, 2000, totalSize - 12, totalSize - 2]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index).text(); + expect(content.length).toBe(wholeText.length - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(debugged.filter((e) => e.event === "read").length).toBe(4); // 1 read + 1 undefined + expect(fetchCount).toEqual(2); + + fetchCount = 0; + debugged.length = 0; + } + }); + }); + + describe("loading one chunk at a time", () => { + it("should load different slices but not till the end", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(1000) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]); + + const totalChunkLength = sum(chunks.map((x) => x.byteLength)); + const wholeText = (chunk1Content + chunk2Content).repeat(1000); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: [ + { + hash: "test", + range: { + start: 0, + end: 2000, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + }, + ], + fetch_info: { + test: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2000 }, + url_range: { + start: 0, + end: totalChunkLength - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + for (const chunk of chunks) { + controller.enqueue(chunk); + } + controller.close(); + }, + }), + { + headers: { + "Content-Range": `bytes 0-${totalChunkLength - 1}/${totalChunkLength}`, + ETag: `"test"`, + "Content-Length": `${totalChunkLength}`, + }, + } + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, 2000]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index, 4000).text(); + expect(content.length).toBe(4000 - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(fetchCount).toEqual(1); + + fetchCount = 0; + debugged.length = 0; + } + }); + + it("should load different slices", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(1000) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]); + + const totalChunkLength = sum(chunks.map((x) => x.byteLength)); + const wholeText = (chunk1Content + chunk2Content).repeat(1000); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: Array(1000) + .fill(0) + .map(() => ({ + hash: "test", + range: { + start: 0, + end: 2, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + })), + fetch_info: { + test: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: totalChunkLength - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + for (const chunk of chunks) { + controller.enqueue(chunk); + } + controller.close(); + }, + }) + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, 2000, totalSize - 12, totalSize - 2]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index).text(); + expect(content.length).toBe(wholeText.length - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(debugged.filter((e) => e.event === "read").length).toBe(2000 + 1); // 1 read for each chunk + 1 undefined + expect(fetchCount).toEqual(1); + + fetchCount = 0; + debugged.length = 0; + } + }); + }); + + describe("loading at 29 bytes intervals", () => { + it("should load different slices", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(1000) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]); + const mergedChunks = await new Blob(chunks).arrayBuffer(); + const splitChunks = splitChunk(new Uint8Array(mergedChunks), 29); + + const totalChunkLength = sum(chunks.map((x) => x.byteLength)); + const wholeText = (chunk1Content + chunk2Content).repeat(1000); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: Array(1000) + .fill(0) + .map(() => ({ + hash: "test", + range: { + start: 0, + end: 2, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + })), + fetch_info: { + test: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: totalChunkLength - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + for (const chunk of splitChunks) { + controller.enqueue(chunk); + } + controller.close(); + }, + }) + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, 2000, totalSize - 12, totalSize - 2]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index).text(); + expect(content.length).toBe(wholeText.length - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(debugged.filter((e) => e.event === "read").length).toBe(Math.ceil(totalChunkLength / 29) + 1); // 1 read for each chunk + 1 undefined + expect(fetchCount).toEqual(1); + + fetchCount = 0; + debugged.length = 0; + } + }); + }); + + describe("loading one byte at a time", () => { + it("should load different slices", async () => { + const chunk1Content = "hello"; + const chunk2Content = "world!"; + const debugged: Array<{ event: "read" | string } & Record> = []; + + const chunks = Array(100) + .fill(0) + .flatMap(() => [makeChunk(chunk1Content), makeChunk(chunk2Content)]) + .flatMap((x) => splitChunk(x, 1)); + + const totalChunkLength = sum(chunks.map((x) => x.byteLength)); + const wholeText = (chunk1Content + chunk2Content).repeat(100); + + const totalSize = wholeText.length; + let fetchCount = 0; + + const blob = new XetBlob({ + hash: "test", + size: totalSize, + refreshUrl: "https://huggingface.co", + listener: (e) => debugged.push(e), + fetch: async function (_url, opts) { + const url = new URL(_url as string); + const headers = opts?.headers as Record | undefined; + + switch (url.hostname) { + case "huggingface.co": { + // This is a token + return new Response( + JSON.stringify({ + casUrl: "https://cas.co", + accessToken: "boo", + exp: 1_000_000, + }) + ); + } + case "cas.co": { + // This is the reconstruction info + const range = headers?.["Range"]?.slice("bytes=".length).split("-").map(Number); + + const start = range?.[0] ?? 0; + // const end = range?.[1] ?? (totalSize - 1); + + return new Response( + JSON.stringify({ + terms: Array(100) + .fill(0) + .map(() => ({ + hash: "test", + range: { + start: 0, + end: 2, + }, + unpacked_length: chunk1Content.length + chunk2Content.length, + })), + fetch_info: { + test: [ + { + url: "https://fetch.co", + range: { start: 0, end: 2 }, + url_range: { + start: 0, + end: totalChunkLength - 1, + }, + }, + ], + }, + offset_into_first_range: start, + } satisfies ReconstructionInfo) + ); + } + case "fetch.co": { + fetchCount++; + return new Response( + new ReadableStream({ + pull(controller) { + for (const chunk of chunks) { + controller.enqueue(chunk); + } + controller.close(); + }, + }) + ); + } + default: + throw new Error("Unhandled URL"); + } + }, + }); + + const startIndexes = [0, 5, 11, 6, 12, 100, totalSize - 12, totalSize - 2]; + + for (const index of startIndexes) { + console.log("slice", index); + const content = await blob.slice(index).text(); + expect(content.length).toBe(wholeText.length - index); + expect(content.slice(0, 1000)).toEqual(wholeText.slice(index).slice(0, 1000)); + expect(debugged.filter((e) => e.event === "read").length).toBe(totalChunkLength + 1); // 1 read for each chunk + 1 undefined + expect(fetchCount).toEqual(1); + + fetchCount = 0; + debugged.length = 0; + } + }); + }); + }); +}); + +function makeChunk(content: string) { + const encoded = new TextEncoder().encode(content); + + const array = new Uint8Array(encoded.length + 8); + + const dataView = new DataView(array.buffer); + dataView.setUint8(0, 0); // version + dataView.setUint8(1, encoded.length % 256); // Compressed length + dataView.setUint8(2, (encoded.length >> 8) % 256); // Compressed length + dataView.setUint8(3, (encoded.length >> 16) % 256); // Compressed length + dataView.setUint8(4, 0); // Compression scheme + dataView.setUint8(5, encoded.length % 256); // Uncompressed length + dataView.setUint8(6, (encoded.length >> 8) % 256); // Uncompressed length + dataView.setUint8(7, (encoded.length >> 16) % 256); // Uncompressed length + + array.set(encoded, 8); + + return array; +} + +function splitChunk(chunk: Uint8Array, toLength: number): Uint8Array[] { + const dataView = new DataView(chunk.buffer); + return new Array(Math.ceil(chunk.byteLength / toLength)).fill(0).map((_, i) => { + const array = new Uint8Array(Math.min(toLength, chunk.byteLength - i * toLength)); + + for (let j = 0; j < array.byteLength; j++) { + array[j] = dataView.getUint8(i * toLength + j); + } + return array; + }); +} diff --git a/utils/XetBlob.ts b/utils/XetBlob.ts new file mode 100644 index 0000000000000000000000000000000000000000..3b787b0e5467020d331887629cc04042899c54ce --- /dev/null +++ b/utils/XetBlob.ts @@ -0,0 +1,662 @@ +import { createApiError } from "../error"; +import type { CredentialsParams } from "../types/public"; +import { checkCredentials } from "./checkCredentials"; +import { decompress as lz4_decompress } from "../vendor/lz4js"; +import { RangeList } from "./RangeList"; + +const JWT_SAFETY_PERIOD = 60_000; +const JWT_CACHE_SIZE = 1_000; + +type XetBlobCreateOptions = { + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + // URL to get the access token from + refreshUrl: string; + size: number; + listener?: (arg: { event: "read" } | { event: "progress"; progress: { read: number; total: number } }) => void; + internalLogging?: boolean; +} & ({ hash: string; reconstructionUrl?: string } | { hash?: string; reconstructionUrl: string }) & + Partial; + +export interface ReconstructionInfo { + /** + * List of CAS blocks + */ + terms: Array<{ + /** Hash of the CAS block */ + hash: string; + /** Total uncompressed length of data of the chunks from range.start to range.end - 1 */ + unpacked_length: number; + /** Chunks. Eg start: 10, end: 100 = chunks 10-99 */ + range: { start: number; end: number }; + }>; + + /** + * Dictionnary of CAS block hash => list of ranges in the block + url to fetch it + */ + fetch_info: Record< + string, + Array<{ + url: string; + /** Chunk range */ + range: { start: number; end: number }; + /** + * Byte range, when making the call to the URL. + * + * We assume that we're given non-overlapping ranges for each hash + */ + url_range: { start: number; end: number }; + }> + >; + /** + * When doing a range request, the offset into the term's uncompressed data. Can be multiple chunks' worth of data. + */ + offset_into_first_range: number; +} + +enum CompressionScheme { + None = 0, + LZ4 = 1, + ByteGroupingLZ4 = 2, +} + +const compressionSchemeLabels: Record = { + [CompressionScheme.None]: "None", + [CompressionScheme.LZ4]: "LZ4", + [CompressionScheme.ByteGroupingLZ4]: "ByteGroupingLZ4", +}; + +interface ChunkHeader { + version: number; // u8, 1 byte + compressed_length: number; // 3 * u8, 3 bytes + compression_scheme: CompressionScheme; // u8, 1 byte + uncompressed_length: number; // 3 * u8, 3 bytes +} + +const CHUNK_HEADER_BYTES = 8; + +/** + * XetBlob is a blob implementation that fetches data directly from the Xet storage + */ +export class XetBlob extends Blob { + fetch: typeof fetch; + accessToken?: string; + refreshUrl: string; + reconstructionUrl?: string; + hash?: string; + start = 0; + end = 0; + internalLogging = false; + reconstructionInfo: ReconstructionInfo | undefined; + listener: XetBlobCreateOptions["listener"]; + + constructor(params: XetBlobCreateOptions) { + super([]); + + this.fetch = params.fetch ?? fetch.bind(globalThis); + this.accessToken = checkCredentials(params); + this.refreshUrl = params.refreshUrl; + this.end = params.size; + this.reconstructionUrl = params.reconstructionUrl; + this.hash = params.hash; + this.listener = params.listener; + this.internalLogging = params.internalLogging ?? false; + this.refreshUrl; + } + + override get size(): number { + return this.end - this.start; + } + + #clone() { + const blob = new XetBlob({ + fetch: this.fetch, + hash: this.hash, + refreshUrl: this.refreshUrl, + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + reconstructionUrl: this.reconstructionUrl!, + size: this.size, + }); + + blob.accessToken = this.accessToken; + blob.start = this.start; + blob.end = this.end; + blob.reconstructionInfo = this.reconstructionInfo; + blob.listener = this.listener; + blob.internalLogging = this.internalLogging; + + return blob; + } + + override slice(start = 0, end = this.size): XetBlob { + if (start < 0 || end < 0) { + new TypeError("Unsupported negative start/end on XetBlob.slice"); + } + + const slice = this.#clone(); + + slice.start = this.start + start; + slice.end = Math.min(this.start + end, this.end); + + if (slice.start !== this.start || slice.end !== this.end) { + slice.reconstructionInfo = undefined; + } + + return slice; + } + + #reconstructionInfoPromise?: Promise; + + #loadReconstructionInfo() { + if (this.#reconstructionInfoPromise) { + return this.#reconstructionInfoPromise; + } + + this.#reconstructionInfoPromise = (async () => { + const connParams = await getAccessToken(this.accessToken, this.fetch, this.refreshUrl); + + // debug( + // `curl '${connParams.casUrl}/reconstruction/${this.hash}' -H 'Authorization: Bearer ${connParams.accessToken}'` + // ); + + const resp = await this.fetch(this.reconstructionUrl ?? `${connParams.casUrl}/reconstruction/${this.hash}`, { + headers: { + Authorization: `Bearer ${connParams.accessToken}`, + Range: `bytes=${this.start}-${this.end - 1}`, + }, + }); + + if (!resp.ok) { + throw await createApiError(resp); + } + + this.reconstructionInfo = (await resp.json()) as ReconstructionInfo; + + return this.reconstructionInfo; + })().finally(() => (this.#reconstructionInfoPromise = undefined)); + + return this.#reconstructionInfoPromise; + } + + async #fetch(): Promise> { + if (!this.reconstructionInfo) { + await this.#loadReconstructionInfo(); + } + + const rangeLists = new Map>(); + + if (!this.reconstructionInfo) { + throw new Error("Failed to load reconstruction info"); + } + + for (const term of this.reconstructionInfo.terms) { + let rangeList = rangeLists.get(term.hash); + if (!rangeList) { + rangeList = new RangeList(); + rangeLists.set(term.hash, rangeList); + } + + rangeList.add(term.range.start, term.range.end); + } + const listener = this.listener; + const log = this.internalLogging ? (...args: unknown[]) => console.log(...args) : () => {}; + + async function* readData( + reconstructionInfo: ReconstructionInfo, + customFetch: typeof fetch, + maxBytes: number, + reloadReconstructionInfo: () => Promise + ) { + let totalBytesRead = 0; + let readBytesToSkip = reconstructionInfo.offset_into_first_range; + + for (const term of reconstructionInfo.terms) { + if (totalBytesRead >= maxBytes) { + break; + } + + const rangeList = rangeLists.get(term.hash); + if (!rangeList) { + throw new Error(`Failed to find range list for term ${term.hash}`); + } + + { + const termRanges = rangeList.getRanges(term.range.start, term.range.end); + + if (termRanges.every((range) => range.data)) { + log("all data available for term", term.hash, readBytesToSkip); + rangeLoop: for (const range of termRanges) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + for (let chunk of range.data!) { + if (readBytesToSkip) { + const skipped = Math.min(readBytesToSkip, chunk.byteLength); + chunk = chunk.slice(skipped); + readBytesToSkip -= skipped; + if (!chunk.byteLength) { + continue; + } + } + if (chunk.byteLength > maxBytes - totalBytesRead) { + chunk = chunk.slice(0, maxBytes - totalBytesRead); + } + totalBytesRead += chunk.byteLength; + // The stream consumer can decide to transfer ownership of the chunk, so we need to return a clone + // if there's more than one range for the same term + yield range.refCount > 1 ? chunk.slice() : chunk; + listener?.({ event: "progress", progress: { read: totalBytesRead, total: maxBytes } }); + + if (totalBytesRead >= maxBytes) { + break rangeLoop; + } + } + } + rangeList.remove(term.range.start, term.range.end); + continue; + } + } + + const fetchInfo = reconstructionInfo.fetch_info[term.hash].find( + (info) => info.range.start <= term.range.start && info.range.end >= term.range.end + ); + + if (!fetchInfo) { + throw new Error( + `Failed to find fetch info for term ${term.hash} and range ${term.range.start}-${term.range.end}` + ); + } + + log("term", term); + log("fetchinfo", fetchInfo); + log("readBytesToSkip", readBytesToSkip); + + let resp = await customFetch(fetchInfo.url, { + headers: { + Range: `bytes=${fetchInfo.url_range.start}-${fetchInfo.url_range.end}`, + }, + }); + + if (resp.status === 403) { + // In case it's expired + reconstructionInfo = await reloadReconstructionInfo(); + resp = await customFetch(fetchInfo.url, { + headers: { + Range: `bytes=${fetchInfo.url_range.start}-${fetchInfo.url_range.end}`, + }, + }); + } + + if (!resp.ok) { + throw await createApiError(resp); + } + + log( + "expected content length", + resp.headers.get("content-length"), + "range", + fetchInfo.url_range, + resp.headers.get("content-range") + ); + + const reader = resp.body?.getReader(); + if (!reader) { + throw new Error("Failed to get reader from response body"); + } + + let done = false; + let chunkIndex = fetchInfo.range.start; + const ranges = rangeList.getRanges(fetchInfo.range.start, fetchInfo.range.end); + + let leftoverBytes: Uint8Array | undefined = undefined; + let totalFetchBytes = 0; + + fetchData: while (!done && totalBytesRead < maxBytes) { + const result = await reader.read(); + listener?.({ event: "read" }); + + done = result.done; + + log("read", result.value?.byteLength, "bytes", "total read", totalBytesRead, "toSkip", readBytesToSkip); + + if (!result.value) { + log("no data in result, cancelled", result); + continue; + } + + totalFetchBytes += result.value.byteLength; + + if (leftoverBytes) { + result.value = new Uint8Array([...leftoverBytes, ...result.value]); + leftoverBytes = undefined; + } + + while (totalBytesRead < maxBytes && result.value.byteLength) { + if (result.value.byteLength < 8) { + // We need 8 bytes to parse the chunk header + leftoverBytes = result.value; + continue fetchData; + } + + const header = new DataView(result.value.buffer, result.value.byteOffset, CHUNK_HEADER_BYTES); + const chunkHeader: ChunkHeader = { + version: header.getUint8(0), + compressed_length: header.getUint8(1) | (header.getUint8(2) << 8) | (header.getUint8(3) << 16), + compression_scheme: header.getUint8(4), + uncompressed_length: header.getUint8(5) | (header.getUint8(6) << 8) | (header.getUint8(7) << 16), + }; + + log("chunk header", chunkHeader, "to skip", readBytesToSkip); + + if (chunkHeader.version !== 0) { + throw new Error(`Unsupported chunk version ${chunkHeader.version}`); + } + + if ( + chunkHeader.compression_scheme !== CompressionScheme.None && + chunkHeader.compression_scheme !== CompressionScheme.LZ4 && + chunkHeader.compression_scheme !== CompressionScheme.ByteGroupingLZ4 + ) { + throw new Error( + `Unsupported compression scheme ${ + compressionSchemeLabels[chunkHeader.compression_scheme] ?? chunkHeader.compression_scheme + }` + ); + } + + if (result.value.byteLength < chunkHeader.compressed_length + CHUNK_HEADER_BYTES) { + // We need more data to read the full chunk + leftoverBytes = result.value; + continue fetchData; + } + + result.value = result.value.slice(CHUNK_HEADER_BYTES); + + let uncompressed = + chunkHeader.compression_scheme === CompressionScheme.LZ4 + ? lz4_decompress(result.value.slice(0, chunkHeader.compressed_length), chunkHeader.uncompressed_length) + : chunkHeader.compression_scheme === CompressionScheme.ByteGroupingLZ4 + ? bg4_regoup_bytes( + lz4_decompress( + result.value.slice(0, chunkHeader.compressed_length), + chunkHeader.uncompressed_length + ) + ) + : result.value.slice(0, chunkHeader.compressed_length); + + const range = ranges.find((range) => chunkIndex >= range.start && chunkIndex < range.end); + const shouldYield = chunkIndex >= term.range.start && chunkIndex < term.range.end; + const minRefCountToStore = shouldYield ? 2 : 1; + let stored = false; + + // Assuming non-overlapping fetch_info ranges for the same hash + if (range && range.refCount >= minRefCountToStore) { + range.data ??= []; + range.data.push(uncompressed); + stored = true; + } + + if (shouldYield) { + if (readBytesToSkip) { + const skipped = Math.min(readBytesToSkip, uncompressed.byteLength); + uncompressed = uncompressed.slice(readBytesToSkip); + readBytesToSkip -= skipped; + } + + if (uncompressed.byteLength > maxBytes - totalBytesRead) { + uncompressed = uncompressed.slice(0, maxBytes - totalBytesRead); + } + + if (uncompressed.byteLength) { + log( + "yield", + uncompressed.byteLength, + "bytes", + result.value.byteLength, + "total read", + totalBytesRead, + stored + ); + totalBytesRead += uncompressed.byteLength; + yield stored ? uncompressed.slice() : uncompressed; + listener?.({ event: "progress", progress: { read: totalBytesRead, total: maxBytes } }); + } + } + + chunkIndex++; + result.value = result.value.slice(chunkHeader.compressed_length); + } + } + + if ( + done && + totalBytesRead < maxBytes && + totalFetchBytes < fetchInfo.url_range.end - fetchInfo.url_range.start + 1 + ) { + log("done", done, "total read", totalBytesRead, maxBytes, totalFetchBytes); + log("failed to fetch all data for term", term.hash); + throw new Error( + `Failed to fetch all data for term ${term.hash}, fetched ${totalFetchBytes} bytes out of ${ + fetchInfo.url_range.end - fetchInfo.url_range.start + 1 + }` + ); + } + + log("done", done, "total read", totalBytesRead, maxBytes, totalFetchBytes); + + // Release the reader + log("cancel reader"); + await reader.cancel(); + } + } + + const iterator = readData( + this.reconstructionInfo, + this.fetch, + this.end - this.start, + this.#loadReconstructionInfo.bind(this) + ); + + // todo: when Chrome/Safari support it, use ReadableStream.from(readData) + return new ReadableStream( + { + // todo: when Safari supports it, type controller as ReadableByteStreamController + async pull(controller) { + const result = await iterator.next(); + + if (result.value) { + controller.enqueue(result.value); + } + + if (result.done) { + controller.close(); + } + }, + type: "bytes", + // todo: when Safari supports it, add autoAllocateChunkSize param + }, + // todo : use ByteLengthQueuingStrategy when there's good support for it, currently in Node.js it fails due to size being a function + { + highWaterMark: 1_000, // 1_000 chunks for ~1MB of RAM + } + ); + } + + override async arrayBuffer(): Promise { + const result = await this.#fetch(); + + return new Response(result).arrayBuffer(); + } + + override async text(): Promise { + const result = await this.#fetch(); + + return new Response(result).text(); + } + + async response(): Promise { + const result = await this.#fetch(); + + return new Response(result); + } + + override stream(): ReturnType { + const stream = new TransformStream(); + + this.#fetch() + .then((response) => response.pipeThrough(stream)) + .catch((error) => stream.writable.abort(error.message)); + + return stream.readable; + } +} + +const jwtPromises: Map> = new Map(); +/** + * Cache to store JWTs, to avoid making many auth requests when downloading multiple files from the same repo + */ +const jwts: Map< + string, + { + accessToken: string; + expiresAt: Date; + casUrl: string; + } +> = new Map(); + +function cacheKey(params: { refreshUrl: string; initialAccessToken: string | undefined }): string { + return JSON.stringify([params.refreshUrl, params.initialAccessToken]); +} + +// exported for testing purposes +export function bg4_regoup_bytes(bytes: Uint8Array): Uint8Array { + // python code + + // split = len(x) // 4 + // rem = len(x) % 4 + // g1_pos = split + (1 if rem >= 1 else 0) + // g2_pos = g1_pos + split + (1 if rem >= 2 else 0) + // g3_pos = g2_pos + split + (1 if rem == 3 else 0) + // ret = bytearray(len(x)) + // ret[0::4] = x[:g1_pos] + // ret[1::4] = x[g1_pos:g2_pos] + // ret[2::4] = x[g2_pos:g3_pos] + // ret[3::4] = x[g3_pos:] + + // todo: optimize to do it in-place + + const split = Math.floor(bytes.byteLength / 4); + const rem = bytes.byteLength % 4; + const g1_pos = split + (rem >= 1 ? 1 : 0); + const g2_pos = g1_pos + split + (rem >= 2 ? 1 : 0); + const g3_pos = g2_pos + split + (rem == 3 ? 1 : 0); + + const ret = new Uint8Array(bytes.byteLength); + for (let i = 0, j = 0; i < bytes.byteLength; i += 4, j++) { + ret[i] = bytes[j]; + } + + for (let i = 1, j = g1_pos; i < bytes.byteLength; i += 4, j++) { + ret[i] = bytes[j]; + } + + for (let i = 2, j = g2_pos; i < bytes.byteLength; i += 4, j++) { + ret[i] = bytes[j]; + } + + for (let i = 3, j = g3_pos; i < bytes.byteLength; i += 4, j++) { + ret[i] = bytes[j]; + } + + return ret; + + // alternative implementation (to benchmark which one is faster) + // for (let i = 0; i < bytes.byteLength - 3; i += 4) { + // ret[i] = bytes[i / 4]; + // ret[i + 1] = bytes[g1_pos + i / 4]; + // ret[i + 2] = bytes[g2_pos + i / 4]; + // ret[i + 3] = bytes[g3_pos + i / 4]; + // } + + // if (rem === 1) { + // ret[bytes.byteLength - 1] = bytes[g1_pos - 1]; + // } else if (rem === 2) { + // ret[bytes.byteLength - 2] = bytes[g1_pos - 1]; + // ret[bytes.byteLength - 1] = bytes[g2_pos - 1]; + // } else if (rem === 3) { + // ret[bytes.byteLength - 3] = bytes[g1_pos - 1]; + // ret[bytes.byteLength - 2] = bytes[g2_pos - 1]; + // ret[bytes.byteLength - 1] = bytes[g3_pos - 1]; + // } +} + +async function getAccessToken( + initialAccessToken: string | undefined, + customFetch: typeof fetch, + refreshUrl: string +): Promise<{ accessToken: string; casUrl: string }> { + const key = cacheKey({ refreshUrl, initialAccessToken }); + + const jwt = jwts.get(key); + + if (jwt && jwt.expiresAt > new Date(Date.now() + JWT_SAFETY_PERIOD)) { + return { accessToken: jwt.accessToken, casUrl: jwt.casUrl }; + } + + // If we already have a promise for this repo, return it + const existingPromise = jwtPromises.get(key); + if (existingPromise) { + return existingPromise; + } + + const promise = (async () => { + const resp = await customFetch(refreshUrl, { + headers: { + ...(initialAccessToken + ? { + Authorization: `Bearer ${initialAccessToken}`, + } + : {}), + }, + }); + + if (!resp.ok) { + throw new Error(`Failed to get JWT token: ${resp.status} ${await resp.text()}`); + } + + const json: { accessToken: string; casUrl: string; exp: number } = await resp.json(); + const jwt = { + accessToken: json.accessToken, + expiresAt: new Date(json.exp * 1000), + initialAccessToken, + refreshUrl, + casUrl: json.casUrl, + }; + + jwtPromises.delete(key); + + for (const [key, value] of jwts.entries()) { + if (value.expiresAt < new Date(Date.now() + JWT_SAFETY_PERIOD)) { + jwts.delete(key); + } else { + break; + } + } + if (jwts.size >= JWT_CACHE_SIZE) { + const keyToDelete = jwts.keys().next().value; + if (keyToDelete) { + jwts.delete(keyToDelete); + } + } + jwts.set(key, jwt); + + return { + accessToken: json.accessToken, + casUrl: json.casUrl, + }; + })(); + + jwtPromises.set(key, promise); + + return promise; +} diff --git a/utils/base64FromBytes.ts b/utils/base64FromBytes.ts new file mode 100644 index 0000000000000000000000000000000000000000..5327bbfe25838372cc7e25123d1cc9beaf80ceda --- /dev/null +++ b/utils/base64FromBytes.ts @@ -0,0 +1,11 @@ +export function base64FromBytes(arr: Uint8Array): string { + if (globalThis.Buffer) { + return globalThis.Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} diff --git a/utils/checkCredentials.ts b/utils/checkCredentials.ts new file mode 100644 index 0000000000000000000000000000000000000000..0e1717054b5dd1d03f4bd0b9c8f986a2ea4560f7 --- /dev/null +++ b/utils/checkCredentials.ts @@ -0,0 +1,18 @@ +import type { CredentialsParams } from "../types/public"; + +export function checkAccessToken(accessToken: string): void { + if (!accessToken.startsWith("hf_")) { + throw new TypeError("Your access token must start with 'hf_'"); + } +} + +export function checkCredentials(params: Partial): string | undefined { + if (params.accessToken) { + checkAccessToken(params.accessToken); + return params.accessToken; + } + if (params.credentials?.accessToken) { + checkAccessToken(params.credentials.accessToken); + return params.credentials.accessToken; + } +} diff --git a/utils/chunk.ts b/utils/chunk.ts new file mode 100644 index 0000000000000000000000000000000000000000..20718cc89e5e32fd99d99be93dead0e4e24949f4 --- /dev/null +++ b/utils/chunk.ts @@ -0,0 +1,25 @@ +import { range } from "./range"; + +/** + * Chunk array into arrays of length at most `chunkSize` + * + * @param chunkSize must be greater than or equal to 1 + */ +export function chunk(arr: T, chunkSize: number): T[] { + if (isNaN(chunkSize) || chunkSize < 1) { + throw new RangeError("Invalid chunk size: " + chunkSize); + } + + if (!arr.length) { + return []; + } + + /// Small optimization to not chunk buffers unless needed + if (arr.length <= chunkSize) { + return [arr]; + } + + return range(Math.ceil(arr.length / chunkSize)).map((i) => { + return arr.slice(i * chunkSize, (i + 1) * chunkSize); + }) as T[]; +} diff --git a/utils/createBlob.ts b/utils/createBlob.ts new file mode 100644 index 0000000000000000000000000000000000000000..5d5f200a66ec1946d2a770cfd7ccde1b9f724ec4 --- /dev/null +++ b/utils/createBlob.ts @@ -0,0 +1,30 @@ +import { WebBlob } from "./WebBlob"; +import { isFrontend } from "./isFrontend"; + +/** + * This function allow to retrieve either a FileBlob or a WebBlob from a URL. + * + * From the backend: + * - support local files + * - support http resources with absolute URLs + * + * From the frontend: + * - support http resources with absolute or relative URLs + */ +export async function createBlob(url: URL, opts?: { fetch?: typeof fetch; accessToken?: string }): Promise { + if (url.protocol === "http:" || url.protocol === "https:") { + return WebBlob.create(url, { fetch: opts?.fetch, accessToken: opts?.accessToken }); + } + + if (isFrontend) { + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); + } + + if (url.protocol === "file:") { + const { FileBlob } = await import("./FileBlob"); + + return FileBlob.create(url); + } + + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); +} diff --git a/utils/createBlobs.ts b/utils/createBlobs.ts new file mode 100644 index 0000000000000000000000000000000000000000..625bef5fde3f53c9c308ea613d880a609db5db37 --- /dev/null +++ b/utils/createBlobs.ts @@ -0,0 +1,51 @@ +import { WebBlob } from "./WebBlob"; +import { isFrontend } from "./isFrontend"; + +/** + * This function allow to retrieve either a FileBlob or a WebBlob from a URL. + * + * From the backend: + * - support local files + * - support local folders + * - support http resources with absolute URLs + * + * From the frontend: + * - support http resources with absolute or relative URLs + */ +export async function createBlobs( + url: URL, + destPath: string, + opts?: { fetch?: typeof fetch; maxFolderDepth?: number; accessToken?: string } +): Promise> { + if (url.protocol === "http:" || url.protocol === "https:") { + const blob = await WebBlob.create(url, { fetch: opts?.fetch, accessToken: opts?.accessToken }); + return [{ path: destPath, blob }]; + } + + if (isFrontend) { + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); + } + + if (url.protocol === "file:") { + const { FileBlob } = await import("./FileBlob"); + const { subPaths } = await import("./sub-paths"); + const paths = await subPaths(url, opts?.maxFolderDepth); + + if (paths.length === 1 && paths[0].relativePath === ".") { + const blob = await FileBlob.create(url); + return [{ path: destPath, blob }]; + } + + return Promise.all( + paths.map(async (path) => ({ + path: `${destPath}/${path.relativePath}` + .replace(/\/[.]$/, "") + .replaceAll("//", "/") + .replace(/^[.]?\//, ""), + blob: await FileBlob.create(new URL(path.path)), + })) + ); + } + + throw new TypeError(`Unsupported URL protocol "${url.protocol}"`); +} diff --git a/utils/eventToGenerator.spec.ts b/utils/eventToGenerator.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..59ed182a4a2394f6662afeee1d4f42fe4692e0b5 --- /dev/null +++ b/utils/eventToGenerator.spec.ts @@ -0,0 +1,44 @@ +import { describe, expect, it } from "vitest"; +import { eventToGenerator } from "./eventToGenerator"; + +describe("eventToGenerator", () => { + it("should handle synchronous events", async () => { + const it = eventToGenerator((yieldCallback, returnCallback) => { + yieldCallback(1); + yieldCallback(2); + returnCallback(3); + }); + + const results = []; + let res: IteratorResult; + do { + res = await it.next(); + if (!res.done) { + results.push(res.value); + } + } while (!res.done); + + expect(results).toEqual([1, 2]); + expect(res.value).toBe(3); + }); + + it("should handle asynchronous events", async () => { + const it = eventToGenerator((yieldCallback, returnCallback) => { + setTimeout(() => yieldCallback(1), 100); + setTimeout(() => yieldCallback(2), 200); + setTimeout(() => returnCallback(3), 300); + }); + + const results = []; + let res: IteratorResult; + do { + res = await it.next(); + if (!res.done) { + results.push(res.value); + } + } while (!res.done); + + expect(results).toEqual([1, 2]); + expect(res.value).toBe(3); + }); +}); diff --git a/utils/eventToGenerator.ts b/utils/eventToGenerator.ts new file mode 100644 index 0000000000000000000000000000000000000000..a06e503b69cd2ef656e99b4d60260afb60dbea60 --- /dev/null +++ b/utils/eventToGenerator.ts @@ -0,0 +1,64 @@ +export async function* eventToGenerator( + cb: ( + yieldCallback: (y: YieldType) => void, + returnCallback: (r: ReturnType) => void, + rejectCallack: (reason: unknown) => void + ) => unknown +): AsyncGenerator { + const promises: Array<{ + p: Promise<{ done: true; value: ReturnType } | { done: false; value: YieldType }>; + resolve: (value: { done: true; value: ReturnType } | { done: false; value: YieldType }) => void; + reject: (reason?: unknown) => void; + }> = []; + + function addPromise() { + let resolve: (value: { done: true; value: ReturnType } | { done: false; value: YieldType }) => void; + let reject: (reason?: unknown) => void; + const p = new Promise<{ done: true; value: ReturnType } | { done: false; value: YieldType }>((res, rej) => { + resolve = res; + reject = rej; + }); + // @ts-expect-error TS doesn't know that promise callback is executed immediately + promises.push({ p, resolve, reject }); + } + + addPromise(); + + const callbackRes = Promise.resolve() + .then(() => + cb( + (y) => { + addPromise(); + promises.at(-2)?.resolve({ done: false, value: y }); + }, + (r) => { + addPromise(); + promises.at(-2)?.resolve({ done: true, value: r }); + }, + (err) => promises.shift()?.reject(err) + ) + ) + .catch((err) => promises.shift()?.reject(err)); + + while (1) { + const p = promises[0]; + if (!p) { + throw new Error("Logic error in eventGenerator, promises should never be empty"); + } + const result = await p.p; + promises.shift(); + if (result.done) { + await callbackRes; // Clean up, may be removed in the future + // // Cleanup promises - shouldn't be needed due to above await + // for (const promise of promises) { + // promise.resolve(result); + // await promise.p; + // } + return result.value; + } + yield result.value; + } + + // So TS doesn't complain + throw new Error("Unreachable"); +} diff --git a/utils/hexFromBytes.ts b/utils/hexFromBytes.ts new file mode 100644 index 0000000000000000000000000000000000000000..a6c331c09525a0906333e7007e5724858795aa35 --- /dev/null +++ b/utils/hexFromBytes.ts @@ -0,0 +1,11 @@ +export function hexFromBytes(arr: Uint8Array): string { + if (globalThis.Buffer) { + return globalThis.Buffer.from(arr).toString("hex"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(byte.toString(16).padStart(2, "0")); + }); + return bin.join(""); + } +} diff --git a/utils/insecureRandomString.ts b/utils/insecureRandomString.ts new file mode 100644 index 0000000000000000000000000000000000000000..f9954d431b26374eb47d4b81f10a5837990a6e90 --- /dev/null +++ b/utils/insecureRandomString.ts @@ -0,0 +1,3 @@ +export function insecureRandomString(): string { + return Math.random().toString(36).slice(2); +} diff --git a/utils/isBackend.ts b/utils/isBackend.ts new file mode 100644 index 0000000000000000000000000000000000000000..1e6f27998645f2971dcdd92503d78de521273a26 --- /dev/null +++ b/utils/isBackend.ts @@ -0,0 +1,6 @@ +const isBrowser = typeof window !== "undefined" && typeof window.document !== "undefined"; + +const isWebWorker = + typeof self === "object" && self.constructor && self.constructor.name === "DedicatedWorkerGlobalScope"; + +export const isBackend = !isBrowser && !isWebWorker; diff --git a/utils/isFrontend.ts b/utils/isFrontend.ts new file mode 100644 index 0000000000000000000000000000000000000000..0b9bab392e71f315704c210bc0e8ff210379703d --- /dev/null +++ b/utils/isFrontend.ts @@ -0,0 +1,3 @@ +import { isBackend } from "./isBackend"; + +export const isFrontend = !isBackend; diff --git a/utils/omit.ts b/utils/omit.ts new file mode 100644 index 0000000000000000000000000000000000000000..8743dba87f5c18ca8e55509d5a311c2f3a863134 --- /dev/null +++ b/utils/omit.ts @@ -0,0 +1,14 @@ +import { pick } from "./pick"; +import { typedInclude } from "./typedInclude"; + +/** + * Return copy of object, omitting blacklisted array of props + */ +export function omit, K extends keyof T>( + o: T, + props: K[] | K +): Pick> { + const propsArr = Array.isArray(props) ? props : [props]; + const letsKeep = (Object.keys(o) as (keyof T)[]).filter((prop) => !typedInclude(propsArr, prop)); + return pick(o, letsKeep); +} diff --git a/utils/parseLinkHeader.ts b/utils/parseLinkHeader.ts new file mode 100644 index 0000000000000000000000000000000000000000..6939a89be6701240097f8b2f7f6c581b4ab6c739 --- /dev/null +++ b/utils/parseLinkHeader.ts @@ -0,0 +1,8 @@ +/** + * Parse Link HTTP header, eg `; rel="next"` + */ +export function parseLinkHeader(header: string): Record { + const regex = /<(https?:[/][/][^>]+)>;\s+rel="([^"]+)"/g; + + return Object.fromEntries([...header.matchAll(regex)].map(([, url, rel]) => [rel, url])); +} diff --git a/utils/pick.ts b/utils/pick.ts new file mode 100644 index 0000000000000000000000000000000000000000..bd32e4532ef41a475a946e94ea6e000ddf7422a0 --- /dev/null +++ b/utils/pick.ts @@ -0,0 +1,13 @@ +/** + * Return copy of object, only keeping whitelisted properties. + */ +export function pick(o: T, props: K[] | ReadonlyArray): Pick { + return Object.assign( + {}, + ...props.map((prop) => { + if (o[prop] !== undefined) { + return { [prop]: o[prop] }; + } + }) + ); +} diff --git a/utils/promisesQueue.spec.ts b/utils/promisesQueue.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..3e9ea6d124ff5cc290eb6aefffa4ce8148b3c46a --- /dev/null +++ b/utils/promisesQueue.spec.ts @@ -0,0 +1,48 @@ +import { describe, expect, it } from "vitest"; +import { promisesQueue } from "./promisesQueue"; + +describe("promisesQueue", () => { + it("should handle multiple errors without triggering an uncaughtException", async () => { + const factories = [ + () => Promise.reject(new Error("error 1")), + () => Promise.reject(new Error("error 2")), + () => Promise.reject(new Error("error 3")), + ]; + + try { + await promisesQueue(factories, 10); + } catch (err) { + if (!(err instanceof Error)) { + throw err; + } + } + + try { + await promisesQueue(factories, 1); + } catch (err) { + if (!(err instanceof Error)) { + throw err; + } + expect(err.message).toBe("error 1"); + } + }); + + it("should return ordered results", async () => { + const factories = [ + () => Promise.resolve(1), + () => Promise.resolve(2), + () => Promise.resolve(3), + () => Promise.resolve(4), + () => Promise.resolve(5), + () => Promise.resolve(6), + () => Promise.resolve(7), + () => Promise.resolve(8), + () => Promise.resolve(9), + () => Promise.resolve(10), + ]; + + const results = await promisesQueue(factories, 3); + + expect(results).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + }); +}); diff --git a/utils/promisesQueue.ts b/utils/promisesQueue.ts new file mode 100644 index 0000000000000000000000000000000000000000..35d2d06907e3aad537e96f99509297f79d5ee547 --- /dev/null +++ b/utils/promisesQueue.ts @@ -0,0 +1,23 @@ +/** + * Execute queue of promises. + * + * Inspired by github.com/rxaviers/async-pool + */ +export async function promisesQueue(factories: (() => Promise)[], concurrency: number): Promise { + const results: T[] = []; + const executing: Set> = new Set(); + let index = 0; + for (const factory of factories) { + const closureIndex = index++; + const e = factory().then((r) => { + results[closureIndex] = r; + executing.delete(e); + }); + executing.add(e); + if (executing.size >= concurrency) { + await Promise.race(executing); + } + } + await Promise.all(executing); + return results; +} diff --git a/utils/promisesQueueStreaming.ts b/utils/promisesQueueStreaming.ts new file mode 100644 index 0000000000000000000000000000000000000000..cbe966276a88f1d0a3fe0882003d052f12f976ed --- /dev/null +++ b/utils/promisesQueueStreaming.ts @@ -0,0 +1,25 @@ +/** + * Execute queue of promises in a streaming fashion. + * + * Optimized for streaming: + * - Expects an iterable as input + * - Does not return a list of all results + * + * Inspired by github.com/rxaviers/async-pool + */ +export async function promisesQueueStreaming( + factories: AsyncIterable<() => Promise> | Iterable<() => Promise>, + concurrency: number +): Promise { + const executing: Promise[] = []; + for await (const factory of factories) { + const e = factory().then(() => { + executing.splice(executing.indexOf(e), 1); + }); + executing.push(e); + if (executing.length >= concurrency) { + await Promise.race(executing); + } + } + await Promise.all(executing); +} diff --git a/utils/range.ts b/utils/range.ts new file mode 100644 index 0000000000000000000000000000000000000000..d7ebababf2ba06366899f5d192b08e849b708f9a --- /dev/null +++ b/utils/range.ts @@ -0,0 +1,13 @@ +/** + * One param: create list of integers from 0 (inclusive) to n (exclusive) + * Two params: create list of integers from a (inclusive) to b (exclusive) + */ +export function range(n: number, b?: number): number[] { + return b + ? Array(b - n) + .fill(0) + .map((_, i) => n + i) + : Array(n) + .fill(0) + .map((_, i) => i); +} diff --git a/utils/sha256-node.ts b/utils/sha256-node.ts new file mode 100644 index 0000000000000000000000000000000000000000..b068d1a218a73e99512478733d71e754de46bbd2 --- /dev/null +++ b/utils/sha256-node.ts @@ -0,0 +1,26 @@ +import { Readable } from "node:stream"; +import type { ReadableStream } from "node:stream/web"; +import { createHash } from "node:crypto"; + +export async function* sha256Node( + buffer: ArrayBuffer | Blob, + opts?: { + abortSignal?: AbortSignal; + } +): AsyncGenerator { + const sha256Stream = createHash("sha256"); + const size = buffer instanceof Blob ? buffer.size : buffer.byteLength; + let done = 0; + const readable = + buffer instanceof Blob ? Readable.fromWeb(buffer.stream() as ReadableStream) : Readable.from(Buffer.from(buffer)); + + for await (const buffer of readable) { + sha256Stream.update(buffer); + done += buffer.length; + yield done / size; + + opts?.abortSignal?.throwIfAborted(); + } + + return sha256Stream.digest("hex"); +} diff --git a/utils/sha256.spec.ts b/utils/sha256.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..8e62b936b60d94debe9b91103f30eeedc1826712 --- /dev/null +++ b/utils/sha256.spec.ts @@ -0,0 +1,50 @@ +import { describe, it, expect } from "vitest"; +import { sha256 } from "./sha256"; + +const smallContent = "hello world"; +const smallContentSHA256 = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"; +const bigContent = "O123456789".repeat(100_000); +const bigContentSHA256 = "a3bbce7ee1df7233d85b5f4d60faa3755f93f537804f8b540c72b0739239ddf8"; +const biggerContent = "0123456789".repeat(1_000_000); +const biggerContentSHA256 = "d52fcc26b48dbd4d79b125eb0a29b803ade07613c67ac7c6f2751aefef008486"; + +describe("sha256", () => { + async function calcSHA256(content: string, useWebWorker: boolean) { + const iterator = sha256(new Blob([content]), { useWebWorker }); + let res: IteratorResult; + do { + res = await iterator.next(); + } while (!res.done); + return res.value; + } + + it("Calculate hash of a small file", async () => { + const sha = await calcSHA256(smallContent, false); + expect(sha).toBe(smallContentSHA256); + }); + + it("Calculate hash of a big file", async () => { + const sha = await calcSHA256(bigContent, false); + expect(sha).toBe(bigContentSHA256); + }); + + it("Calculate hash of a bigger file", async () => { + const sha = await calcSHA256(biggerContent, false); + expect(sha).toBe(biggerContentSHA256); + }); + + it("Calculate hash of a small file (+ web worker)", async () => { + const sha = await calcSHA256(smallContent, true); + expect(sha).toBe(smallContentSHA256); + }); + + it("Calculate hash of a big file (+ web worker)", async () => { + const sha = await calcSHA256(bigContent, true); + expect(sha).toBe(bigContentSHA256); + }); + + it("Calculate hash of a bigger file (+ web worker)", async () => { + const sha = await calcSHA256(biggerContent, true); + expect(sha).toBe(biggerContentSHA256); + }); +}); diff --git a/utils/sha256.ts b/utils/sha256.ts new file mode 100644 index 0000000000000000000000000000000000000000..d432909b34d7cb4046748a5c353f5875d97f5939 --- /dev/null +++ b/utils/sha256.ts @@ -0,0 +1,166 @@ +import { eventToGenerator } from "./eventToGenerator"; +import { hexFromBytes } from "./hexFromBytes"; +import { isFrontend } from "./isFrontend"; + +async function getWebWorkerCode() { + const sha256Module = await import("../vendor/hash-wasm/sha256-wrapper"); + return URL.createObjectURL(new Blob([sha256Module.createSHA256WorkerCode()])); +} + +const pendingWorkers: Worker[] = []; +const runningWorkers: Set = new Set(); + +let resolve: () => void; +let waitPromise: Promise = new Promise((r) => { + resolve = r; +}); + +async function getWorker(poolSize?: number): Promise { + { + const worker = pendingWorkers.pop(); + if (worker) { + runningWorkers.add(worker); + return worker; + } + } + if (!poolSize) { + const worker = new Worker(await getWebWorkerCode()); + runningWorkers.add(worker); + return worker; + } + + if (poolSize <= 0) { + throw new TypeError("Invalid webworker pool size: " + poolSize); + } + + while (runningWorkers.size >= poolSize) { + await waitPromise; + } + + const worker = new Worker(await getWebWorkerCode()); + runningWorkers.add(worker); + return worker; +} + +async function freeWorker(worker: Worker, poolSize: number | undefined): Promise { + if (!poolSize) { + return destroyWorker(worker); + } + runningWorkers.delete(worker); + pendingWorkers.push(worker); + const r = resolve; + waitPromise = new Promise((r) => { + resolve = r; + }); + r(); +} + +function destroyWorker(worker: Worker): void { + runningWorkers.delete(worker); + worker.terminate(); + const r = resolve; + waitPromise = new Promise((r) => { + resolve = r; + }); + r(); +} + +/** + * @returns hex-encoded sha + * @yields progress (0-1) + */ +export async function* sha256( + buffer: Blob, + opts?: { useWebWorker?: boolean | { minSize?: number; poolSize?: number }; abortSignal?: AbortSignal } +): AsyncGenerator { + yield 0; + + const maxCryptoSize = + typeof opts?.useWebWorker === "object" && opts?.useWebWorker.minSize !== undefined + ? opts.useWebWorker.minSize + : 10_000_000; + if (buffer.size < maxCryptoSize && globalThis.crypto?.subtle) { + const res = hexFromBytes( + new Uint8Array( + await globalThis.crypto.subtle.digest("SHA-256", buffer instanceof Blob ? await buffer.arrayBuffer() : buffer) + ) + ); + + yield 1; + + return res; + } + + if (isFrontend) { + if (opts?.useWebWorker) { + try { + const poolSize = typeof opts?.useWebWorker === "object" ? opts.useWebWorker.poolSize : undefined; + const worker = await getWorker(poolSize); + return yield* eventToGenerator((yieldCallback, returnCallback, rejectCallack) => { + worker.addEventListener("message", (event) => { + if (event.data.sha256) { + freeWorker(worker, poolSize); + returnCallback(event.data.sha256); + } else if (event.data.progress) { + yieldCallback(event.data.progress); + + try { + opts.abortSignal?.throwIfAborted(); + } catch (err) { + destroyWorker(worker); + rejectCallack(err); + } + } else { + destroyWorker(worker); + rejectCallack(event); + } + }); + worker.addEventListener("error", (event) => { + destroyWorker(worker); + rejectCallack(event.error); + }); + worker.postMessage({ file: buffer }); + }); + } catch (err) { + console.warn("Failed to use web worker for sha256", err); + } + } + if (!wasmModule) { + wasmModule = await import("../vendor/hash-wasm/sha256-wrapper"); + } + + const sha256 = await wasmModule.createSHA256(); + sha256.init(); + + const reader = buffer.stream().getReader(); + const total = buffer.size; + let bytesDone = 0; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + sha256.update(value); + bytesDone += value.length; + yield bytesDone / total; + + opts?.abortSignal?.throwIfAborted(); + } + + return sha256.digest("hex"); + } + + if (!cryptoModule) { + cryptoModule = await import("./sha256-node"); + } + + return yield* cryptoModule.sha256Node(buffer, { abortSignal: opts?.abortSignal }); +} + +// eslint-disable-next-line @typescript-eslint/consistent-type-imports +let cryptoModule: typeof import("./sha256-node"); +// eslint-disable-next-line @typescript-eslint/consistent-type-imports +let wasmModule: typeof import("../vendor/hash-wasm/sha256-wrapper"); diff --git a/utils/sub-paths.spec.ts b/utils/sub-paths.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..6dcb773ab7f106e8bcf1d09751ff01b1e0daceba --- /dev/null +++ b/utils/sub-paths.spec.ts @@ -0,0 +1,39 @@ +import { mkdir, writeFile } from "fs/promises"; +import { tmpdir } from "os"; +import { describe, expect, it } from "vitest"; +import { subPaths } from "./sub-paths"; +import { pathToFileURL } from "url"; + +describe("sub-paths", () => { + it("should retrieve all sub-paths of a directory", async () => { + const tmpDir = tmpdir(); + + await mkdir(`${tmpDir}/test-dir/sub`, { recursive: true }); + + await writeFile(`${tmpDir}/test-dir/sub/file1.txt`, "file1"); + await writeFile(`${tmpDir}/test-dir/sub/file2.txt`, "file2"); + await writeFile(`${tmpDir}/test-dir/file3.txt`, "file3"); + await writeFile(`${tmpDir}/test-dir/file4.txt`, "file4"); + const result = await subPaths(pathToFileURL(`${tmpDir}/test-dir`)); + + expect(result).toEqual([ + { + path: pathToFileURL(`${tmpDir}/test-dir/file3.txt`), + relativePath: "file3.txt", + }, + { + path: pathToFileURL(`${tmpDir}/test-dir/file4.txt`), + relativePath: "file4.txt", + }, + + { + path: pathToFileURL(`${tmpDir}/test-dir/sub/file1.txt`), + relativePath: "sub/file1.txt", + }, + { + path: pathToFileURL(`${tmpDir}/test-dir/sub/file2.txt`), + relativePath: "sub/file2.txt", + }, + ]); + }); +}); diff --git a/utils/sub-paths.ts b/utils/sub-paths.ts new file mode 100644 index 0000000000000000000000000000000000000000..15682c14f854cc4fda95ae997aeebf96a5029257 --- /dev/null +++ b/utils/sub-paths.ts @@ -0,0 +1,38 @@ +import { readdir, stat } from "node:fs/promises"; +import { fileURLToPath, pathToFileURL } from "node:url"; + +/** + * Recursively retrieves all sub-paths of a given directory up to a specified depth. + */ +export async function subPaths( + path: URL, + maxDepth = 10 +): Promise< + Array<{ + path: URL; + relativePath: string; + }> +> { + const state = await stat(path); + if (!state.isDirectory()) { + return [{ path, relativePath: "." }]; + } + + const files = await readdir(path, { withFileTypes: true }); + const ret: Array<{ path: URL; relativePath: string }> = []; + for (const file of files) { + const filePath = pathToFileURL(fileURLToPath(path) + "/" + file.name); + if (file.isDirectory()) { + ret.push( + ...(await subPaths(filePath, maxDepth - 1)).map((subPath) => ({ + ...subPath, + relativePath: `${file.name}/${subPath.relativePath}`, + })) + ); + } else { + ret.push({ path: filePath, relativePath: file.name }); + } + } + + return ret; +} diff --git a/utils/sum.ts b/utils/sum.ts new file mode 100644 index 0000000000000000000000000000000000000000..9d3fe6f15960f0f373a4b8669cace09c41c4b02f --- /dev/null +++ b/utils/sum.ts @@ -0,0 +1,6 @@ +/** + * Sum of elements in array + */ +export function sum(arr: number[]): number { + return arr.reduce((a, b) => a + b, 0); +} diff --git a/utils/symlink.spec.ts b/utils/symlink.spec.ts new file mode 100644 index 0000000000000000000000000000000000000000..56edaf4f33c170389d4799cdec7b9c59686f1a63 --- /dev/null +++ b/utils/symlink.spec.ts @@ -0,0 +1,89 @@ +/* eslint-disable @typescript-eslint/consistent-type-imports */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, expect, it, vi } from "vitest"; +import { createSymlink } from "./symlink"; +import { readFileSync, writeFileSync } from "node:fs"; +import { lstat, rm } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; + +let failSymlink = false; +vi.mock("node:fs/promises", async (importOriginal) => ({ + ...(await importOriginal()), + symlink: async (...args: any[]) => { + if (failSymlink) { + failSymlink = false; + throw new Error("Symlink not supported"); + } + + // @ts-expect-error - ignore + return (await importOriginal()).symlink(...args); + }, +})); + +describe("utils/symlink", () => { + it("should create a symlink", async () => { + writeFileSync(join(tmpdir(), "test.txt"), "hello world"); + await createSymlink({ + sourcePath: join(tmpdir(), "test.txt"), + finalPath: join(tmpdir(), "test-symlink.txt"), + }); + + const stats = await lstat(join(tmpdir(), "test-symlink.txt")); + expect(stats.isSymbolicLink()).toBe(process.platform !== "win32"); + + // Test file content + const content = readFileSync(join(tmpdir(), "test-symlink.txt"), "utf8"); + expect(content).toBe("hello world"); + + // Cleanup + await rm(join(tmpdir(), "test-symlink.txt")); + await rm(join(tmpdir(), "test.txt")); + }); + + it("should work when symlinking twice", async () => { + writeFileSync(join(tmpdir(), "test.txt"), "hello world"); + writeFileSync(join(tmpdir(), "test2.txt"), "hello world2"); + await createSymlink({ + sourcePath: join(tmpdir(), "test.txt"), + finalPath: join(tmpdir(), "test-symlink.txt"), + }); + await createSymlink({ + sourcePath: join(tmpdir(), "test2.txt"), + finalPath: join(tmpdir(), "test-symlink.txt"), + }); + + const stats = await lstat(join(tmpdir(), "test-symlink.txt")); + expect(stats.isSymbolicLink()).toBe(process.platform !== "win32"); + + // Test file content + const content = readFileSync(join(tmpdir(), "test-symlink.txt"), "utf8"); + expect(content).toBe("hello world2"); + + // Cleanup + await rm(join(tmpdir(), "test-symlink.txt")); + await rm(join(tmpdir(), "test.txt")); + await rm(join(tmpdir(), "test2.txt")); + }); + + it("should work when symlink doesn't work (windows)", async () => { + writeFileSync(join(tmpdir(), "test.txt"), "hello world"); + + failSymlink = true; + await createSymlink({ + sourcePath: join(tmpdir(), "test.txt"), + finalPath: join(tmpdir(), "test-symlink.txt"), + }); + + const stats = await lstat(join(tmpdir(), "test-symlink.txt")); + expect(stats.isSymbolicLink()).toBe(false); + + // Test file content + const content = readFileSync(join(tmpdir(), "test-symlink.txt"), "utf8"); + expect(content).toBe("hello world"); + + // Cleanup + await rm(join(tmpdir(), "test-symlink.txt")); + await rm(join(tmpdir(), "test.txt")); + }); +}); diff --git a/utils/symlink.ts b/utils/symlink.ts new file mode 100644 index 0000000000000000000000000000000000000000..17cedf66ad17629435dcd492238fc9a258068961 --- /dev/null +++ b/utils/symlink.ts @@ -0,0 +1,65 @@ +/** + * Heavily inspired by https://github.com/huggingface/huggingface_hub/blob/fcfd14361bd03f23f82efced1aa65a7cbfa4b922/src/huggingface_hub/file_download.py#L517 + */ + +import * as fs from "node:fs/promises"; +import * as path from "node:path"; +import * as os from "node:os"; + +function expandUser(path: string): string { + if (path.startsWith("~")) { + return path.replace("~", os.homedir()); + } + return path; +} + +/** + * Create a symbolic link named dst pointing to src. + * + * By default, it will try to create a symlink using a relative path. Relative paths have 2 advantages: + * - If the cache_folder is moved (example: back-up on a shared drive), relative paths within the cache folder will + * not break. + * - Relative paths seems to be better handled on Windows. Issue was reported 3 times in less than a week when + * changing from relative to absolute paths. See https://github.com/huggingface/huggingface_hub/issues/1398, + * https://github.com/huggingface/diffusers/issues/2729 and https://github.com/huggingface/transformers/pull/22228. + * NOTE: The issue with absolute paths doesn't happen on admin mode. + * When creating a symlink from the cache to a local folder, it is possible that a relative path cannot be created. + * This happens when paths are not on the same volume. In that case, we use absolute paths. + * + * The result layout looks something like + * └── [ 128] snapshots + * ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + * │ ├── [ 52] README.md -> ../../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + * │ └── [ 76] pytorch_model.bin -> ../../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + * + * If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by + * having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file + * (`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing + * cache, the file is duplicated on the disk. + */ +export async function createSymlink(params: { + /** + * The path to the symlink. + */ + finalPath: string; + /** + * The path the symlink should point to. + */ + sourcePath: string; +}): Promise { + const abs_src = path.resolve(expandUser(params.sourcePath)); + const abs_dst = path.resolve(expandUser(params.finalPath)); + + try { + await fs.rm(abs_dst); + } catch { + // ignore + } + + try { + await fs.symlink(path.relative(path.dirname(abs_dst), abs_src), abs_dst); + } catch { + console.info(`Symlink not supported. Copying file from ${abs_src} to ${abs_dst}`); + await fs.copyFile(abs_src, abs_dst); + } +} diff --git a/utils/toRepoId.ts b/utils/toRepoId.ts new file mode 100644 index 0000000000000000000000000000000000000000..9273266928980d0e07b229c8ac0b224df7d02577 --- /dev/null +++ b/utils/toRepoId.ts @@ -0,0 +1,54 @@ +import type { RepoDesignation, RepoId } from "../types/public"; + +export function toRepoId(repo: RepoDesignation): RepoId { + if (typeof repo !== "string") { + return repo; + } + + if (repo.startsWith("model/") || repo.startsWith("models/")) { + throw new TypeError( + "A repo designation for a model should not start with 'models/', directly specify the model namespace / name" + ); + } + + if (repo.startsWith("space/")) { + throw new TypeError("Spaces should start with 'spaces/', plural, not 'space/'"); + } + + if (repo.startsWith("dataset/")) { + throw new TypeError("Datasets should start with 'dataset/', plural, not 'dataset/'"); + } + + const slashes = repo.split("/").length - 1; + + if (repo.startsWith("spaces/")) { + if (slashes !== 2) { + throw new TypeError("Space Id must include namespace and name of the space"); + } + + return { + type: "space", + name: repo.slice("spaces/".length), + }; + } + + if (repo.startsWith("datasets/")) { + if (slashes > 2) { + throw new TypeError("Too many slashes in repo designation: " + repo); + } + + return { + type: "dataset", + name: repo.slice("datasets/".length), + }; + } + + if (slashes > 1) { + throw new TypeError("Too many slashes in repo designation: " + repo); + } + + return { + type: "model", + name: repo, + }; +} diff --git a/utils/typedEntries.ts b/utils/typedEntries.ts new file mode 100644 index 0000000000000000000000000000000000000000..031ba7daa0cc381ce650224c040f67261fbd56c3 --- /dev/null +++ b/utils/typedEntries.ts @@ -0,0 +1,5 @@ +import type { Entries } from "../vendor/type-fest/entries"; + +export function typedEntries>(obj: T): Entries { + return Object.entries(obj) as Entries; +} diff --git a/utils/typedInclude.ts b/utils/typedInclude.ts new file mode 100644 index 0000000000000000000000000000000000000000..71e2f7a7e111995a744589dd34cb090d9743ea16 --- /dev/null +++ b/utils/typedInclude.ts @@ -0,0 +1,3 @@ +export function typedInclude(arr: readonly T[], v: V): v is T { + return arr.includes(v as T); +} diff --git a/vendor/hash-wasm/LICENSE b/vendor/hash-wasm/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..689ccbe872a76ddfeb0196e0e4b9d60d67480220 --- /dev/null +++ b/vendor/hash-wasm/LICENSE @@ -0,0 +1,39 @@ +MIT License + +Copyright (c) 2020 Dani Biró + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +Embedded C implementations might use other, similarly permissive licenses. +Check the beginning of the files from the /src directory. + +Special thank you to the authors of original C algorithms: + +- Alexander Peslyak +- Aleksey Kravchenko +- Colin Percival +- Stephan Brumme +- Steve Reid +- Samuel Neves +- Solar Designer +- Project Nayuki +- ARM Limited +- Yanbo Li dreamfly281@gmail.com, goldboar@163.comYanbo Li +- Mark Adler +- Yann Collet diff --git a/vendor/hash-wasm/build.sh b/vendor/hash-wasm/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..18424dfbc296bf3418e77fce35fbde5298a4a6a3 --- /dev/null +++ b/vendor/hash-wasm/build.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +CURRENT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +cd $CURRENT_PATH + +# Clean up +docker kill hash-wasm-builder +docker rm hash-wasm-builder + +# Start container +docker run -it -d --name hash-wasm-builder emscripten/emsdk:3.1.55 bash + +# Copy & compile +docker exec hash-wasm-builder bash -c "mkdir /source" +docker cp ./sha256.c hash-wasm-builder:/source +docker exec hash-wasm-builder bash -c "\ + cd /source && \ + emcc sha256.c -o sha256.js -msimd128 -sSINGLE_FILE -sMODULARIZE=1 -sENVIRONMENT=web,worker -sEXPORTED_FUNCTIONS=_Hash_Init,_Hash_Update,_Hash_Final,_GetBufferPtr -sFILESYSTEM=0 -fno-rtti -fno-exceptions -O1 -sMODULARIZE=1 -sEXPORT_ES6=1 \ + " +# Patch "_scriptDir" variable +docker exec hash-wasm-builder bash -c "\ + cd /source && \ + sed -i 's\var _scriptDir\var _unused\g' ./sha256.js && \ + sed -i 's\_scriptDir\false\g' ./sha256.js \ + " + +# Copy back compiled file +docker cp hash-wasm-builder:/source/sha256.js . + + +# Clean up +docker kill hash-wasm-builder +docker rm hash-wasm-builder diff --git a/vendor/hash-wasm/sha256-wrapper.ts b/vendor/hash-wasm/sha256-wrapper.ts new file mode 100644 index 0000000000000000000000000000000000000000..c6ad75b36b0852c1ba4b669b11ca3215812171a0 --- /dev/null +++ b/vendor/hash-wasm/sha256-wrapper.ts @@ -0,0 +1,62 @@ +import WasmModule from "./sha256"; + +export async function createSHA256(isInsideWorker = false): Promise<{ + init(): void; + update(data: Uint8Array): void; + digest(method: "hex"): string; +}> { + const BUFFER_MAX_SIZE = 8 * 1024 * 1024; + const wasm: Awaited> = isInsideWorker + ? // @ts-expect-error WasmModule will be populated inside self object + await self["SHA256WasmModule"]() + : await WasmModule(); + const heap = wasm.HEAPU8.subarray(wasm._GetBufferPtr()); + return { + init() { + wasm._Hash_Init(256); + }, + update(data: Uint8Array) { + let byteUsed = 0; + while (byteUsed < data.byteLength) { + const bytesLeft = data.byteLength - byteUsed; + const length = Math.min(bytesLeft, BUFFER_MAX_SIZE); + heap.set(data.subarray(byteUsed, byteUsed + length)); + wasm._Hash_Update(length); + byteUsed += length; + } + }, + digest(method: "hex") { + if (method !== "hex") { + throw new Error("Only digest hex is supported"); + } + wasm._Hash_Final(); + const result = Array.from(heap.slice(0, 32)); + return result.map((b) => b.toString(16).padStart(2, "0")).join(""); + }, + }; +} + +export function createSHA256WorkerCode(): string { + return ` + self.addEventListener('message', async (event) => { + const { file } = event.data; + const sha256 = await self.createSHA256(true); + sha256.init(); + const reader = file.stream().getReader(); + const total = file.size; + let bytesDone = 0; + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + sha256.update(value); + bytesDone += value.length; + postMessage({ progress: bytesDone / total }); + } + postMessage({ sha256: sha256.digest('hex') }); + }); + self.SHA256WasmModule = ${WasmModule.toString()}; + self.createSHA256 = ${createSHA256.toString()}; + `; +} diff --git a/vendor/hash-wasm/sha256.c b/vendor/hash-wasm/sha256.c new file mode 100644 index 0000000000000000000000000000000000000000..b8c0cb7eaf322cb4f74c42908f2d2542e5894233 --- /dev/null +++ b/vendor/hash-wasm/sha256.c @@ -0,0 +1,432 @@ +/* sha256.c - an implementation of SHA-256/224 hash functions + * based on FIPS 180-3 (Federal Information Processing Standart). + * + * Copyright (c) 2010, Aleksey Kravchenko + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + * AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + + * Modified for hash-wasm by Dani Biró + */ + +#define WITH_BUFFER + + +////////////////////////////////////////////////////////////////////////// + +#include +#include + +#ifndef NULL +#define NULL 0 +#endif + +#ifdef _MSC_VER +#define WASM_EXPORT +#define __inline__ +#else +#define WASM_EXPORT __attribute__((visibility("default"))) +#endif + +#ifdef WITH_BUFFER + +#define MAIN_BUFFER_SIZE 8 * 1024 * 1024 +alignas(128) uint8_t main_buffer[MAIN_BUFFER_SIZE]; + +WASM_EXPORT +uint8_t *Hash_GetBuffer() { + return main_buffer; +} + +#endif + +// Sometimes LLVM emits these functions during the optimization step +// even with -nostdlib -fno-builtin flags +static __inline__ void* memcpy(void* dst, const void* src, uint32_t cnt) { + uint8_t *destination = dst; + const uint8_t *source = src; + while (cnt) { + *(destination++)= *(source++); + --cnt; + } + return dst; +} + +static __inline__ void* memset(void* dst, const uint8_t value, uint32_t cnt) { + uint8_t *p = dst; + while (cnt--) { + *p++ = value; + } + return dst; +} + +static __inline__ void* memcpy2(void* dst, const void* src, uint32_t cnt) { + uint64_t *destination64 = dst; + const uint64_t *source64 = src; + while (cnt >= 8) { + *(destination64++)= *(source64++); + cnt -= 8; + } + + uint8_t *destination = (uint8_t*)destination64; + const uint8_t *source = (uint8_t*)source64; + while (cnt) { + *(destination++)= *(source++); + --cnt; + } + return dst; +} + +static __inline__ void memcpy16(void* dst, const void* src) { + uint64_t* dst64 = (uint64_t*)dst; + uint64_t* src64 = (uint64_t*)src; + + dst64[0] = src64[0]; + dst64[1] = src64[1]; +} + +static __inline__ void memcpy32(void* dst, const void* src) { + uint64_t* dst64 = (uint64_t*)dst; + uint64_t* src64 = (uint64_t*)src; + + #pragma clang loop unroll(full) + for (int i = 0; i < 4; i++) { + dst64[i] = src64[i]; + } +} + +static __inline__ void memcpy64(void* dst, const void* src) { + uint64_t* dst64 = (uint64_t*)dst; + uint64_t* src64 = (uint64_t*)src; + + #pragma clang loop unroll(full) + for (int i = 0; i < 8; i++) { + dst64[i] = src64[i]; + } +} + +static __inline__ uint64_t widen8to64(const uint8_t value) { + return value | (value << 8) | (value << 16) | (value << 24); +} + +static __inline__ void memset16(void* dst, const uint8_t value) { + uint64_t val = widen8to64(value); + uint64_t* dst64 = (uint64_t*)dst; + + dst64[0] = val; + dst64[1] = val; +} + +static __inline__ void memset32(void* dst, const uint8_t value) { + uint64_t val = widen8to64(value); + uint64_t* dst64 = (uint64_t*)dst; + + #pragma clang loop unroll(full) + for (int i = 0; i < 4; i++) { + dst64[i] = val; + } +} + +static __inline__ void memset64(void* dst, const uint8_t value) { + uint64_t val = widen8to64(value); + uint64_t* dst64 = (uint64_t*)dst; + + #pragma clang loop unroll(full) + for (int i = 0; i < 8; i++) { + dst64[i] = val; + } +} + +static __inline__ void memset128(void* dst, const uint8_t value) { + uint64_t val = widen8to64(value); + uint64_t* dst64 = (uint64_t*)dst; + + #pragma clang loop unroll(full) + for (int i = 0; i < 16; i++) { + dst64[i] = val; + } +} + + +////////////////////////////////////////////////////////////////////////// + +#define sha256_block_size 64 +#define sha256_hash_size 32 +#define sha224_hash_size 28 +#define ROTR32(dword, n) ((dword) >> (n) ^ ((dword) << (32 - (n)))) +#define bswap_32(x) __builtin_bswap32(x) + +struct sha256_ctx { + uint32_t message[16]; /* 512-bit buffer for leftovers */ + uint64_t length; /* number of processed bytes */ + uint32_t hash[8]; /* 256-bit algorithm internal hashing state */ + uint32_t digest_length; /* length of the algorithm digest in bytes */ +}; + +struct sha256_ctx sctx; +struct sha256_ctx* ctx = &sctx; + +/* SHA-224 and SHA-256 constants for 64 rounds. These words represent + * the first 32 bits of the fractional parts of the cube + * roots of the first 64 prime numbers. */ +static const uint32_t rhash_k256[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, + 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, + 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, + 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +}; + +/* The SHA256/224 functions defined by FIPS 180-3, 4.1.2 */ +/* Optimized version of Ch(x,y,z)=((x & y) | (~x & z)) */ +#define Ch(x, y, z) ((z) ^ ((x) & ((y) ^ (z)))) +/* Optimized version of Maj(x,y,z)=((x & y) ^ (x & z) ^ (y & z)) */ +#define Maj(x, y, z) (((x) & (y)) ^ ((z) & ((x) ^ (y)))) + +#define Sigma0(x) (ROTR32((x), 2) ^ ROTR32((x), 13) ^ ROTR32((x), 22)) +#define Sigma1(x) (ROTR32((x), 6) ^ ROTR32((x), 11) ^ ROTR32((x), 25)) +#define sigma0(x) (ROTR32((x), 7) ^ ROTR32((x), 18) ^ ((x) >> 3)) +#define sigma1(x) (ROTR32((x), 17) ^ ROTR32((x), 19) ^ ((x) >> 10)) + +/* Recalculate element n-th of circular buffer W using formula + * W[n] = sigma1(W[n - 2]) + W[n - 7] + sigma0(W[n - 15]) + W[n - 16]; */ +#define RECALCULATE_W(W, n) \ + (W[n] += \ + (sigma1(W[(n - 2) & 15]) + W[(n - 7) & 15] + sigma0(W[(n - 15) & 15]))) + +#define ROUND(a, b, c, d, e, f, g, h, k, data) \ + { \ + uint32_t T1 = h + Sigma1(e) + Ch(e, f, g) + k + (data); \ + d += T1, h = T1 + Sigma0(a) + Maj(a, b, c); \ + } +#define ROUND_1_16(a, b, c, d, e, f, g, h, n) \ + ROUND(a, b, c, d, e, f, g, h, rhash_k256[n], W[n] = bswap_32(block[n])) +#define ROUND_17_64(a, b, c, d, e, f, g, h, n) \ + ROUND(a, b, c, d, e, f, g, h, k[n], RECALCULATE_W(W, n)) + +/** + * Initialize context before calculaing hash. + * + */ +void sha256_init() { + /* Initial values. These words were obtained by taking the first 32 + * bits of the fractional parts of the square roots of the first + * eight prime numbers. */ + static const uint32_t SHA256_H0[8] = { + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 + }; + + ctx->length = 0; + ctx->digest_length = sha256_hash_size; + + /* initialize algorithm state */ + + #pragma clang loop vectorize(enable) + for (uint8_t i = 0; i < 8; i += 2) { + *(uint64_t*)&ctx->hash[i] = *(uint64_t*)&SHA256_H0[i]; + } +} + +/** + * Initialize context before calculaing hash. + * + */ +void sha224_init() { + /* Initial values from FIPS 180-3. These words were obtained by taking + * bits from 33th to 64th of the fractional parts of the square + * roots of ninth through sixteenth prime numbers. */ + static const uint32_t SHA224_H0[8] = { + 0xc1059ed8, 0x367cd507, 0x3070dd17, 0xf70e5939, + 0xffc00b31, 0x68581511, 0x64f98fa7, 0xbefa4fa4 + }; + + ctx->length = 0; + ctx->digest_length = sha224_hash_size; + + #pragma clang loop vectorize(enable) + for (uint8_t i = 0; i < 8; i += 2) { + *(uint64_t*)&ctx->hash[i] = *(uint64_t*)&SHA224_H0[i]; + } +} + +/** + * The core transformation. Process a 512-bit block. + * + * @param hash algorithm state + * @param block the message block to process + */ +static void sha256_process_block(uint32_t hash[8], uint32_t block[16]) { + uint32_t A, B, C, D, E, F, G, H; + uint32_t W[16]; + const uint32_t* k; + int i; + + A = hash[0], B = hash[1], C = hash[2], D = hash[3]; + E = hash[4], F = hash[5], G = hash[6], H = hash[7]; + + /* Compute SHA using alternate Method: FIPS 180-3 6.1.3 */ + ROUND_1_16(A, B, C, D, E, F, G, H, 0); + ROUND_1_16(H, A, B, C, D, E, F, G, 1); + ROUND_1_16(G, H, A, B, C, D, E, F, 2); + ROUND_1_16(F, G, H, A, B, C, D, E, 3); + ROUND_1_16(E, F, G, H, A, B, C, D, 4); + ROUND_1_16(D, E, F, G, H, A, B, C, 5); + ROUND_1_16(C, D, E, F, G, H, A, B, 6); + ROUND_1_16(B, C, D, E, F, G, H, A, 7); + ROUND_1_16(A, B, C, D, E, F, G, H, 8); + ROUND_1_16(H, A, B, C, D, E, F, G, 9); + ROUND_1_16(G, H, A, B, C, D, E, F, 10); + ROUND_1_16(F, G, H, A, B, C, D, E, 11); + ROUND_1_16(E, F, G, H, A, B, C, D, 12); + ROUND_1_16(D, E, F, G, H, A, B, C, 13); + ROUND_1_16(C, D, E, F, G, H, A, B, 14); + ROUND_1_16(B, C, D, E, F, G, H, A, 15); + + #pragma clang loop vectorize(enable) + for (i = 16, k = &rhash_k256[16]; i < 64; i += 16, k += 16) { + ROUND_17_64(A, B, C, D, E, F, G, H, 0); + ROUND_17_64(H, A, B, C, D, E, F, G, 1); + ROUND_17_64(G, H, A, B, C, D, E, F, 2); + ROUND_17_64(F, G, H, A, B, C, D, E, 3); + ROUND_17_64(E, F, G, H, A, B, C, D, 4); + ROUND_17_64(D, E, F, G, H, A, B, C, 5); + ROUND_17_64(C, D, E, F, G, H, A, B, 6); + ROUND_17_64(B, C, D, E, F, G, H, A, 7); + ROUND_17_64(A, B, C, D, E, F, G, H, 8); + ROUND_17_64(H, A, B, C, D, E, F, G, 9); + ROUND_17_64(G, H, A, B, C, D, E, F, 10); + ROUND_17_64(F, G, H, A, B, C, D, E, 11); + ROUND_17_64(E, F, G, H, A, B, C, D, 12); + ROUND_17_64(D, E, F, G, H, A, B, C, 13); + ROUND_17_64(C, D, E, F, G, H, A, B, 14); + ROUND_17_64(B, C, D, E, F, G, H, A, 15); + } + + hash[0] += A, hash[1] += B, hash[2] += C, hash[3] += D; + hash[4] += E, hash[5] += F, hash[6] += G, hash[7] += H; +} + +/** + * Calculate message hash. + * Can be called repeatedly with chunks of the message to be hashed. + * + * @param size length of the message chunk + */ +WASM_EXPORT +void Hash_Update(uint32_t size) { + const uint8_t* msg = main_buffer; + uint32_t index = (uint32_t)ctx->length & 63; + ctx->length += size; + + /* fill partial block */ + if (index) { + uint32_t left = sha256_block_size - index; + uint32_t end = size < left ? size : left; + uint8_t* message8 = (uint8_t*)ctx->message; + for (uint8_t i = 0; i < end; i++) { + *(message8 + index + i) = msg[i]; + } + if (size < left) return; + + /* process partial block */ + sha256_process_block(ctx->hash, (uint32_t*)ctx->message); + msg += left; + size -= left; + } + + while (size >= sha256_block_size) { + uint32_t* aligned_message_block = (uint32_t*)msg; + + sha256_process_block(ctx->hash, aligned_message_block); + msg += sha256_block_size; + size -= sha256_block_size; + } + + if (size) { + /* save leftovers */ + for (uint8_t i = 0; i < size; i++) { + *(((uint8_t*)ctx->message) + i) = msg[i]; + } + } +} + +/** + * Store calculated hash into the given array. + * + */ +WASM_EXPORT +void Hash_Final() { + uint32_t index = ((uint32_t)ctx->length & 63) >> 2; + uint32_t shift = ((uint32_t)ctx->length & 3) * 8; + + /* pad message and run for last block */ + + /* append the byte 0x80 to the message */ + ctx->message[index] &= ~(0xFFFFFFFFu << shift); + ctx->message[index++] ^= 0x80u << shift; + + /* if no room left in the message to store 64-bit message length */ + if (index > 14) { + /* then fill the rest with zeros and process it */ + while (index < 16) { + ctx->message[index++] = 0; + } + sha256_process_block(ctx->hash, ctx->message); + index = 0; + } + + while (index < 14) { + ctx->message[index++] = 0; + } + + ctx->message[14] = bswap_32((uint32_t)(ctx->length >> 29)); + ctx->message[15] = bswap_32((uint32_t)(ctx->length << 3)); + sha256_process_block(ctx->hash, ctx->message); + + #pragma clang loop vectorize(enable) + for (int32_t i = 7; i >= 0; i--) { + ctx->hash[i] = bswap_32(ctx->hash[i]); + } + + for (uint8_t i = 0; i < ctx->digest_length; i++) { + main_buffer[i] = *(((uint8_t*)ctx->hash) + i); + } +} + +WASM_EXPORT +uint32_t Hash_Init(uint32_t bits) { + if (bits == 224) { + sha224_init(); + } else { + sha256_init(); + } + return 0; +} + +WASM_EXPORT +const uint32_t STATE_SIZE = sizeof(*ctx); + +WASM_EXPORT +uint8_t* Hash_GetState() { + return (uint8_t*) ctx; +} + +WASM_EXPORT +uint32_t GetBufferPtr() { + return (uint32_t) main_buffer; +} diff --git a/vendor/hash-wasm/sha256.d.ts b/vendor/hash-wasm/sha256.d.ts new file mode 100644 index 0000000000000000000000000000000000000000..b6d0f51481360ced2c39d6140091d7da6681f150 --- /dev/null +++ b/vendor/hash-wasm/sha256.d.ts @@ -0,0 +1,8 @@ +declare function Module(): Promise<{ + HEAPU8: Uint8Array; + _Hash_Init(type: number): void; + _Hash_Update(length: number): void; + _Hash_Final(): void; + _GetBufferPtr(): number; +}>; +export default Module; diff --git a/vendor/hash-wasm/sha256.js b/vendor/hash-wasm/sha256.js new file mode 100644 index 0000000000000000000000000000000000000000..7ff85e582d49404ba7565cef5428115398baac3f --- /dev/null +++ b/vendor/hash-wasm/sha256.js @@ -0,0 +1,685 @@ + +var Module = (() => { + var _unused = import.meta.url; + + return ( +function(moduleArg = {}) { + +// include: shell.js +// The Module object: Our interface to the outside world. We import +// and export values on it. There are various ways Module can be used: +// 1. Not defined. We create it here +// 2. A function parameter, function(Module) { ..generated code.. } +// 3. pre-run appended it, var Module = {}; ..generated code.. +// 4. External script tag defines var Module. +// We need to check if Module already exists (e.g. case 3 above). +// Substitution will be replaced with actual code on later stage of the build, +// this way Closure Compiler will not mangle it (e.g. case 4. above). +// Note that if you want to run closure, and also to use Module +// after the generated code, you will need to define var Module = {}; +// before the code. Then that object will be used in the code, and you +// can continue to use Module afterwards as well. +var Module = moduleArg; + +// Set up the promise that indicates the Module is initialized +var readyPromiseResolve, readyPromiseReject; +Module['ready'] = new Promise((resolve, reject) => { + readyPromiseResolve = resolve; + readyPromiseReject = reject; +}); + +// --pre-jses are emitted after the Module integration code, so that they can +// refer to Module (if they choose; they can also define Module) + + +// Sometimes an existing Module object exists with properties +// meant to overwrite the default module functionality. Here +// we collect those properties and reapply _after_ we configure +// the current environment's defaults to avoid having to be so +// defensive during initialization. +var moduleOverrides = Object.assign({}, Module); + +var arguments_ = []; +var thisProgram = './this.program'; +var quit_ = (status, toThrow) => { + throw toThrow; +}; + +// Determine the runtime environment we are in. You can customize this by +// setting the ENVIRONMENT setting at compile time (see settings.js). + +// Attempt to auto-detect the environment +var ENVIRONMENT_IS_WEB = typeof window == 'object'; +var ENVIRONMENT_IS_WORKER = typeof importScripts == 'function'; +// N.b. Electron.js environment is simultaneously a NODE-environment, but +// also a web environment. +var ENVIRONMENT_IS_NODE = typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string'; +var ENVIRONMENT_IS_SHELL = !ENVIRONMENT_IS_WEB && !ENVIRONMENT_IS_NODE && !ENVIRONMENT_IS_WORKER; + +// `/` should be present at the end if `scriptDirectory` is not empty +var scriptDirectory = ''; +function locateFile(path) { + if (Module['locateFile']) { + return Module['locateFile'](path, scriptDirectory); + } + return scriptDirectory + path; +} + +// Hooks that are implemented differently in different runtime environments. +var read_, + readAsync, + readBinary; + +// Note that this includes Node.js workers when relevant (pthreads is enabled). +// Node.js workers are detected as a combination of ENVIRONMENT_IS_WORKER and +// ENVIRONMENT_IS_NODE. +if (ENVIRONMENT_IS_WEB || ENVIRONMENT_IS_WORKER) { + if (ENVIRONMENT_IS_WORKER) { // Check worker, not web, since window could be polyfilled + scriptDirectory = self.location.href; + } else if (typeof document != 'undefined' && document.currentScript) { // web + scriptDirectory = document.currentScript.src; + } + // When MODULARIZE, this JS may be executed later, after document.currentScript + // is gone, so we saved it, and we use it here instead of any other info. + if (false) { + scriptDirectory = false; + } + // blob urls look like blob:http://site.com/etc/etc and we cannot infer anything from them. + // otherwise, slice off the final part of the url to find the script directory. + // if scriptDirectory does not contain a slash, lastIndexOf will return -1, + // and scriptDirectory will correctly be replaced with an empty string. + // If scriptDirectory contains a query (starting with ?) or a fragment (starting with #), + // they are removed because they could contain a slash. + if (scriptDirectory.startsWith('blob:')) { + scriptDirectory = ''; + } else { + scriptDirectory = scriptDirectory.substr(0, scriptDirectory.replace(/[?#].*/, '').lastIndexOf('/')+1); + } + + // Differentiate the Web Worker from the Node Worker case, as reading must + // be done differently. + { +// include: web_or_worker_shell_read.js +read_ = (url) => { + var xhr = new XMLHttpRequest(); + xhr.open('GET', url, false); + xhr.send(null); + return xhr.responseText; + } + + if (ENVIRONMENT_IS_WORKER) { + readBinary = (url) => { + var xhr = new XMLHttpRequest(); + xhr.open('GET', url, false); + xhr.responseType = 'arraybuffer'; + xhr.send(null); + return new Uint8Array(/** @type{!ArrayBuffer} */(xhr.response)); + }; + } + + readAsync = (url, onload, onerror) => { + var xhr = new XMLHttpRequest(); + xhr.open('GET', url, true); + xhr.responseType = 'arraybuffer'; + xhr.onload = () => { + if (xhr.status == 200 || (xhr.status == 0 && xhr.response)) { // file URLs can return 0 + onload(xhr.response); + return; + } + onerror(); + }; + xhr.onerror = onerror; + xhr.send(null); + } + +// end include: web_or_worker_shell_read.js + } +} else +{ +} + +var out = Module['print'] || console.log.bind(console); +var err = Module['printErr'] || console.error.bind(console); + +// Merge back in the overrides +Object.assign(Module, moduleOverrides); +// Free the object hierarchy contained in the overrides, this lets the GC +// reclaim data used. +moduleOverrides = null; + +// Emit code to handle expected values on the Module object. This applies Module.x +// to the proper local x. This has two benefits: first, we only emit it if it is +// expected to arrive, and second, by using a local everywhere else that can be +// minified. + +if (Module['arguments']) arguments_ = Module['arguments']; + +if (Module['thisProgram']) thisProgram = Module['thisProgram']; + +if (Module['quit']) quit_ = Module['quit']; + +// perform assertions in shell.js after we set up out() and err(), as otherwise if an assertion fails it cannot print the message +// end include: shell.js + +// include: preamble.js +// === Preamble library stuff === + +// Documentation for the public APIs defined in this file must be updated in: +// site/source/docs/api_reference/preamble.js.rst +// A prebuilt local version of the documentation is available at: +// site/build/text/docs/api_reference/preamble.js.txt +// You can also build docs locally as HTML or other formats in site/ +// An online HTML version (which may be of a different version of Emscripten) +// is up at http://kripken.github.io/emscripten-site/docs/api_reference/preamble.js.html + +var wasmBinary; +if (Module['wasmBinary']) wasmBinary = Module['wasmBinary']; + +if (typeof WebAssembly != 'object') { + abort('no native wasm support detected'); +} + +// include: base64Utils.js +// Converts a string of base64 into a byte array (Uint8Array). +function intArrayFromBase64(s) { + + var decoded = atob(s); + var bytes = new Uint8Array(decoded.length); + for (var i = 0 ; i < decoded.length ; ++i) { + bytes[i] = decoded.charCodeAt(i); + } + return bytes; +} + +// If filename is a base64 data URI, parses and returns data (Buffer on node, +// Uint8Array otherwise). If filename is not a base64 data URI, returns undefined. +function tryParseAsDataURI(filename) { + if (!isDataURI(filename)) { + return; + } + + return intArrayFromBase64(filename.slice(dataURIPrefix.length)); +} +// end include: base64Utils.js +// Wasm globals + +var wasmMemory; + +//======================================== +// Runtime essentials +//======================================== + +// whether we are quitting the application. no code should run after this. +// set in exit() and abort() +var ABORT = false; + +// set by exit() and abort(). Passed to 'onExit' handler. +// NOTE: This is also used as the process return code code in shell environments +// but only when noExitRuntime is false. +var EXITSTATUS; + +// In STRICT mode, we only define assert() when ASSERTIONS is set. i.e. we +// don't define it at all in release modes. This matches the behaviour of +// MINIMAL_RUNTIME. +// TODO(sbc): Make this the default even without STRICT enabled. +/** @type {function(*, string=)} */ +function assert(condition, text) { + if (!condition) { + // This build was created without ASSERTIONS defined. `assert()` should not + // ever be called in this configuration but in case there are callers in + // the wild leave this simple abort() implementation here for now. + abort(text); + } +} + +// Memory management + +var HEAP, +/** @type {!Int8Array} */ + HEAP8, +/** @type {!Uint8Array} */ + HEAPU8, +/** @type {!Int16Array} */ + HEAP16, +/** @type {!Uint16Array} */ + HEAPU16, +/** @type {!Int32Array} */ + HEAP32, +/** @type {!Uint32Array} */ + HEAPU32, +/** @type {!Float32Array} */ + HEAPF32, +/** @type {!Float64Array} */ + HEAPF64; + +// include: runtime_shared.js +function updateMemoryViews() { + var b = wasmMemory.buffer; + Module['HEAP8'] = HEAP8 = new Int8Array(b); + Module['HEAP16'] = HEAP16 = new Int16Array(b); + Module['HEAPU8'] = HEAPU8 = new Uint8Array(b); + Module['HEAPU16'] = HEAPU16 = new Uint16Array(b); + Module['HEAP32'] = HEAP32 = new Int32Array(b); + Module['HEAPU32'] = HEAPU32 = new Uint32Array(b); + Module['HEAPF32'] = HEAPF32 = new Float32Array(b); + Module['HEAPF64'] = HEAPF64 = new Float64Array(b); +} +// end include: runtime_shared.js +// include: runtime_stack_check.js +// end include: runtime_stack_check.js +// include: runtime_assertions.js +// end include: runtime_assertions.js +var __ATPRERUN__ = []; // functions called before the runtime is initialized +var __ATINIT__ = []; // functions called during startup +var __ATEXIT__ = []; // functions called during shutdown +var __ATPOSTRUN__ = []; // functions called after the main() is called + +var runtimeInitialized = false; + +function preRun() { + if (Module['preRun']) { + if (typeof Module['preRun'] == 'function') Module['preRun'] = [Module['preRun']]; + while (Module['preRun'].length) { + addOnPreRun(Module['preRun'].shift()); + } + } + callRuntimeCallbacks(__ATPRERUN__); +} + +function initRuntime() { + runtimeInitialized = true; + + + callRuntimeCallbacks(__ATINIT__); +} + +function postRun() { + + if (Module['postRun']) { + if (typeof Module['postRun'] == 'function') Module['postRun'] = [Module['postRun']]; + while (Module['postRun'].length) { + addOnPostRun(Module['postRun'].shift()); + } + } + + callRuntimeCallbacks(__ATPOSTRUN__); +} + +function addOnPreRun(cb) { + __ATPRERUN__.unshift(cb); +} + +function addOnInit(cb) { + __ATINIT__.unshift(cb); +} + +function addOnExit(cb) { +} + +function addOnPostRun(cb) { + __ATPOSTRUN__.unshift(cb); +} + +// include: runtime_math.js +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Math/imul + +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Math/fround + +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Math/clz32 + +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Math/trunc + +// end include: runtime_math.js +// A counter of dependencies for calling run(). If we need to +// do asynchronous work before running, increment this and +// decrement it. Incrementing must happen in a place like +// Module.preRun (used by emcc to add file preloading). +// Note that you can add dependencies in preRun, even though +// it happens right before run - run will be postponed until +// the dependencies are met. +var runDependencies = 0; +var runDependencyWatcher = null; +var dependenciesFulfilled = null; // overridden to take different actions when all run dependencies are fulfilled + +function getUniqueRunDependency(id) { + return id; +} + +function addRunDependency(id) { + runDependencies++; + + Module['monitorRunDependencies']?.(runDependencies); + +} + +function removeRunDependency(id) { + runDependencies--; + + Module['monitorRunDependencies']?.(runDependencies); + + if (runDependencies == 0) { + if (runDependencyWatcher !== null) { + clearInterval(runDependencyWatcher); + runDependencyWatcher = null; + } + if (dependenciesFulfilled) { + var callback = dependenciesFulfilled; + dependenciesFulfilled = null; + callback(); // can add another dependenciesFulfilled + } + } +} + +/** @param {string|number=} what */ +function abort(what) { + Module['onAbort']?.(what); + + what = 'Aborted(' + what + ')'; + // TODO(sbc): Should we remove printing and leave it up to whoever + // catches the exception? + err(what); + + ABORT = true; + EXITSTATUS = 1; + + what += '. Build with -sASSERTIONS for more info.'; + + // Use a wasm runtime error, because a JS error might be seen as a foreign + // exception, which means we'd run destructors on it. We need the error to + // simply make the program stop. + // FIXME This approach does not work in Wasm EH because it currently does not assume + // all RuntimeErrors are from traps; it decides whether a RuntimeError is from + // a trap or not based on a hidden field within the object. So at the moment + // we don't have a way of throwing a wasm trap from JS. TODO Make a JS API that + // allows this in the wasm spec. + + // Suppress closure compiler warning here. Closure compiler's builtin extern + // definition for WebAssembly.RuntimeError claims it takes no arguments even + // though it can. + // TODO(https://github.com/google/closure-compiler/pull/3913): Remove if/when upstream closure gets fixed. + /** @suppress {checkTypes} */ + var e = new WebAssembly.RuntimeError(what); + + readyPromiseReject(e); + // Throw the error whether or not MODULARIZE is set because abort is used + // in code paths apart from instantiation where an exception is expected + // to be thrown when abort is called. + throw e; +} + +// include: memoryprofiler.js +// end include: memoryprofiler.js +// include: URIUtils.js +// Prefix of data URIs emitted by SINGLE_FILE and related options. +var dataURIPrefix = 'data:application/octet-stream;base64,'; + +/** + * Indicates whether filename is a base64 data URI. + * @noinline + */ +var isDataURI = (filename) => filename.startsWith(dataURIPrefix); + +/** + * Indicates whether filename is delivered via file protocol (as opposed to http/https) + * @noinline + */ +var isFileURI = (filename) => filename.startsWith('file://'); +// end include: URIUtils.js +// include: runtime_exceptions.js +// end include: runtime_exceptions.js +var wasmBinaryFile; + wasmBinaryFile = 'data:application/octet-stream;base64,AGFzbQEAAAABHQZgAX8AYAABf2AAAGABfwF/YAJ/fwBgA39/fwF/Aw0MAgAEAgMBBQABAQADBAUBcAEBAQUGAQGAAoACBg4CfwFB8IuEBAt/AUEACweYAQoGbWVtb3J5AgARX193YXNtX2NhbGxfY3RvcnMAAAtIYXNoX1VwZGF0ZQABCkhhc2hfRmluYWwAAwlIYXNoX0luaXQABAxHZXRCdWZmZXJQdHIABRlfX2luZGlyZWN0X2Z1bmN0aW9uX3RhYmxlAQAJc3RhY2tTYXZlAAkMc3RhY2tSZXN0b3JlAAoKc3RhY2tBbGxvYwALCossDAIAC+4CAgV/AX5BACgCwAoiASABKQNAIgYgAK18NwNAAkACQAJAIAanQT9xIgINAEGACyEBIAAhAgwBC0HAACACayEDAkAgAEUNACADIAAgAyAASRshBCABIAJqIQVBACEBA0AgBSABIgFqQYALIAFqLQAAOgAAIAFBAWoiAiEBIAIgBEcNAAsLAkACQCAAIANJIgRFDQBBgAshASAAIQIMAQtBACgCwAoiAUHIAGogARACQYALIANqIQEgACADayECCyABIQEgAiECIAQNAQsgASEBAkACQCACIgJBwABPDQAgASEFIAIhAAwBCyACIQIgASEEA0BBACgCwApByABqIAQiBBACIAJBQGoiASECIARBwABqIgUhBCAFIQUgASEAIAFBP0sNAAsLIAUhBSAAIgBFDQBBACEBQQAhAgNAQQAoAsAKIAEiAWogBSABai0AADoAACACQQFqIgJB/wFxIgQhASACIQIgACAESw0ACwsLqCEBK38gACgCCCICIAAoAgQiAyAAKAIAIgRzcSADIARxcyAEQR53IARBE3dzIARBCndzaiAAKAIQIgVBGncgBUEVd3MgBUEHd3MgACgCHCIGaiAAKAIYIgcgACgCFCIIcyAFcSAHc2ogASgCACIJQRh0IAlBgP4DcUEIdHIgCUEIdkGA/gNxIAlBGHZyciIKakGY36iUBGoiC2oiCSAEcyADcSAJIARxcyAJQR53IAlBE3dzIAlBCndzaiAHIAEoAgQiDEEYdCAMQYD+A3FBCHRyIAxBCHZBgP4DcSAMQRh2cnIiDWogCyAAKAIMIg5qIg8gCCAFc3EgCHNqIA9BGncgD0EVd3MgD0EHd3NqQZGJ3YkHaiIQaiIMIAlzIARxIAwgCXFzIAxBHncgDEETd3MgDEEKd3NqIAggASgCCCILQRh0IAtBgP4DcUEIdHIgC0EIdkGA/gNxIAtBGHZyciIRaiAQIAJqIhIgDyAFc3EgBXNqIBJBGncgEkEVd3MgEkEHd3NqQc/3g657aiITaiILIAxzIAlxIAsgDHFzIAtBHncgC0ETd3MgC0EKd3NqIAUgASgCDCIQQRh0IBBBgP4DcUEIdHIgEEEIdkGA/gNxIBBBGHZyciIUaiATIANqIhMgEiAPc3EgD3NqIBNBGncgE0EVd3MgE0EHd3NqQaW3181+aiIVaiIQIAtzIAxxIBAgC3FzIBBBHncgEEETd3MgEEEKd3NqIA8gASgCECIWQRh0IBZBgP4DcUEIdHIgFkEIdkGA/gNxIBZBGHZyciIXaiAVIARqIhYgEyASc3EgEnNqIBZBGncgFkEVd3MgFkEHd3NqQduE28oDaiIYaiIPIBBzIAtxIA8gEHFzIA9BHncgD0ETd3MgD0EKd3NqIAEoAhQiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiGSASaiAYIAlqIhIgFiATc3EgE3NqIBJBGncgEkEVd3MgEkEHd3NqQfGjxM8FaiIYaiIJIA9zIBBxIAkgD3FzIAlBHncgCUETd3MgCUEKd3NqIAEoAhgiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiGiATaiAYIAxqIhMgEiAWc3EgFnNqIBNBGncgE0EVd3MgE0EHd3NqQaSF/pF5aiIYaiIMIAlzIA9xIAwgCXFzIAxBHncgDEETd3MgDEEKd3NqIAEoAhwiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiGyAWaiAYIAtqIhYgEyASc3EgEnNqIBZBGncgFkEVd3MgFkEHd3NqQdW98dh6aiIYaiILIAxzIAlxIAsgDHFzIAtBHncgC0ETd3MgC0EKd3NqIAEoAiAiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiHCASaiAYIBBqIhIgFiATc3EgE3NqIBJBGncgEkEVd3MgEkEHd3NqQZjVnsB9aiIYaiIQIAtzIAxxIBAgC3FzIBBBHncgEEETd3MgEEEKd3NqIAEoAiQiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiHSATaiAYIA9qIhMgEiAWc3EgFnNqIBNBGncgE0EVd3MgE0EHd3NqQYG2jZQBaiIYaiIPIBBzIAtxIA8gEHFzIA9BHncgD0ETd3MgD0EKd3NqIAEoAigiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiHiAWaiAYIAlqIhYgEyASc3EgEnNqIBZBGncgFkEVd3MgFkEHd3NqQb6LxqECaiIYaiIJIA9zIBBxIAkgD3FzIAlBHncgCUETd3MgCUEKd3NqIAEoAiwiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiHyASaiAYIAxqIhIgFiATc3EgE3NqIBJBGncgEkEVd3MgEkEHd3NqQcP7sagFaiIYaiIMIAlzIA9xIAwgCXFzIAxBHncgDEETd3MgDEEKd3NqIAEoAjAiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiICATaiAYIAtqIhMgEiAWc3EgFnNqIBNBGncgE0EVd3MgE0EHd3NqQfS6+ZUHaiIYaiILIAxzIAlxIAsgDHFzIAtBHncgC0ETd3MgC0EKd3NqIAEoAjQiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiISAWaiAYIBBqIhAgEyASc3EgEnNqIBBBGncgEEEVd3MgEEEHd3NqQf7j+oZ4aiIYaiIWIAtzIAxxIBYgC3FzIBZBHncgFkETd3MgFkEKd3NqIAEoAjgiFUEYdCAVQYD+A3FBCHRyIBVBCHZBgP4DcSAVQRh2cnIiIiASaiAYIA9qIg8gECATc3EgE3NqIA9BGncgD0EVd3MgD0EHd3NqQaeN8N55aiIVaiISIBZzIAtxIBIgFnFzIBJBHncgEkETd3MgEkEKd3NqIAEoAjwiAUEYdCABQYD+A3FBCHRyIAFBCHZBgP4DcSABQRh2cnIiIyATaiAVIAlqIgEgDyAQc3EgEHNqIAFBGncgAUEVd3MgAUEHd3NqQfTi74x8aiIJaiEVIBIhGCAWISQgCyElIAkgDGohJiABIScgDyEoIBAhKSAjISMgIiEiICEhISAgISAgHyEfIB4hHiAdIR0gHCEcIBshGyAaIRogGSEZIBchFyAUIRQgESERIA0hECAKIQxBgAkhAUEQISoDQCAVIgkgGCIKcyAkIitxIAkgCnFzIAlBHncgCUETd3MgCUEKd3NqIBAiEEEZdyAQQQ53cyAQQQN2cyAMaiAdIh1qICIiFkEPdyAWQQ13cyAWQQp2c2oiDCApaiAmIhIgJyIPICgiE3NxIBNzaiASQRp3IBJBFXdzIBJBB3dzaiABIgEoAgBqIiRqIgsgCXMgCnEgCyAJcXMgC0EedyALQRN3cyALQQp3c2ogESIYQRl3IBhBDndzIBhBA3ZzIBBqIB4iHmogIyIVQQ93IBVBDXdzIBVBCnZzaiINIBNqIAEoAgRqICQgJWoiEyASIA9zcSAPc2ogE0EadyATQRV3cyATQQd3c2oiJWoiECALcyAJcSAQIAtxcyAQQR53IBBBE3dzIBBBCndzaiAUIiRBGXcgJEEOd3MgJEEDdnMgGGogHyIfaiAMQQ93IAxBDXdzIAxBCnZzaiIRIA9qIAEoAghqICUgK2oiGCATIBJzcSASc2ogGEEadyAYQRV3cyAYQQd3c2oiJWoiDyAQcyALcSAPIBBxcyAPQR53IA9BE3dzIA9BCndzaiAXIhdBGXcgF0EOd3MgF0EDdnMgJGogICIgaiANQQ93IA1BDXdzIA1BCnZzaiIUIBJqIAEoAgxqICUgCmoiCiAYIBNzcSATc2ogCkEadyAKQRV3cyAKQQd3c2oiJWoiEiAPcyAQcSASIA9xcyASQR53IBJBE3dzIBJBCndzaiATIBkiJEEZdyAkQQ53cyAkQQN2cyAXaiAhIiFqIBFBD3cgEUENd3MgEUEKdnNqIhdqIAEoAhBqICUgCWoiEyAKIBhzcSAYc2ogE0EadyATQRV3cyATQQd3c2oiJWoiCSAScyAPcSAJIBJxcyAJQR53IAlBE3dzIAlBCndzaiABKAIUIBoiGkEZdyAaQQ53cyAaQQN2cyAkaiAWaiAUQQ93IBRBDXdzIBRBCnZzaiIZaiAYaiAlIAtqIhggEyAKc3EgCnNqIBhBGncgGEEVd3MgGEEHd3NqIiVqIgsgCXMgEnEgCyAJcXMgC0EedyALQRN3cyALQQp3c2ogASgCGCAbIiRBGXcgJEEOd3MgJEEDdnMgGmogFWogF0EPdyAXQQ13cyAXQQp2c2oiGmogCmogJSAQaiIKIBggE3NxIBNzaiAKQRp3IApBFXdzIApBB3dzaiIlaiIQIAtzIAlxIBAgC3FzIBBBHncgEEETd3MgEEEKd3NqIAEoAhwgHCIcQRl3IBxBDndzIBxBA3ZzICRqIAxqIBlBD3cgGUENd3MgGUEKdnNqIhtqIBNqICUgD2oiJCAKIBhzcSAYc2ogJEEadyAkQRV3cyAkQQd3c2oiE2oiDyAQcyALcSAPIBBxcyAPQR53IA9BE3dzIA9BCndzaiABKAIgIB1BGXcgHUEOd3MgHUEDdnMgHGogDWogGkEPdyAaQQ13cyAaQQp2c2oiHGogGGogEyASaiIYICQgCnNxIApzaiAYQRp3IBhBFXdzIBhBB3dzaiITaiISIA9zIBBxIBIgD3FzIBJBHncgEkETd3MgEkEKd3NqIAEoAiQgHkEZdyAeQQ53cyAeQQN2cyAdaiARaiAbQQ93IBtBDXdzIBtBCnZzaiIdaiAKaiATIAlqIgkgGCAkc3EgJHNqIAlBGncgCUEVd3MgCUEHd3NqIgpqIhMgEnMgD3EgEyAScXMgE0EedyATQRN3cyATQQp3c2ogASgCKCAfQRl3IB9BDndzIB9BA3ZzIB5qIBRqIBxBD3cgHEENd3MgHEEKdnNqIh5qICRqIAogC2oiCiAJIBhzcSAYc2ogCkEadyAKQRV3cyAKQQd3c2oiJGoiCyATcyAScSALIBNxcyALQR53IAtBE3dzIAtBCndzaiABKAIsICBBGXcgIEEOd3MgIEEDdnMgH2ogF2ogHUEPdyAdQQ13cyAdQQp2c2oiH2ogGGogJCAQaiIYIAogCXNxIAlzaiAYQRp3IBhBFXdzIBhBB3dzaiIkaiIQIAtzIBNxIBAgC3FzIBBBHncgEEETd3MgEEEKd3NqIAEoAjAgIUEZdyAhQQ53cyAhQQN2cyAgaiAZaiAeQQ93IB5BDXdzIB5BCnZzaiIgaiAJaiAkIA9qIiQgGCAKc3EgCnNqICRBGncgJEEVd3MgJEEHd3NqIg9qIgkgEHMgC3EgCSAQcXMgCUEedyAJQRN3cyAJQQp3c2ogASgCNCAWQRl3IBZBDndzIBZBA3ZzICFqIBpqIB9BD3cgH0ENd3MgH0EKdnNqIiFqIApqIA8gEmoiDyAkIBhzcSAYc2ogD0EadyAPQRV3cyAPQQd3c2oiCmoiEiAJcyAQcSASIAlxcyASQR53IBJBE3dzIBJBCndzaiABKAI4IBVBGXcgFUEOd3MgFUEDdnMgFmogG2ogIEEPdyAgQQ13cyAgQQp2c2oiImogGGogCiATaiITIA8gJHNxICRzaiATQRp3IBNBFXdzIBNBB3dzaiIYaiIWIBJzIAlxIBYgEnFzIBZBHncgFkETd3MgFkEKd3NqIAEoAjwgDEEZdyAMQQ53cyAMQQN2cyAVaiAcaiAhQQ93ICFBDXdzICFBCnZzaiIKaiAkaiAYIAtqIgsgEyAPc3EgD3NqIAtBGncgC0EVd3MgC0EHd3NqIiZqIishFSAWIRggEiEkIAkhJSAmIBBqIiwhJiALIScgEyEoIA8hKSAKISMgIiEiICEhISAgISAgHyEfIB4hHiAdIR0gHCEcIBshGyAaIRogGSEZIBchFyAUIRQgESERIA0hECAMIQwgAUHAAGohASAqIgpBEGohKiAKQTBJDQALIAAgDyAGajYCHCAAIBMgB2o2AhggACALIAhqNgIUIAAgLCAFajYCECAAIAkgDmo2AgwgACASIAJqNgIIIAAgFiADajYCBCAAICsgBGo2AgAL1AMDBX8BfgF7QQAoAsAKIgAgACgCQCIBQQJ2QQ9xIgJBAnRqIgMgAygCAEF/IAFBA3QiAXRBf3NxQYABIAF0czYCAAJAAkAgAkEOTw0AIAJBAWohAAwBCwJAIAJBDkcNACAAQQA2AjwLIABByABqIAAQAkEAIQALAkAgACIAQQ1LDQBBACgCwAogAEECdCIAakEAQTggAGsQBhoLQQAoAsAKIgAgACkDQCIFpyICQRt0IAJBC3RBgID8B3FyIAJBBXZBgP4DcSACQQN0QRh2cnI2AjwgACAFQh2IpyICQRh0IAJBgP4DcUEIdHIgAkEIdkGA/gNxIAJBGHZycjYCOCAAQcgAaiAAEAJBACgCwApBPGohAUEAIQADQCABQQcgACIAa0ECdGoiAiAC/QACACAG/Q0MDQ4PCAkKCwQFBgcAAQIDIAb9DQMCAQAHBgUECwoJCA8ODQwgBv0NDA0ODwgJCgsEBQYHAAECA/0LAgAgAEEEaiICIQAgAkEIRw0ACwJAQQAoAsAKIgMoAmhFDQAgA0HIAGohBEEAIQBBACECA0BBgAsgACIAaiAEIABqLQAAOgAAIAJBAWoiAkH/AXEiASEAIAIhAiADKAJoIAFLDQALCwtxAQJ/QQAoAsAKIgFCADcDQCABQcgAaiECAkAgAEHgAUcNACABQRw2AmggAkEQakEA/QAEsAj9CwIAIAJBAP0ABKAI/QsCAEEADwsgAUEgNgJoIAJBEGpBAP0ABJAI/QsCACACQQD9AASACP0LAgBBAAsFAEGACwvyAgIDfwF+AkAgAkUNACAAIAE6AAAgACACaiIDQX9qIAE6AAAgAkEDSQ0AIAAgAToAAiAAIAE6AAEgA0F9aiABOgAAIANBfmogAToAACACQQdJDQAgACABOgADIANBfGogAToAACACQQlJDQAgAEEAIABrQQNxIgRqIgMgAUH/AXFBgYKECGwiATYCACADIAIgBGtBfHEiBGoiAkF8aiABNgIAIARBCUkNACADIAE2AgggAyABNgIEIAJBeGogATYCACACQXRqIAE2AgAgBEEZSQ0AIAMgATYCGCADIAE2AhQgAyABNgIQIAMgATYCDCACQXBqIAE2AgAgAkFsaiABNgIAIAJBaGogATYCACACQWRqIAE2AgAgBCADQQRxQRhyIgVrIgJBIEkNACABrUKBgICAEH4hBiADIAVqIQEDQCABIAY3AxggASAGNwMQIAEgBjcDCCABIAY3AwAgAUEgaiEBIAJBYGoiAkEfSw0ACwsgAAsGACAAJAELBAAjAQsEACMACwYAIAAkAAsSAQJ/IwAgAGtBcHEiASQAIAELC9ICAgBBgAgLwAJn5glqha5nu3Lzbjw69U+lf1IOUYxoBZur2YMfGc3gW9ieBcEH1Xw2F91wMDlZDvcxC8D/ERVYaKeP+WSkT/q+mC+KQpFEN3HP+8C1pdu16VvCVjnxEfFZpII/ktVeHKuYqgfYAVuDEr6FMSTDfQxVdF2+cv6x3oCnBtybdPGbwcFpm+SGR77vxp3BD8yhDCRvLOktqoR0StypsFzaiPl2UlE+mG3GMajIJwOwx39Zv/ML4MZHkafVUWPKBmcpKRSFCrcnOCEbLvxtLE0TDThTVHMKZbsKanYuycKBhSxykqHov6JLZhqocItLwqNRbMcZ6JLRJAaZ1oU1DvRwoGoQFsGkGQhsNx5Md0gntbywNLMMHDlKqthOT8qcW/NvLmjugo90b2OleBR4yIQIAseM+v++kOtsUKT3o/m+8nhxxgBBwAoLBIAFgAA='; + if (!isDataURI(wasmBinaryFile)) { + wasmBinaryFile = locateFile(wasmBinaryFile); + } + +function getBinarySync(file) { + if (file == wasmBinaryFile && wasmBinary) { + return new Uint8Array(wasmBinary); + } + var binary = tryParseAsDataURI(file); + if (binary) { + return binary; + } + if (readBinary) { + return readBinary(file); + } + throw 'both async and sync fetching of the wasm failed'; +} + +function getBinaryPromise(binaryFile) { + + // Otherwise, getBinarySync should be able to get it synchronously + return Promise.resolve().then(() => getBinarySync(binaryFile)); +} + +function instantiateArrayBuffer(binaryFile, imports, receiver) { + return getBinaryPromise(binaryFile).then((binary) => { + return WebAssembly.instantiate(binary, imports); + }).then(receiver, (reason) => { + err(`failed to asynchronously prepare wasm: ${reason}`); + + abort(reason); + }); +} + +function instantiateAsync(binary, binaryFile, imports, callback) { + return instantiateArrayBuffer(binaryFile, imports, callback); +} + +// Create the wasm instance. +// Receives the wasm imports, returns the exports. +function createWasm() { + // prepare imports + var info = { + 'env': wasmImports, + 'wasi_snapshot_preview1': wasmImports, + }; + // Load the wasm module and create an instance of using native support in the JS engine. + // handle a generated wasm instance, receiving its exports and + // performing other necessary setup + /** @param {WebAssembly.Module=} module*/ + function receiveInstance(instance, module) { + wasmExports = instance.exports; + + + + wasmMemory = wasmExports['memory']; + + updateMemoryViews(); + + addOnInit(wasmExports['__wasm_call_ctors']); + + removeRunDependency('wasm-instantiate'); + return wasmExports; + } + // wait for the pthread pool (if any) + addRunDependency('wasm-instantiate'); + + // Prefer streaming instantiation if available. + function receiveInstantiationResult(result) { + // 'result' is a ResultObject object which has both the module and instance. + // receiveInstance() will swap in the exports (to Module.asm) so they can be called + // TODO: Due to Closure regression https://github.com/google/closure-compiler/issues/3193, the above line no longer optimizes out down to the following line. + // When the regression is fixed, can restore the above PTHREADS-enabled path. + receiveInstance(result['instance']); + } + + // User shell pages can write their own Module.instantiateWasm = function(imports, successCallback) callback + // to manually instantiate the Wasm module themselves. This allows pages to + // run the instantiation parallel to any other async startup actions they are + // performing. + // Also pthreads and wasm workers initialize the wasm instance through this + // path. + if (Module['instantiateWasm']) { + + try { + return Module['instantiateWasm'](info, receiveInstance); + } catch(e) { + err(`Module.instantiateWasm callback failed with error: ${e}`); + // If instantiation fails, reject the module ready promise. + readyPromiseReject(e); + } + } + + // If instantiation fails, reject the module ready promise. + instantiateAsync(wasmBinary, wasmBinaryFile, info, receiveInstantiationResult).catch(readyPromiseReject); + return {}; // no exports yet; we'll fill them in later +} + +// Globals used by JS i64 conversions (see makeSetValue) +var tempDouble; +var tempI64; + +// include: runtime_debug.js +// end include: runtime_debug.js +// === Body === +// end include: preamble.js + + + /** @constructor */ + function ExitStatus(status) { + this.name = 'ExitStatus'; + this.message = `Program terminated with exit(${status})`; + this.status = status; + } + + var callRuntimeCallbacks = (callbacks) => { + while (callbacks.length > 0) { + // Pass the module as the first argument. + callbacks.shift()(Module); + } + }; + + + /** + * @param {number} ptr + * @param {string} type + */ + function getValue(ptr, type = 'i8') { + if (type.endsWith('*')) type = '*'; + switch (type) { + case 'i1': return HEAP8[ptr]; + case 'i8': return HEAP8[ptr]; + case 'i16': return HEAP16[((ptr)>>1)]; + case 'i32': return HEAP32[((ptr)>>2)]; + case 'i64': abort('to do getValue(i64) use WASM_BIGINT'); + case 'float': return HEAPF32[((ptr)>>2)]; + case 'double': return HEAPF64[((ptr)>>3)]; + case '*': return HEAPU32[((ptr)>>2)]; + default: abort(`invalid type for getValue: ${type}`); + } + } + + var noExitRuntime = Module['noExitRuntime'] || true; + + + /** + * @param {number} ptr + * @param {number} value + * @param {string} type + */ + function setValue(ptr, value, type = 'i8') { + if (type.endsWith('*')) type = '*'; + switch (type) { + case 'i1': HEAP8[ptr] = value; break; + case 'i8': HEAP8[ptr] = value; break; + case 'i16': HEAP16[((ptr)>>1)] = value; break; + case 'i32': HEAP32[((ptr)>>2)] = value; break; + case 'i64': abort('to do setValue(i64) use WASM_BIGINT'); + case 'float': HEAPF32[((ptr)>>2)] = value; break; + case 'double': HEAPF64[((ptr)>>3)] = value; break; + case '*': HEAPU32[((ptr)>>2)] = value; break; + default: abort(`invalid type for setValue: ${type}`); + } + } +var wasmImports = { + +}; +var wasmExports = createWasm(); +var ___wasm_call_ctors = () => (___wasm_call_ctors = wasmExports['__wasm_call_ctors'])(); +var _Hash_Update = Module['_Hash_Update'] = (a0) => (_Hash_Update = Module['_Hash_Update'] = wasmExports['Hash_Update'])(a0); +var _Hash_Final = Module['_Hash_Final'] = () => (_Hash_Final = Module['_Hash_Final'] = wasmExports['Hash_Final'])(); +var _Hash_Init = Module['_Hash_Init'] = (a0) => (_Hash_Init = Module['_Hash_Init'] = wasmExports['Hash_Init'])(a0); +var _GetBufferPtr = Module['_GetBufferPtr'] = () => (_GetBufferPtr = Module['_GetBufferPtr'] = wasmExports['GetBufferPtr'])(); +var stackSave = () => (stackSave = wasmExports['stackSave'])(); +var stackRestore = (a0) => (stackRestore = wasmExports['stackRestore'])(a0); +var stackAlloc = (a0) => (stackAlloc = wasmExports['stackAlloc'])(a0); + + +// include: postamble.js +// === Auto-generated postamble setup entry stuff === + + + + +var calledRun; + +dependenciesFulfilled = function runCaller() { + // If run has never been called, and we should call run (INVOKE_RUN is true, and Module.noInitialRun is not false) + if (!calledRun) run(); + if (!calledRun) dependenciesFulfilled = runCaller; // try this again later, after new deps are fulfilled +}; + +function run() { + + if (runDependencies > 0) { + return; + } + + preRun(); + + // a preRun added a dependency, run will be called later + if (runDependencies > 0) { + return; + } + + function doRun() { + // run may have just been called through dependencies being fulfilled just in this very frame, + // or while the async setStatus time below was happening + if (calledRun) return; + calledRun = true; + Module['calledRun'] = true; + + if (ABORT) return; + + initRuntime(); + + readyPromiseResolve(Module); + if (Module['onRuntimeInitialized']) Module['onRuntimeInitialized'](); + + postRun(); + } + + if (Module['setStatus']) { + Module['setStatus']('Running...'); + setTimeout(function() { + setTimeout(function() { + Module['setStatus'](''); + }, 1); + doRun(); + }, 1); + } else + { + doRun(); + } +} + +if (Module['preInit']) { + if (typeof Module['preInit'] == 'function') Module['preInit'] = [Module['preInit']]; + while (Module['preInit'].length > 0) { + Module['preInit'].pop()(); + } +} + +run(); + +// end include: postamble.js + + + + return moduleArg.ready +} +); +})(); +export default Module; \ No newline at end of file diff --git a/vendor/lz4js/LICENSE b/vendor/lz4js/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7dfda0d31577cba1033826adacb3c799c234790a --- /dev/null +++ b/vendor/lz4js/LICENSE @@ -0,0 +1,11 @@ +ISC License + +Copyright 2019 John Chadwick + +Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +--- + +Note: this license is not actually included in https://github.com/Benzinga/lz4js, but the package.json specifies the ISC license diff --git a/vendor/lz4js/index.ts b/vendor/lz4js/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..e258051c1b24021a6d7edbd2d4838a1df502d885 --- /dev/null +++ b/vendor/lz4js/index.ts @@ -0,0 +1,537 @@ +// lz4.js - An implementation of Lz4 in plain JavaScript. +// +// TODO: +// - Unify header parsing/writing. +// - Support options (block size, checksums) +// - Support streams +// - Better error handling (handle bad offset, etc.) +// - HC support (better search algorithm) +// - Tests/benchmarking + +import * as xxhash from "./xxh32.js"; +import * as util from "./util.js"; + +// Constants +// -- + +// Compression format parameters/constants. +const minMatch = 4; +const matchSearchLimit = 12; +const minTrailingLitterals = 5; +const skipTrigger = 6; +const hashSize = 1 << 16; + +// Token constants. +const mlBits = 4; +const mlMask = (1 << mlBits) - 1; +const runBits = 4; +const runMask = (1 << runBits) - 1; + +// Shared buffers +const blockBuf = makeBuffer(5 << 20); +const hashTable = makeHashTable(); + +// Frame constants. +const magicNum = 0x184d2204; + +// Frame descriptor flags. +const fdContentChksum = 0x4; +const fdContentSize = 0x8; +const fdBlockChksum = 0x10; +// var fdBlockIndep = 0x20; +const fdVersion = 0x40; +const fdVersionMask = 0xc0; + +// Block sizes. +const bsUncompressed = 0x80000000; +const bsDefault = 7; +const bsShift = 4; +const bsMask = 7; +const bsMap: Record = { + 4: 0x10000, + 5: 0x40000, + 6: 0x100000, + 7: 0x400000, +}; + +// Utility functions/primitives +// -- + +// Makes our hashtable. On older browsers, may return a plain array. +function makeHashTable() { + try { + return new Uint32Array(hashSize); + } catch (error) { + const hashTable = new Array(hashSize); + + for (let i = 0; i < hashSize; i++) { + hashTable[i] = 0; + } + + return hashTable; + } +} + +// Clear hashtable. +function clearHashTable(table: Uint32Array | number[]) { + for (let i = 0; i < hashSize; i++) { + table[i] = 0; + } +} + +// Makes a byte buffer. On older browsers, may return a plain array. +function makeBuffer(size: number) { + return new Uint8Array(size); +} + +function sliceArray(array: Uint8Array, start: number, end: number) { + return array.slice(start, end); +} + +// Implementation +// -- + +// Calculates an upper bound for lz4 compression. +export function compressBound(n: number) { + return (n + n / 255 + 16) | 0; +} + +// Calculates an upper bound for lz4 decompression, by reading the data. +export function decompressBound(src: Uint8Array) { + let sIndex = 0; + + // Read magic number + if (util.readU32(src, sIndex) !== magicNum) { + throw new Error("invalid magic number"); + } + + sIndex += 4; + + // Read descriptor + const descriptor = src[sIndex++]; + + // Check version + if ((descriptor & fdVersionMask) !== fdVersion) { + throw new Error("incompatible descriptor version " + (descriptor & fdVersionMask)); + } + + // Read flags + const useBlockSum = (descriptor & fdBlockChksum) !== 0; + const useContentSize = (descriptor & fdContentSize) !== 0; + + // Read block size + const bsIdx = (src[sIndex++] >> bsShift) & bsMask; + + if (bsMap[bsIdx] === undefined) { + throw new Error("invalid block size " + bsIdx); + } + + const maxBlockSize = bsMap[bsIdx]; + + // Get content size + if (useContentSize) { + return util.readU64(src, sIndex); + } + + // Checksum + sIndex++; + + // Read blocks. + let maxSize = 0; + while (true) { + let blockSize = util.readU32(src, sIndex); + sIndex += 4; + + if (blockSize & bsUncompressed) { + blockSize &= ~bsUncompressed; + maxSize += blockSize; + } else if (blockSize > 0) { + maxSize += maxBlockSize; + } + + if (blockSize === 0) { + return maxSize; + } + + if (useBlockSum) { + sIndex += 4; + } + + sIndex += blockSize; + } +} + +// Decompresses a block of Lz4. +export function decompressBlock(src: Uint8Array, dst: Uint8Array, sIndex: number, sLength: number, dIndex: number) { + let mLength, mOffset, sEnd, n, i; + const hasCopyWithin = dst.copyWithin !== undefined && dst.fill !== undefined; + + // Setup initial state. + sEnd = sIndex + sLength; + + // Consume entire input block. + while (sIndex < sEnd) { + const token = src[sIndex++]; + + // Copy literals. + let literalCount = token >> 4; + if (literalCount > 0) { + // Parse length. + if (literalCount === 0xf) { + while (true) { + literalCount += src[sIndex]; + if (src[sIndex++] !== 0xff) { + break; + } + } + } + + // Copy literals + for (n = sIndex + literalCount; sIndex < n; ) { + dst[dIndex++] = src[sIndex++]; + } + } + + if (sIndex >= sEnd) { + break; + } + + // Copy match. + mLength = token & 0xf; + + // Parse offset. + mOffset = src[sIndex++] | (src[sIndex++] << 8); + + // Parse length. + if (mLength === 0xf) { + while (true) { + mLength += src[sIndex]; + if (src[sIndex++] !== 0xff) { + break; + } + } + } + + mLength += minMatch; + + // Copy match + // prefer to use typedarray.copyWithin for larger matches + // NOTE: copyWithin doesn't work as required by LZ4 for overlapping sequences + // e.g. mOffset=1, mLength=30 (repeach char 30 times) + // we special case the repeat char w/ array.fill + if (hasCopyWithin && mOffset === 1) { + dst.fill(dst[dIndex - 1] | 0, dIndex, dIndex + mLength); + dIndex += mLength; + } else if (hasCopyWithin && mOffset > mLength && mLength > 31) { + dst.copyWithin(dIndex, dIndex - mOffset, dIndex - mOffset + mLength); + dIndex += mLength; + } else { + for (i = dIndex - mOffset, n = i + mLength; i < n; ) { + dst[dIndex++] = dst[i++] | 0; + } + } + } + + return dIndex; +} + +// Compresses a block with Lz4. +export function compressBlock( + src: Uint8Array, + dst: Uint8Array, + sIndex: number, + sLength: number, + hashTable: Uint32Array | number[] +) { + let mIndex, mAnchor, mLength, mOffset, mStep; + let literalCount, dIndex, sEnd, n; + + // Setup initial state. + dIndex = 0; + sEnd = sLength + sIndex; + mAnchor = sIndex; + + let searchMatchCount = (1 << skipTrigger) + 3; + + // Search for matches with a limit of matchSearchLimit bytes + // before the end of block (Lz4 spec limitation.) + while (sIndex <= sEnd - matchSearchLimit) { + const seq = util.readU32(src, sIndex); + let hash = util.hashU32(seq) >>> 0; + + // Crush hash to 16 bits. + hash = (((hash >> 16) ^ hash) >>> 0) & 0xffff; + + // Look for a match in the hashtable. NOTE: remove one; see below. + mIndex = hashTable[hash] - 1; + + // Put pos in hash table. NOTE: add one so that zero = invalid. + hashTable[hash] = sIndex + 1; + + // Determine if there is a match (within range.) + if (mIndex < 0 || (sIndex - mIndex) >>> 16 > 0 || util.readU32(src, mIndex) !== seq) { + mStep = searchMatchCount++ >> skipTrigger; + sIndex += mStep; + continue; + } + + searchMatchCount = (1 << skipTrigger) + 3; + + // Calculate literal count and offset. + literalCount = sIndex - mAnchor; + mOffset = sIndex - mIndex; + + // We've already matched one word, so get that out of the way. + sIndex += minMatch; + mIndex += minMatch; + + // Determine match length. + // N.B.: mLength does not include minMatch, Lz4 adds it back + // in decoding. + mLength = sIndex; + while (sIndex < sEnd - minTrailingLitterals && src[sIndex] === src[mIndex]) { + sIndex++; + mIndex++; + } + mLength = sIndex - mLength; + + // Write token + literal count. + const token = mLength < mlMask ? mLength : mlMask; + if (literalCount >= runMask) { + dst[dIndex++] = (runMask << mlBits) + token; + for (n = literalCount - runMask; n >= 0xff; n -= 0xff) { + dst[dIndex++] = 0xff; + } + dst[dIndex++] = n; + } else { + dst[dIndex++] = (literalCount << mlBits) + token; + } + + // Write literals. + for (let i = 0; i < literalCount; i++) { + dst[dIndex++] = src[mAnchor + i]; + } + + // Write offset. + dst[dIndex++] = mOffset; + dst[dIndex++] = mOffset >> 8; + + // Write match length. + if (mLength >= mlMask) { + for (n = mLength - mlMask; n >= 0xff; n -= 0xff) { + dst[dIndex++] = 0xff; + } + dst[dIndex++] = n; + } + + // Move the anchor. + mAnchor = sIndex; + } + + // Nothing was encoded. + if (mAnchor === 0) { + return 0; + } + + // Write remaining literals. + // Write literal token+count. + literalCount = sEnd - mAnchor; + if (literalCount >= runMask) { + dst[dIndex++] = runMask << mlBits; + for (n = literalCount - runMask; n >= 0xff; n -= 0xff) { + dst[dIndex++] = 0xff; + } + dst[dIndex++] = n; + } else { + dst[dIndex++] = literalCount << mlBits; + } + + // Write literals. + sIndex = mAnchor; + while (sIndex < sEnd) { + dst[dIndex++] = src[sIndex++]; + } + + return dIndex; +} + +// Decompresses a frame of Lz4 data. +export function decompressFrame(src: Uint8Array, dst: Uint8Array) { + let useBlockSum, useContentSum, useContentSize, descriptor; + let sIndex = 0; + let dIndex = 0; + + // Read magic number + if (util.readU32(src, sIndex) !== magicNum) { + throw new Error("invalid magic number"); + } + + sIndex += 4; + + // Read descriptor + descriptor = src[sIndex++]; + + // Check version + if ((descriptor & fdVersionMask) !== fdVersion) { + throw new Error("incompatible descriptor version"); + } + + // Read flags + useBlockSum = (descriptor & fdBlockChksum) !== 0; + useContentSum = (descriptor & fdContentChksum) !== 0; + useContentSize = (descriptor & fdContentSize) !== 0; + + // Read block size + const bsIdx = (src[sIndex++] >> bsShift) & bsMask; + + if (bsMap[bsIdx] === undefined) { + throw new Error("invalid block size"); + } + + if (useContentSize) { + // TODO: read content size + sIndex += 8; + } + + sIndex++; + + // Read blocks. + while (true) { + var compSize; + + compSize = util.readU32(src, sIndex); + sIndex += 4; + + if (compSize === 0) { + break; + } + + if (useBlockSum) { + // TODO: read block checksum + sIndex += 4; + } + + // Check if block is compressed + if ((compSize & bsUncompressed) !== 0) { + // Mask off the 'uncompressed' bit + compSize &= ~bsUncompressed; + + // Copy uncompressed data into destination buffer. + for (let j = 0; j < compSize; j++) { + dst[dIndex++] = src[sIndex++]; + } + } else { + // Decompress into blockBuf + dIndex = decompressBlock(src, dst, sIndex, compSize, dIndex); + sIndex += compSize; + } + } + + if (useContentSum) { + // TODO: read content checksum + sIndex += 4; + } + + return dIndex; +} + +// Compresses data to an Lz4 frame. +export function compressFrame(src: Uint8Array, dst: Uint8Array) { + let dIndex = 0; + + // Write magic number. + util.writeU32(dst, dIndex, magicNum); + dIndex += 4; + + // Descriptor flags. + dst[dIndex++] = fdVersion; + dst[dIndex++] = bsDefault << bsShift; + + // Descriptor checksum. + dst[dIndex] = xxhash.hash(0, dst, 4, dIndex - 4) >> 8; + dIndex++; + + // Write blocks. + const maxBlockSize = bsMap[bsDefault]; + let remaining = src.length; + let sIndex = 0; + + // Clear the hashtable. + clearHashTable(hashTable); + + // Split input into blocks and write. + while (remaining > 0) { + let compSize = 0; + const blockSize = remaining > maxBlockSize ? maxBlockSize : remaining; + + compSize = compressBlock(src, blockBuf, sIndex, blockSize, hashTable); + + if (compSize > blockSize || compSize === 0) { + // Output uncompressed. + util.writeU32(dst, dIndex, 0x80000000 | blockSize); + dIndex += 4; + + for (let z = sIndex + blockSize; sIndex < z; ) { + dst[dIndex++] = src[sIndex++]; + } + + remaining -= blockSize; + } else { + // Output compressed. + util.writeU32(dst, dIndex, compSize); + dIndex += 4; + + for (let j = 0; j < compSize; ) { + dst[dIndex++] = blockBuf[j++]; + } + + sIndex += blockSize; + remaining -= blockSize; + } + } + + // Write blank end block. + util.writeU32(dst, dIndex, 0); + dIndex += 4; + + return dIndex; +} + +// Decompresses a buffer containing an Lz4 frame. maxSize is optional; if not +// provided, a maximum size will be determined by examining the data. The +// buffer returned will always be perfectly-sized. +export function decompress(src: Uint8Array, maxSize: number) { + let dst, size; + + if (maxSize === undefined) { + maxSize = decompressBound(src); + } + dst = makeBuffer(maxSize); + size = decompressFrame(src, dst); + + if (size !== maxSize) { + dst = sliceArray(dst, 0, size); + } + + return dst; +} + +// Compresses a buffer to an Lz4 frame. maxSize is optional; if not provided, +// a buffer will be created based on the theoretical worst output size for a +// given input size. The buffer returned will always be perfectly-sized. +export function compress(src: Uint8Array, maxSize: number) { + let dst, size; + + if (maxSize === undefined) { + maxSize = compressBound(src.length); + } + + dst = makeBuffer(maxSize); + size = compressFrame(src, dst); + + if (size !== maxSize) { + dst = sliceArray(dst, 0, size); + } + + return dst; +} diff --git a/vendor/lz4js/util.ts b/vendor/lz4js/util.ts new file mode 100644 index 0000000000000000000000000000000000000000..5d579b76604d2fe13537423de069f61eca50a2d4 --- /dev/null +++ b/vendor/lz4js/util.ts @@ -0,0 +1,54 @@ +// Simple hash function, from: http://burtleburtle.net/bob/hash/integer.html. +// Chosen because it doesn't use multiply and achieves full avalanche. +export function hashU32(a: number): number { + a = a | 0; + a = (a + 2127912214 + (a << 12)) | 0; + a = a ^ -949894596 ^ (a >>> 19); + a = (a + 374761393 + (a << 5)) | 0; + a = (a + -744332180) ^ (a << 9); + a = (a + -42973499 + (a << 3)) | 0; + return (a ^ -1252372727 ^ (a >>> 16)) | 0; +} + +// Reads a 64-bit little-endian integer from an array. +export function readU64(b: Uint8Array, n: number): number { + let x = 0; + x |= b[n++] << 0; + x |= b[n++] << 8; + x |= b[n++] << 16; + x |= b[n++] << 24; + x |= b[n++] << 32; + x |= b[n++] << 40; + x |= b[n++] << 48; + x |= b[n++] << 56; + return x; +} + +// Reads a 32-bit little-endian integer from an array. +export function readU32(b: Uint8Array, n: number): number { + let x = 0; + x |= b[n++] << 0; + x |= b[n++] << 8; + x |= b[n++] << 16; + x |= b[n++] << 24; + return x; +} + +// Writes a 32-bit little-endian integer from an array. +export function writeU32(b: Uint8Array, n: number, x: number): void { + b[n++] = (x >> 0) & 0xff; + b[n++] = (x >> 8) & 0xff; + b[n++] = (x >> 16) & 0xff; + b[n++] = (x >> 24) & 0xff; +} + +// Multiplies two numbers using 32-bit integer multiplication. +// Algorithm from Emscripten. +export function imul(a: number, b: number): number { + const ah = a >>> 16; + const al = a & 65535; + const bh = b >>> 16; + const bl = b & 65535; + + return (al * bl + ((ah * bl + al * bh) << 16)) | 0; +} diff --git a/vendor/lz4js/xxh32.ts b/vendor/lz4js/xxh32.ts new file mode 100644 index 0000000000000000000000000000000000000000..b9d135c657c69da01cda9beab2abe1b06b009e2e --- /dev/null +++ b/vendor/lz4js/xxh32.ts @@ -0,0 +1,96 @@ +// xxh32.js - implementation of xxhash32 in plain JavaScript +import * as util from "./util.js"; + +// xxhash32 primes +const prime1 = 0x9e3779b1; +const prime2 = 0x85ebca77; +const prime3 = 0xc2b2ae3d; +const prime4 = 0x27d4eb2f; +const prime5 = 0x165667b1; + +// Utility functions/primitives +// -- +function rotl32(x: number, r: number): number { + x = x | 0; + r = r | 0; + + return (x >>> ((32 - r) | 0)) | (x << r) | 0; +} + +function rotmul32(h: number, r: number, m: number): number { + h = h | 0; + r = r | 0; + m = m | 0; + + return util.imul((h >>> ((32 - r) | 0)) | (h << r), m) | 0; +} + +function shiftxor32(h: number, s: number): number { + h = h | 0; + s = s | 0; + + return ((h >>> s) ^ h) | 0; +} + +// Implementation +// -- + +function xxhapply(h: number, src: number, m0: number, s: number, m1: number): number { + return rotmul32(util.imul(src, m0) + h, s, m1); +} + +function xxh1(h: number, src: Uint8Array, index: number): number { + return rotmul32(h + util.imul(src[index], prime5), 11, prime1); +} + +function xxh4(h: number, src: Uint8Array, index: number): number { + return xxhapply(h, util.readU32(src, index), prime3, 17, prime4); +} + +function xxh16(h: number[], src: Uint8Array, index: number): number[] { + return [ + xxhapply(h[0], util.readU32(src, index + 0), prime2, 13, prime1), + xxhapply(h[1], util.readU32(src, index + 4), prime2, 13, prime1), + xxhapply(h[2], util.readU32(src, index + 8), prime2, 13, prime1), + xxhapply(h[3], util.readU32(src, index + 12), prime2, 13, prime1), + ]; +} + +function xxh32(seed: number, src: Uint8Array, index: number, len: number): number { + let h; + const l = len; + if (len >= 16) { + h = [seed + prime1 + prime2, seed + prime2, seed, seed - prime1]; + + while (len >= 16) { + h = xxh16(h, src, index); + + index += 16; + len -= 16; + } + + h = rotl32(h[0], 1) + rotl32(h[1], 7) + rotl32(h[2], 12) + rotl32(h[3], 18) + l; + } else { + h = (seed + prime5 + len) >>> 0; + } + + while (len >= 4) { + h = xxh4(h, src, index); + + index += 4; + len -= 4; + } + + while (len > 0) { + h = xxh1(h, src, index); + + index++; + len--; + } + + h = shiftxor32(util.imul(shiftxor32(util.imul(shiftxor32(h, 15), prime2), 13), prime3), 16); + + return h >>> 0; +} + +export const hash = xxh32; diff --git a/vendor/type-fest/basic.ts b/vendor/type-fest/basic.ts new file mode 100644 index 0000000000000000000000000000000000000000..3fa40a039955fec376a3e8cdc42dc425a00a859b --- /dev/null +++ b/vendor/type-fest/basic.ts @@ -0,0 +1,31 @@ +/** +Matches a JSON object. + +This type can be useful to enforce some input to be JSON-compatible or as a super-type to be extended from. Don't use this as a direct return type as the user would have to double-cast it: `jsonObject as unknown as CustomResponse`. Instead, you could extend your CustomResponse type from it to ensure your type only uses JSON-compatible types: `interface CustomResponse extends JsonObject { … }`. + +@category JSON +*/ +export type JsonObject = { [Key in string]: JsonValue } & { [Key in string]?: JsonValue | undefined }; + +/** +Matches a JSON array. + +@category JSON +*/ +export type JsonArray = JsonValue[] | readonly JsonValue[]; + +/** +Matches any valid JSON primitive value. + +@category JSON +*/ +export type JsonPrimitive = string | number | boolean | null; + +/** +Matches any valid JSON value. + +@see `Jsonify` if you need to transform a type to one that is assignable to `JsonValue`. + +@category JSON +*/ +export type JsonValue = JsonPrimitive | JsonObject | JsonArray; diff --git a/vendor/type-fest/entries.ts b/vendor/type-fest/entries.ts new file mode 100644 index 0000000000000000000000000000000000000000..7716e47480b299a3b846f59df3349bbb355be5ee --- /dev/null +++ b/vendor/type-fest/entries.ts @@ -0,0 +1,65 @@ +import type { ArrayEntry, MapEntry, ObjectEntry, SetEntry } from "./entry"; + +type ArrayEntries = Array>; +type MapEntries = Array>; +type ObjectEntries = Array>; +type SetEntries> = Array>; + +/** +Many collections have an `entries` method which returns an array of a given object's own enumerable string-keyed property [key, value] pairs. The `Entries` type will return the type of that collection's entries. + +For example the {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/entries|`Object`}, {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/entries|`Map`}, {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/entries|`Array`}, and {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/entries|`Set`} collections all have this method. Note that `WeakMap` and `WeakSet` do not have this method since their entries are not enumerable. + +@see `Entry` if you want to just access the type of a single entry. + +@example +``` +import type {Entries} from 'type-fest'; + +interface Example { + someKey: number; +} + +const manipulatesEntries = (examples: Entries) => examples.map(example => [ + // Does some arbitrary processing on the key (with type information available) + example[0].toUpperCase(), + + // Does some arbitrary processing on the value (with type information available) + example[1].toFixed() +]); + +const example: Example = {someKey: 1}; +const entries = Object.entries(example) as Entries; +const output = manipulatesEntries(entries); + +// Objects +const objectExample = {a: 1}; +const objectEntries: Entries = [['a', 1]]; + +// Arrays +const arrayExample = ['a', 1]; +const arrayEntries: Entries = [[0, 'a'], [1, 1]]; + +// Maps +const mapExample = new Map([['a', 1]]); +const mapEntries: Entries = [['a', 1]]; + +// Sets +const setExample = new Set(['a', 1]); +const setEntries: Entries = [['a', 'a'], [1, 1]]; +``` + +@category Object +@category Map +@category Set +@category Array +*/ +export type Entries = BaseType extends Map + ? MapEntries + : BaseType extends Set + ? SetEntries + : BaseType extends readonly unknown[] + ? ArrayEntries + : BaseType extends object + ? ObjectEntries + : never; diff --git a/vendor/type-fest/entry.ts b/vendor/type-fest/entry.ts new file mode 100644 index 0000000000000000000000000000000000000000..ed3650c90cf54bbe72b039e84e70158ad0cf3bd8 --- /dev/null +++ b/vendor/type-fest/entry.ts @@ -0,0 +1,68 @@ +type MapKey = BaseType extends Map ? KeyType : never; +type MapValue = BaseType extends Map ? ValueType : never; + +export type ArrayEntry = [number, BaseType[number]]; +export type MapEntry = [MapKey, MapValue]; +export type ObjectEntry = [keyof BaseType, BaseType[keyof BaseType]]; +export type SetEntry = BaseType extends Set ? [ItemType, ItemType] : never; + +/** +Many collections have an `entries` method which returns an array of a given object's own enumerable string-keyed property [key, value] pairs. The `Entry` type will return the type of that collection's entry. + +For example the {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Object/entries|`Object`}, {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Map/entries|`Map`}, {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/entries|`Array`}, and {@link https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Set/entries|`Set`} collections all have this method. Note that `WeakMap` and `WeakSet` do not have this method since their entries are not enumerable. + +@see `Entries` if you want to just access the type of the array of entries (which is the return of the `.entries()` method). + +@example +``` +import type {Entry} from 'type-fest'; + +interface Example { + someKey: number; +} + +const manipulatesEntry = (example: Entry) => [ + // Does some arbitrary processing on the key (with type information available) + example[0].toUpperCase(), + + // Does some arbitrary processing on the value (with type information available) + example[1].toFixed(), +]; + +const example: Example = {someKey: 1}; +const entry = Object.entries(example)[0] as Entry; +const output = manipulatesEntry(entry); + +// Objects +const objectExample = {a: 1}; +const objectEntry: Entry = ['a', 1]; + +// Arrays +const arrayExample = ['a', 1]; +const arrayEntryString: Entry = [0, 'a']; +const arrayEntryNumber: Entry = [1, 1]; + +// Maps +const mapExample = new Map([['a', 1]]); +const mapEntry: Entry = ['a', 1]; + +// Sets +const setExample = new Set(['a', 1]); +const setEntryString: Entry = ['a', 'a']; +const setEntryNumber: Entry = [1, 1]; +``` + +@category Object +@category Map +@category Array +@category Set +*/ +export type Entry = BaseType extends Map + ? MapEntry + : BaseType extends Set + ? SetEntry + : BaseType extends readonly unknown[] + ? ArrayEntry + : BaseType extends object + ? ObjectEntry + : never; diff --git a/vendor/type-fest/except.ts b/vendor/type-fest/except.ts new file mode 100644 index 0000000000000000000000000000000000000000..b18f739d16141d613e5d922969f2924655060f30 --- /dev/null +++ b/vendor/type-fest/except.ts @@ -0,0 +1,71 @@ +import type { IsEqual } from "./is-equal"; + +/** +Filter out keys from an object. + +Returns `never` if `Exclude` is strictly equal to `Key`. +Returns `never` if `Key` extends `Exclude`. +Returns `Key` otherwise. + +@example +``` +type Filtered = Filter<'foo', 'foo'>; +//=> never +``` + +@example +``` +type Filtered = Filter<'bar', string>; +//=> never +``` + +@example +``` +type Filtered = Filter<'bar', 'foo'>; +//=> 'bar' +``` + +@see {Except} +*/ +type Filter = IsEqual extends true + ? never + : KeyType extends ExcludeType + ? never + : KeyType; + +/** +Create a type from an object type without certain keys. + +We recommend setting the `requireExactProps` option to `true`. + +This type is a stricter version of [`Omit`](https://www.typescriptlang.org/docs/handbook/release-notes/typescript-3-5.html#the-omit-helper-type). The `Omit` type does not restrict the omitted keys to be keys present on the given type, while `Except` does. The benefits of a stricter type are avoiding typos and allowing the compiler to pick up on rename refactors automatically. + +This type was proposed to the TypeScript team, which declined it, saying they prefer that libraries implement stricter versions of the built-in types ([microsoft/TypeScript#30825](https://github.com/microsoft/TypeScript/issues/30825#issuecomment-523668235)). + +@example +``` +import type {Except} from 'type-fest'; + +type Foo = { + a: number; + b: string; +}; + +type FooWithoutA = Except; +//=> {b: string} + +const fooWithoutA: FooWithoutA = {a: 1, b: '2'}; +//=> errors: 'a' does not exist in type '{ b: string; }' + +type FooWithoutB = Except; +//=> {a: number} & Partial> + +const fooWithoutB: FooWithoutB = {a: 1, b: '2'}; +//=> errors at 'b': Type 'string' is not assignable to type 'undefined'. +``` + +@category Object +*/ +export type Except = { + [KeyType in keyof ObjectType as Filter]: ObjectType[KeyType]; +}; diff --git a/vendor/type-fest/is-equal.ts b/vendor/type-fest/is-equal.ts new file mode 100644 index 0000000000000000000000000000000000000000..d6ff2e53c4d25df11584f0a59504e08170e0a3b8 --- /dev/null +++ b/vendor/type-fest/is-equal.ts @@ -0,0 +1,27 @@ +/** +Returns a boolean for whether the two given types are equal. + +@link https://github.com/microsoft/TypeScript/issues/27024#issuecomment-421529650 +@link https://stackoverflow.com/questions/68961864/how-does-the-equals-work-in-typescript/68963796#68963796 + +Use-cases: +- If you want to make a conditional branch based on the result of a comparison of two types. + +@example +``` +import type {IsEqual} from 'type-fest'; + +// This type returns a boolean for whether the given array includes the given item. +// `IsEqual` is used to compare the given array at position 0 and the given item and then return true if they are equal. +type Includes = + Value extends readonly [Value[0], ...infer rest] + ? IsEqual extends true + ? true + : Includes + : false; +``` + +@category Type Guard +@category Utilities +*/ +export type IsEqual = (() => G extends A ? 1 : 2) extends () => G extends B ? 1 : 2 ? true : false; diff --git a/vendor/type-fest/license-cc0 b/vendor/type-fest/license-cc0 new file mode 100644 index 0000000000000000000000000000000000000000..0e259d42c996742e9e3cba14c677129b2c1b6311 --- /dev/null +++ b/vendor/type-fest/license-cc0 @@ -0,0 +1,121 @@ +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. diff --git a/vendor/type-fest/license-mit b/vendor/type-fest/license-mit new file mode 100644 index 0000000000000000000000000000000000000000..fa7ceba3eb4a9657a9db7f3ffca4e4e97a9019de --- /dev/null +++ b/vendor/type-fest/license-mit @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) Sindre Sorhus (https://sindresorhus.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/type-fest/set-required.ts b/vendor/type-fest/set-required.ts new file mode 100644 index 0000000000000000000000000000000000000000..8e4c6417a92e2b600e66cb00d2c4ff9675620269 --- /dev/null +++ b/vendor/type-fest/set-required.ts @@ -0,0 +1,34 @@ +import type { Except } from "./except"; +import type { Simplify } from "./simplify"; + +/** +Create a type that makes the given keys required. The remaining keys are kept as is. The sister of the `SetOptional` type. + +Use-case: You want to define a single model where the only thing that changes is whether or not some of the keys are required. + +@example +``` +import type {SetRequired} from 'type-fest'; + +type Foo = { + a?: number; + b: string; + c?: boolean; +} + +type SomeRequired = SetRequired; +// type SomeRequired = { +// a?: number; +// b: string; // Was already required and still is. +// c: boolean; // Is now required. +// } +``` + +@category Object +*/ +export type SetRequired = Simplify< + // Pick just the keys that are optional from the base type. + Except & + // Pick the keys that should be required from the base type and make them required. + Required> +>; diff --git a/vendor/type-fest/simplify.ts b/vendor/type-fest/simplify.ts new file mode 100644 index 0000000000000000000000000000000000000000..f4564fe7043dd43ce6ac4de4b332d7c85915910c --- /dev/null +++ b/vendor/type-fest/simplify.ts @@ -0,0 +1,59 @@ +/** +Useful to flatten the type output to improve type hints shown in editors. And also to transform an interface into a type to aide with assignability. + +@example +``` +import type {Simplify} from 'type-fest'; + +type PositionProps = { + top: number; + left: number; +}; + +type SizeProps = { + width: number; + height: number; +}; + +// In your editor, hovering over `Props` will show a flattened object with all the properties. +type Props = Simplify; +``` + +Sometimes it is desired to pass a value as a function argument that has a different type. At first inspection it may seem assignable, and then you discover it is not because the `value`'s type definition was defined as an interface. In the following example, `fn` requires an argument of type `Record`. If the value is defined as a literal, then it is assignable. And if the `value` is defined as type using the `Simplify` utility the value is assignable. But if the `value` is defined as an interface, it is not assignable because the interface is not sealed and elsewhere a non-string property could be added to the interface. + +If the type definition must be an interface (perhaps it was defined in a third-party npm package), then the `value` can be defined as `const value: Simplify = ...`. Then `value` will be assignable to the `fn` argument. Or the `value` can be cast as `Simplify` if you can't re-declare the `value`. + +@example +``` +import type {Simplify} from 'type-fest'; + +interface SomeInterface { + foo: number; + bar?: string; + baz: number | undefined; +} + +type SomeType = { + foo: number; + bar?: string; + baz: number | undefined; +}; + +const literal = {foo: 123, bar: 'hello', baz: 456}; +const someType: SomeType = literal; +const someInterface: SomeInterface = literal; + +function fn(object: Record): void {} + +fn(literal); // Good: literal object type is sealed +fn(someType); // Good: type is sealed +fn(someInterface); // Error: Index signature for type 'string' is missing in type 'someInterface'. Because `interface` can be re-opened +fn(someInterface as Simplify); // Good: transform an `interface` into a `type` +``` + +@link https://github.com/microsoft/TypeScript/issues/15300 + +@category Object +*/ +// eslint-disable-next-line @typescript-eslint/ban-types +export type Simplify = { [KeyType in keyof T]: T[KeyType] } & {};