diff --git a/index.test.ts b/index.test.ts index 2645ca4..c4d7e06 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,11 +1,5 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { - ApiError, - Model, - Prediction, - validateWebhook, - parseProgressFromLogs, -} from "replicate"; +import Replicate, { ApiError, Model, Prediction, validateWebhook, parseProgressFromLogs } from "replicate"; import nock from "nock"; import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; @@ -42,8 +36,7 @@ const fileTestCases = [ describe("Replicate client", () => { let unmatched: any[] = []; - const handleNoMatch = (req: unknown, options: any, body: string) => - unmatched.push({ req, options, body }); + const handleNoMatch = (req: unknown, options: any, body: string) => unmatched.push({ req, options, body }); beforeEach(() => { client = new Replicate({ auth: "test-token" }); @@ -123,8 +116,7 @@ describe("Replicate client", () => { { name: "Super resolution", slug: "super-resolution", - description: - "Upscaling models that create high-quality images from low-quality images.", + description: "Upscaling models that create high-quality images from low-quality images.", }, { name: "Image classification", @@ -147,8 +139,7 @@ describe("Replicate client", () => { nock(BASE_URL).get("/collections/super-resolution").reply(200, { name: "Super resolution", slug: "super-resolution", - description: - "Upscaling models that create high-quality images from low-quality images.", + description: "Upscaling models that create high-quality images from low-quality images.", models: [], }); @@ -188,9 +179,7 @@ describe("Replicate client", () => { results: [{ url: "https://replicate.com/some-user/model-1" }], next: "https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ url: "https://replicate.com/some-user/model-2" }], next: null, @@ -248,12 +237,10 @@ describe("Replicate client", () => { expectedResponse: { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, input: testCase.input, created_at: "2022-04-26T22:13:06.224088Z", @@ -263,79 +250,64 @@ describe("Replicate client", () => { }, })); - test.each(predictionTestCases)( - "$description", - async ({ input, expectedResponse }) => { - nock(BASE_URL) - .post("/predictions", { - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: input as Record, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }) - .reply(200, expectedResponse); - - const response = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + test.each(predictionTestCases)("$description", async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: input as Record, webhook: "http://test.host/webhook", webhook_events_filter: ["output", "completed"], - }); + }) + .reply(200, expectedResponse); - expect(response.input).toEqual(input); - expect(response.status).toBe(expectedResponse.status); - } - ); + const response = await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); - test.each(fileTestCases)( - "converts a $type input into a Replicate file URL", - async ({ value: data, type }) => { - const mockedFetch = jest.spyOn(client, "fetch"); + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + }); - nock(BASE_URL) - .post("/files") - .reply(201, { - urls: { - get: "https://replicate.com/api/files/123", - }, - }) - .post( - "/predictions", - (body) => body.input.data === "https://replicate.com/api/files/123" - ) - .reply(201, (_uri: string, body: Record) => { - return body; - }); + test.each(fileTestCases)("converts a $type input into a Replicate file URL", async ({ value: data, type }) => { + const mockedFetch = jest.spyOn(client, "fetch"); - const prediction = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - prompt: "Tell me a story", - data, + nock(BASE_URL) + .post("/files") + .reply(201, { + urls: { + get: "https://replicate.com/api/files/123", }, + }) + .post("/predictions", (body) => body.input.data === "https://replicate.com/api/files/123") + .reply(201, (_uri: string, body: Record) => { + return body; }); - expect(client.fetch).toHaveBeenCalledWith( - new URL("https://api.replicate.com/v1/files"), - { - method: "POST", - body: expect.any(FormData), - headers: expect.any(Object), - } - ); - const form = mockedFetch.mock.calls[0][1]?.body as FormData; - // @ts-ignore - expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); - - expect(prediction.input).toEqual({ + const prediction = await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { prompt: "Tell me a story", - data: "https://replicate.com/api/files/123", - }); - } - ); + data, + }, + }); + + expect(client.fetch).toHaveBeenCalledWith(new URL("https://api.replicate.com/v1/files"), { + method: "POST", + body: expect.any(FormData), + headers: expect.any(Object), + }); + const form = mockedFetch.mock.calls[0][1]?.body as FormData; + // @ts-ignore + expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); + + expect(prediction.input).toEqual({ + prompt: "Tell me a story", + data: "https://replicate.com/api/files/123", + }); + }); test.each(fileTestCases)( "converts a $type input into a base64 encoded string", @@ -351,8 +323,7 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", data, @@ -361,7 +332,38 @@ describe("Replicate client", () => { }); expect(actual?.input.data).toEqual(expected); - } + }, + ); + + test.each(fileTestCases)( + "raises an error when the file upload fails with 4xx error for a $type input", + async ({ value: data, expected }) => { + let actual: Record | undefined; + nock(BASE_URL) + .post("/files") + .reply(401, "Unauthorized") + .post("/predictions") + .reply(201, (_uri: string, body: Record) => { + actual = body; + return body; + }); + + await expect(async () => { + await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, + }, + stream: true, + }); + }).rejects.toThrowError( + expect.objectContaining({ + name: "ApiError", + message: expect.stringContaining("401"), + }), + ); + }, ); test("Passes stream parameter to API endpoint", async () => { @@ -373,8 +375,7 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", }, @@ -385,8 +386,7 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -402,15 +402,14 @@ describe("Replicate client", () => { status: 400, detail: "Invalid input", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ); try { expect.hasAssertions(); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: null, }, @@ -429,15 +428,14 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" } + { "Content-Type": "application/json", "Retry-After": "1" }, ) .post("/predictions") .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", }); const prediction = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -451,19 +449,18 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ); await expect( client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, - }) + }), ).rejects.toThrow( - `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.` + `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.`, ); }); }); @@ -475,12 +472,10 @@ describe("Replicate client", () => { .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", model: "stability-ai/stable-diffusion", - version: - "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", + version: "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", urls: { get: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe", - cancel: - "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", + cancel: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", }, created_at: "2022-09-13T22:54:18.578761Z", started_at: "2022-09-13T22:54:19.438525Z", @@ -499,9 +494,7 @@ describe("Replicate client", () => { predict_time: 4.484541, }, }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -513,16 +506,14 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" } + { "Content-Type": "application/json", "Retry-After": "1" }, ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -534,16 +525,14 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); }); @@ -555,12 +544,10 @@ describe("Replicate client", () => { .reply(200, { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-04-26T22:13:06.224088Z", started_at: "2022-04-26T22:13:06.224088Z", @@ -575,9 +562,7 @@ describe("Replicate client", () => { metrics: {}, }); - const prediction = await client.predictions.cancel( - "ufawqhfynnddngldkgtslldrkq" - ); + const prediction = await client.predictions.cancel("ufawqhfynnddngldkgtslldrkq"); expect(prediction.status).toBe("canceled"); }); @@ -595,12 +580,10 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/stable-diffusion", - version: - "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku", - cancel: - "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -623,9 +606,7 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -635,10 +616,7 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.predictions.list)) { results.push(...batch); } - expect(results).toEqual([ - { id: "ufawqhfynnddngldkgtslldrkq" }, - { id: "rrr4z55ocneqzikepnug6xezpe" }, - ]); + expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); // Add more tests for error handling, edge cases, etc. }); @@ -647,13 +625,10 @@ describe("Replicate client", () => { describe("trainings.create", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) - .post( - "/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings" - ) + .post("/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings") .reply(200, { id: "zz4ibbonubfz7carwiefibzgga", - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", status: "starting", input: { text: "...", @@ -675,25 +650,20 @@ describe("Replicate client", () => { input: { text: "...", }, - } + }, ); expect(training.id).toBe("zz4ibbonubfz7carwiefibzgga"); }); test("Throws an error if webhook is not a valid URL", async () => { await expect( - client.trainings.create( - "owner", - "model", - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", - { - destination: "new_owner/new_model", - input: { - text: "...", - }, - webhook: "invalid-url", - } - ) + client.trainings.create("owner", "model", "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", { + destination: "new_owner/new_model", + input: { + text: "...", + }, + webhook: "invalid-url", + }), ).rejects.toThrow("Invalid webhook URL"); }); @@ -753,9 +723,7 @@ describe("Replicate client", () => { completed_at: null, }); - const training = await client.trainings.cancel( - "zz4ibbonubfz7carwiefibzgga" - ); + const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga"); expect(training.status).toBe("canceled"); }); @@ -773,12 +741,10 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/sdxl", - version: - "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku", - cancel: - "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -801,9 +767,7 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -813,10 +777,7 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.trainings.list)) { results.push(...batch); } - expect(results).toEqual([ - { id: "ufawqhfynnddngldkgtslldrkq" }, - { id: "rrr4z55ocneqzikepnug6xezpe" }, - ]); + expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); // Add more tests for error handling, edge cases, etc. }); @@ -829,12 +790,10 @@ describe("Replicate client", () => { .reply(200, { id: "mfrgcyzzme2wkmbwgzrgmntcg", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-09-10T09:44:22.165836Z", started_at: null, @@ -848,17 +807,13 @@ describe("Replicate client", () => { logs: null, metrics: {}, }); - const prediction = await client.deployments.predictions.create( - "replicate", - "greeter", - { - input: { - text: "Alice", - }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - } - ); + const prediction = await client.deployments.predictions.create("replicate", "greeter", { + input: { + text: "Alice", + }, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); expect(prediction.id).toBe("mfrgcyzzme2wkmbwgzrgmntcg"); }); // Add more tests for error handling, edge cases, etc. @@ -874,8 +829,7 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -891,10 +845,7 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.get( - "acme", - "my-app-image-generator" - ); + const deployment = await client.deployments.get("acme", "my-app-image-generator"); expect(deployment.owner).toBe("acme"); expect(deployment.name).toBe("my-app-image-generator"); @@ -913,8 +864,7 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -933,8 +883,7 @@ describe("Replicate client", () => { const deployment = await client.deployments.create({ name: "my-app-image-generator", model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", hardware: "gpu-t4", min_instances: 1, max_instances: 5, @@ -957,8 +906,7 @@ describe("Replicate client", () => { current_release: { number: 2, model: "stability-ai/sdxl", - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", created_at: "2024-02-16T08:14:22.345678Z", created_by: { type: "organization", @@ -974,25 +922,18 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.update( - "acme", - "my-app-image-generator", - { - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", - hardware: "gpu-a40-large", - min_instances: 3, - max_instances: 10, - } - ); + const deployment = await client.deployments.update("acme", "my-app-image-generator", { + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + hardware: "gpu-a40-large", + min_instances: 3, + max_instances: 10, + }); expect(deployment.current_release.number).toBe(2); expect(deployment.current_release.version).toBe( - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532" - ); - expect(deployment.current_release.configuration.hardware).toBe( - "gpu-a40-large" + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", ); + expect(deployment.current_release.configuration.hardware).toBe("gpu-a40-large"); expect(deployment.current_release.configuration.min_instances).toBe(3); expect(deployment.current_release.configuration.max_instances).toBe(10); }); @@ -1001,14 +942,9 @@ describe("Replicate client", () => { describe("deployments.delete", () => { test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL) - .delete("/deployments/acme/my-app-image-generator") - .reply(204); + nock(BASE_URL).delete("/deployments/acme/my-app-image-generator").reply(204); - const success = await client.deployments.delete( - "acme", - "my-app-image-generator" - ); + const success = await client.deployments.delete("acme", "my-app-image-generator"); expect(success).toBe(true); }); }); @@ -1054,8 +990,7 @@ describe("Replicate client", () => { status: "starting", created_at: "2023-11-27T13:35:45.99397566Z", urls: { - cancel: - "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", + cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", }, }); @@ -1265,7 +1200,7 @@ describe("Replicate client", () => { (prediction) => { const progress = parseProgressFromLogs(prediction); callback(prediction, progress); - } + }, ); expect(output).toBe("Goodbye!"); @@ -1277,7 +1212,7 @@ describe("Replicate client", () => { status: "starting", logs: null, }, - null + null, ); expect(callback).toHaveBeenNthCalledWith( @@ -1291,7 +1226,7 @@ describe("Replicate client", () => { percentage: 0.4, current: 2, total: 5, - } + }, ); expect(callback).toHaveBeenNthCalledWith( @@ -1305,7 +1240,7 @@ describe("Replicate client", () => { percentage: 0.8, current: 4, total: 5, - } + }, ); expect(callback).toHaveBeenNthCalledWith( @@ -1320,7 +1255,7 @@ describe("Replicate client", () => { percentage: 1.0, current: 5, total: 5, - } + }, ); expect(callback).toHaveBeenCalledTimes(4); @@ -1354,7 +1289,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, wait: { interval: 1 }, }, - progress + progress, ); expect(output).toBe("Goodbye!"); @@ -1397,9 +1332,7 @@ describe("Replicate client", () => { output: "foobar", }); - await expect( - client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } }) - ).resolves.not.toThrow(); + await expect(client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } })).resolves.not.toThrow(); }); test("Throws an error for invalid identifiers", async () => { @@ -1416,15 +1349,12 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { - await client.run( - "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - { - input: { - text: "Alice", - }, - webhook: "invalid-url", - } - ); + await client.run("owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + input: { + text: "Alice", + }, + webhook: "invalid-url", + }); }).rejects.toThrow("Invalid webhook URL"); }); @@ -1463,7 +1393,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, signal, }, - onProgress + onProgress, ); expect(body).toBeDefined(); @@ -1475,19 +1405,19 @@ describe("Replicate client", () => { 1, expect.objectContaining({ status: "processing", - }) + }), ); expect(onProgress).toHaveBeenNthCalledWith( 2, expect.objectContaining({ status: "processing", - }) + }), ); expect(onProgress).toHaveBeenNthCalledWith( 3, expect.objectContaining({ status: "canceled", - }) + }), ); scope.done(); @@ -1512,8 +1442,7 @@ describe("Replicate client", () => { "Content-Type": "application/json", "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", "Webhook-Timestamp": "1614265330", - "Webhook-Signature": - "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", }, body: `{"test": 2432232314}`, }); @@ -1556,7 +1485,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1587,7 +1516,7 @@ describe("Replicate client", () => { id: EVENT_3 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1621,7 +1550,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1653,7 +1582,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1774,7 +1703,7 @@ describe("Replicate client", () => { id: EVENT_1 data: hello world - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1796,7 +1725,7 @@ describe("Replicate client", () => { id: EVENT_2 data: An unexpected error occurred - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1804,9 +1733,7 @@ describe("Replicate client", () => { done: false, value: { event: "output", id: "EVENT_1", data: "hello world" }, }); - await expect(iterator.next()).rejects.toThrowError( - "An unexpected error occurred" - ); + await expect(iterator.next()).rejects.toThrowError("An unexpected error occurred"); expect(await iterator.next()).toEqual({ done: true }); }); @@ -1814,7 +1741,7 @@ describe("Replicate client", () => { const stream = createStream("{}", 500); const iterator = stream[Symbol.asyncIterator](); await expect(iterator.next()).rejects.toThrowError( - "Request to https://stream.replicate.com/fake_stream failed with status 500" + "Request to https://stream.replicate.com/fake_stream failed with status 500", ); expect(await iterator.next()).toEqual({ done: true }); }); diff --git a/lib/util.js b/lib/util.js index 3745d9f..b4483ae 100644 --- a/lib/util.js +++ b/lib/util.js @@ -67,18 +67,11 @@ async function validateWebhook(requestData, secret) { const signedContent = `${id}.${timestamp}.${body}`; - const computedSignature = await createHMACSHA256( - signingSecret.split("_").pop(), - signedContent - ); + const computedSignature = await createHMACSHA256(signingSecret.split("_").pop(), signedContent); - const expectedSignatures = signature - .split(" ") - .map((sig) => sig.split(",")[1]); + const expectedSignatures = signature.split(" ").map((sig) => sig.split(",")[1]); - return expectedSignatures.some( - (expectedSignature) => expectedSignature === computedSignature - ); + return expectedSignatures.some((expectedSignature) => expectedSignature === computedSignature); } /** @@ -105,13 +98,9 @@ async function createHMACSHA256(secret, data) { crypto = require.call(null, "node:crypto").webcrypto; } - const key = await crypto.subtle.importKey( - "raw", - base64ToBytes(secret), - { name: "HMAC", hash: "SHA-256" }, - false, - ["sign"] - ); + const key = await crypto.subtle.importKey("raw", base64ToBytes(secret), { name: "HMAC", hash: "SHA-256" }, false, [ + "sign", + ]); const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(data)); return bytesToBase64(signature); @@ -235,6 +224,9 @@ async function transformFileInputs(client, inputs, strategy) { try { return await transformFileInputsToReplicateFileURLs(client, inputs); } catch (error) { + if (error instanceof ApiError && error.response.status >= 400 && error.response.status < 500) { + throw error; + } return await transformFileInputsToBase64EncodedDataURIs(inputs); } default: @@ -296,7 +288,7 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { totalBytes += buffer.byteLength; if (totalBytes > MAX_DATA_URI_SIZE) { throw new Error( - `Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead` + `Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`, ); } @@ -354,14 +346,11 @@ function isPlainObject(value) { if (proto === null) { return true; } - const Ctor = - Object.prototype.hasOwnProperty.call(proto, "constructor") && - proto.constructor; + const Ctor = Object.prototype.hasOwnProperty.call(proto, "constructor") && proto.constructor; return ( typeof Ctor === "function" && Ctor instanceof Ctor && - Function.prototype.toString.call(Ctor) === - Function.prototype.toString.call(Object) + Function.prototype.toString.call(Ctor) === Function.prototype.toString.call(Object) ); }