From b8714efca8ace19ea73c122c7fb731a3c0078137 Mon Sep 17 00:00:00 2001 From: "Junius Free B. Fontamillas" <8164667+juniusfree@users.noreply.github.com> Date: Tue, 15 Oct 2024 18:43:42 +0800 Subject: [PATCH 1/2] Implement streaming --- app/(dashboard)/dashboard/pull-request.tsx | 239 +++++++++++++-------- app/api/generate-tests/route.ts | 25 +-- app/api/generate-tests/schema.ts | 9 + 3 files changed, 170 insertions(+), 103 deletions(-) diff --git a/app/(dashboard)/dashboard/pull-request.tsx b/app/(dashboard)/dashboard/pull-request.tsx index 4d39dae6..37979bf2 100644 --- a/app/(dashboard)/dashboard/pull-request.tsx +++ b/app/(dashboard)/dashboard/pull-request.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState } from "react"; +import { useCallback, useState } from "react"; import { Button } from "@/components/ui/button"; import { GitPullRequestDraft, @@ -18,15 +18,20 @@ import Link from "next/link"; import { Checkbox } from "@/components/ui/checkbox"; import dynamic from "next/dynamic"; import { PullRequest, TestFile } from "./types"; -import { generateTestsResponseSchema } from "@/app/api/generate-tests/schema"; import { useToast } from "@/hooks/use-toast"; -import { commitChangesToPullRequest, getPullRequestInfo, getFailingTests } from "@/lib/github"; +import { + commitChangesToPullRequest, + getPullRequestInfo, + getFailingTests, +} from "@/lib/github"; import { Input } from "@/components/ui/input"; -import useSWR from 'swr'; -import { fetchBuildStatus } from '@/lib/github'; -import { LogView } from './log-view' -import { getLatestRunId } from '@/lib/github' -import { cn } from "@/lib/utils" +import useSWR from "swr"; +import { fetchBuildStatus } from "@/lib/github"; +import { experimental_useObject as useObject } from "ai/react"; +import { TestFileSchema } from "@/app/api/generate-tests/schema"; +import { LogView } from "./log-view"; +import { getLatestRunId } from "@/lib/github"; +import { cn } from "@/lib/utils"; const ReactDiffViewer = dynamic(() => import("react-diff-viewer"), { ssr: false, @@ -36,13 +41,20 @@ interface PullRequestItemProps { pullRequest: PullRequest; } -export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequestItemProps) { +export function PullRequestItem({ + pullRequest: initialPullRequest, +}: PullRequestItemProps) { const [optimisticRunning, setOptimisticRunning] = useState(false); const [showLogs, setShowLogs] = useState(false); const { data: pullRequest, mutate } = useSWR( `pullRequest-${initialPullRequest.id}`, - () => fetchBuildStatus(initialPullRequest.repository.owner.login, initialPullRequest.repository.name, initialPullRequest.number), + () => + fetchBuildStatus( + initialPullRequest.repository.owner.login, + initialPullRequest.repository.name, + initialPullRequest.number + ), { fallbackData: initialPullRequest, refreshInterval: optimisticRunning ? 10000 : 0, @@ -55,15 +67,31 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest ); const { data: latestRunId } = useSWR( - pullRequest.buildStatus === 'success' || pullRequest.buildStatus === 'failure' - ? ['latestRunId', pullRequest.repository.owner.login, pullRequest.repository.name, pullRequest.branchName] + pullRequest.buildStatus === "success" || + pullRequest.buildStatus === "failure" + ? [ + "latestRunId", + pullRequest.repository.owner.login, + pullRequest.repository.name, + pullRequest.branchName, + ] : null, - () => getLatestRunId(pullRequest.repository.owner.login, pullRequest.repository.name, pullRequest.branchName) + () => + getLatestRunId( + pullRequest.repository.owner.login, + pullRequest.repository.name, + pullRequest.branchName + ) ); const [testFiles, setTestFiles] = useState([]); - const [selectedFiles, setSelectedFiles] = useState>({}); - const [expandedFiles, setExpandedFiles] = useState>({}); + const [oldTestFiles, setOldTestFiles] = useState([]); + const [selectedFiles, setSelectedFiles] = useState>( + {} + ); + const [expandedFiles, setExpandedFiles] = useState>( + {} + ); const [analyzing, setAnalyzing] = useState(false); const [loading, setLoading] = useState(false); const [error, setError] = useState(null); @@ -73,6 +101,35 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest const isRunning = optimisticRunning || pullRequest.buildStatus === "running"; const isPending = !optimisticRunning && pullRequest.buildStatus === "pending"; + const { + object, + submit, + isLoading: isStreaming, + } = useObject({ + api: "/api/generate-tests", + schema: TestFileSchema, + onFinish: async (result) => { + const { testFiles: oldTestFiles } = await getPullRequestInfo( + pullRequest.repository.owner.login, + pullRequest.repository.name, + pullRequest.number + ); + const { filteredTestFiles, newSelectedFiles, newExpandedFiles } = + handleTestFilesUpdate(oldTestFiles, result.object?.tests); + setTestFiles(filteredTestFiles); + setSelectedFiles(newSelectedFiles); + setExpandedFiles(newExpandedFiles); + setAnalyzing(false); + setLoading(false); + }, + onError: (error) => { + console.error("Error generating test files:", error); + setError("Failed to generate test files."); + setAnalyzing(false); + setLoading(false); + }, + }); + const handleTests = async (pr: PullRequest, mode: "write" | "update") => { setAnalyzing(true); setLoading(true); @@ -80,9 +137,9 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest try { const { diff, testFiles: oldTestFiles } = await getPullRequestInfo( - pr.repository.owner.login, - pr.repository.name, - pr.number + pullRequest.repository.owner.login, + pullRequest.repository.name, + pullRequest.number ); let testFilesToUpdate = oldTestFiles; @@ -93,35 +150,21 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest pr.repository.name, pr.number ); - testFilesToUpdate = oldTestFiles.filter(file => - failingTests.some(failingFile => failingFile.name === file.name) + testFilesToUpdate = oldTestFiles.filter((file) => + failingTests.some((failingFile) => failingFile.name === file.name) ); } - const response = await fetch("/api/generate-tests", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - mode, - pr_id: pr.id, - pr_diff: diff, - test_files: testFilesToUpdate, - }), + setOldTestFiles(testFilesToUpdate); + submit({ + mode, + pr_id: pr.id, + pr_diff: diff, + test_files: testFilesToUpdate, }); - - if (!response.ok) { - throw new Error("Failed to generate test files"); - } - - const data = await response.json(); - const parsedData = generateTestsResponseSchema.parse(data); - handleTestFilesUpdate(oldTestFiles, parsedData); } catch (error) { console.error("Error generating test files:", error); setError("Failed to generate test files."); - } finally { setAnalyzing(false); setLoading(false); } @@ -129,13 +172,13 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest const handleTestFilesUpdate = ( oldTestFiles: TestFile[], - newTestFiles: TestFile[] + newTestFiles?: (Partial | undefined)[] ) => { - if (newTestFiles.length > 0) { + if (newTestFiles && newTestFiles.length > 0) { const filteredTestFiles = newTestFiles .filter((file): file is TestFile => file !== undefined) .map((file) => { - const oldFile = oldTestFiles.find( + const oldFile = oldTestFiles?.find( (oldFile) => oldFile.name === file.name ); return { @@ -143,7 +186,7 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest oldContent: oldFile ? oldFile.content : "", }; }); - setTestFiles(filteredTestFiles); + const newSelectedFiles: Record = {}; const newExpandedFiles: Record = {}; filteredTestFiles.forEach((file) => { @@ -151,9 +194,13 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest newExpandedFiles[fileName] = true; newSelectedFiles[fileName] = true; }); - setSelectedFiles(newSelectedFiles); - setExpandedFiles(newExpandedFiles); + return { filteredTestFiles, newSelectedFiles, newExpandedFiles }; } + return { + filteredTestFiles: [], + newSelectedFiles: {}, + newExpandedFiles: {}, + }; }; const commitChanges = async () => { @@ -196,7 +243,6 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest setExpandedFiles({}); mutate(); - } catch (error) { console.error("Error committing changes:", error); setError("Failed to commit changes. Please try again."); @@ -231,6 +277,13 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest })); }; + const { filteredTestFiles, newSelectedFiles, newExpandedFiles } = + handleTestFilesUpdate(oldTestFiles, object?.tests); + + const testFilesToShow = isStreaming ? filteredTestFiles : testFiles; + const selectedFilesToShow = isStreaming ? newSelectedFiles : selectedFiles; + const expandedFilesToShow = isStreaming ? newExpandedFiles : expandedFiles; + return (
@@ -264,32 +317,39 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest href={`https://github.com/${pullRequest.repository.full_name}/pull/${pullRequest.number}/checks`} className="text-sm underline text-gray-600" > - Build: {isRunning ? "Running" : isPending ? "Pending" : pullRequest.buildStatus} + Build:{" "} + {isRunning + ? "Running" + : isPending + ? "Pending" + : pullRequest.buildStatus} - {(pullRequest.buildStatus === 'success' || pullRequest.buildStatus === 'failure') && latestRunId && ( - - )} + {(pullRequest.buildStatus === "success" || + pullRequest.buildStatus === "failure") && + latestRunId && ( + + )} - {testFiles.length > 0 ? ( + {testFilesToShow.length > 0 ? ( ) : ( )}
@@ -333,23 +401,19 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest {error}
)} - {(loading || analyzing || testFiles.length > 0) && ( + + {(loading || analyzing || testFilesToShow.length > 0) && (

Test files

- {analyzing ? ( -
- - Analyzing PR diff... -
- ) : ( + {testFilesToShow.length > 0 ? (
- {testFiles.map((file) => ( + {testFilesToShow.map((file) => (
handleFileToggle(file.name)} />
- {expandedFiles[file.name] && ( + {expandedFilesToShow[file.name] && (
!value) || + Object.values(selectedFilesToShow).every( + (value) => !value + ) || loading || !commitMessage.trim() } @@ -399,6 +465,11 @@ export function PullRequestItem({ pullRequest: initialPullRequest }: PullRequest
+ ) : ( +
+ + Analyzing PR diff... +
)}
)} diff --git a/app/api/generate-tests/route.ts b/app/api/generate-tests/route.ts index ee01f659..fd4da4c5 100644 --- a/app/api/generate-tests/route.ts +++ b/app/api/generate-tests/route.ts @@ -1,9 +1,6 @@ import { anthropic } from "@ai-sdk/anthropic"; -import { generateObject } from "ai"; -import { z } from "zod"; -import { GenerateTestsInput } from "./schema"; - -export const maxDuration = 30; +import { streamObject } from "ai"; +import { GenerateTestsInput, TestFileSchema } from "./schema"; export async function POST(req: Request) { const { mode, pr_diff, test_files } = @@ -29,22 +26,12 @@ export async function POST(req: Request) { Respond with an array of test files with their name being the path to the file and the content being the full contents of the updated test file.`; - const { object } = await generateObject({ + const result = await streamObject({ model: anthropic("claude-3-5-sonnet-20240620"), - output: "array", - schema: z.object({ - name: z.string(), - content: z.string(), - }), + output: "object", + schema: TestFileSchema, prompt, }); - const tests = []; - for (const test of object) { - tests.push(test); - } - - return new Response(JSON.stringify(tests), { - headers: { "Content-Type": "application/json" }, - }); + return result.toTextStreamResponse(); } diff --git a/app/api/generate-tests/schema.ts b/app/api/generate-tests/schema.ts index 16b5925e..fadb9f75 100644 --- a/app/api/generate-tests/schema.ts +++ b/app/api/generate-tests/schema.ts @@ -19,5 +19,14 @@ export const generateTestsResponseSchema = z.array( }) ); +export const TestFileSchema = z.object({ + tests: z.array( + z.object({ + name: z.string(), + content: z.string(), + }) + ), +}); + export type GenerateTestsInput = z.infer; export type GenerateTestsResponse = z.infer; From 2df2fc0a8336b8e6c3758aee33b3eb825b4c3c10 Mon Sep 17 00:00:00 2001 From: "Junius Free B. Fontamillas" <8164667+juniusfree@users.noreply.github.com> Date: Tue, 15 Oct 2024 19:46:02 +0800 Subject: [PATCH 2/2] Update test files --- .../dashboard/pull-request.test.tsx | 175 ++++++++++++------ 1 file changed, 114 insertions(+), 61 deletions(-) diff --git a/app/(dashboard)/dashboard/pull-request.test.tsx b/app/(dashboard)/dashboard/pull-request.test.tsx index d1821635..4bed08ca 100644 --- a/app/(dashboard)/dashboard/pull-request.test.tsx +++ b/app/(dashboard)/dashboard/pull-request.test.tsx @@ -1,12 +1,15 @@ -import React from 'react'; -import { render, screen, fireEvent, waitFor } from '@testing-library/react'; -import { PullRequestItem } from './pull-request'; -import { vi, describe, it, expect, beforeEach } from 'vitest'; -import { PullRequest } from './types'; -import useSWR from 'swr'; - -vi.mock('@/lib/github', async () => { - const actual = await vi.importActual('@/lib/github'); +import React from "react"; +import { render, screen, fireEvent, waitFor } from "@testing-library/react"; +import { PullRequestItem } from "./pull-request"; +import { vi, describe, it, expect, beforeEach } from "vitest"; +import { PullRequest } from "./types"; +import useSWR from "swr"; +import { fetchBuildStatus } from "@/lib/github"; +import { experimental_useObject as useObject } from "ai/react"; +import { act } from "react"; + +vi.mock("@/lib/github", async () => { + const actual = await vi.importActual("@/lib/github"); return { ...(actual as object), getPullRequestInfo: vi.fn(), @@ -17,44 +20,56 @@ vi.mock('@/lib/github', async () => { }; }); -vi.mock('@/hooks/use-toast', () => ({ +vi.mock("@/hooks/use-toast", () => ({ useToast: vi.fn(() => ({ toast: vi.fn(), })), })); -vi.mock('next/link', () => ({ - default: ({ children, href }: { children: React.ReactNode; href: string }) => ( - {children} - ), +vi.mock("next/link", () => ({ + default: ({ + children, + href, + }: { + children: React.ReactNode; + href: string; + }) => {children}, })); -vi.mock('react-diff-viewer', () => ({ +vi.mock("react-diff-viewer", () => ({ default: () =>
Mocked Diff Viewer
, })); -vi.mock('swr', () => ({ +vi.mock("swr", () => ({ default: vi.fn(), })); -vi.mock('./log-view', () => ({ +vi.mock("./log-view", () => ({ LogView: () =>
Mocked Log View
, })); -describe('PullRequestItem', () => { +vi.mock("./log-view", () => ({ + LogView: () =>
Mocked Log View
, +})); + +vi.mock("ai/react", () => ({ + experimental_useObject: vi.fn(), +})); + +describe("PullRequestItem", () => { const mockPullRequest: PullRequest = { id: 1, - title: 'Test PR', + title: "Test PR", number: 123, - buildStatus: 'success', + buildStatus: "success", isDraft: false, - branchName: 'feature-branch', + branchName: "feature-branch", repository: { id: 1, - name: 'test-repo', - full_name: 'owner/test-repo', + name: "test-repo", + full_name: "owner/test-repo", owner: { - login: 'owner', + login: "owner", }, }, }; @@ -69,17 +84,22 @@ describe('PullRequestItem', () => { isValidating: false, isLoading: false, }); + vi.mocked(useObject).mockReturnValue({ + object: null, + submit: vi.fn(), + isLoading: false, + }); }); - it('renders the pull request information correctly', () => { + it("renders the pull request information correctly", () => { render(); - expect(screen.getByText('Test PR')).toBeInTheDocument(); - expect(screen.getByText('#123')).toBeInTheDocument(); - expect(screen.getByText('Build: success')).toBeInTheDocument(); + expect(screen.getByText("Test PR")).toBeInTheDocument(); + expect(screen.getByText("#123")).toBeInTheDocument(); + expect(screen.getByText("Build: success")).toBeInTheDocument(); }); - it('displays running build status', () => { - const runningPR = { ...mockPullRequest, buildStatus: 'running' }; + it("displays running build status", () => { + const runningPR = { ...mockPullRequest, buildStatus: "running" }; vi.mocked(useSWR).mockReturnValue({ data: runningPR, mutate: vi.fn(), @@ -88,12 +108,12 @@ describe('PullRequestItem', () => { isLoading: false, }); render(); - expect(screen.getByText('Build: Running')).toBeInTheDocument(); - expect(screen.getByText('Running...')).toBeInTheDocument(); + expect(screen.getByText("Build: Running")).toBeInTheDocument(); + expect(screen.getByText("Running...")).toBeInTheDocument(); }); - it('disables buttons when build is running', () => { - const runningPR = { ...mockPullRequest, buildStatus: 'running' }; + it("disables buttons when build is running", () => { + const runningPR = { ...mockPullRequest, buildStatus: "running" }; vi.mocked(useSWR).mockReturnValue({ data: runningPR, mutate: vi.fn(), @@ -102,14 +122,16 @@ describe('PullRequestItem', () => { isLoading: false, }); render(); - expect(screen.getByText('Running...')).toBeDisabled(); + expect(screen.getByText("Running...")).toBeDisabled(); }); - it('updates build status periodically', async () => { + it("updates build status periodically", async () => { const mutate = vi.fn(); const fetchBuildStatusMock = vi.fn().mockResolvedValue(mockPullRequest); + vi.mocked(fetchBuildStatus).mockImplementation(fetchBuildStatusMock); + vi.mocked(useSWR).mockImplementation((key, fetcher, options) => { - if (typeof fetcher === 'function') { + if (typeof fetcher === "function") { fetcher(); } return { @@ -122,7 +144,7 @@ describe('PullRequestItem', () => { }); render(); - + await waitFor(() => { expect(useSWR).toHaveBeenCalledWith( `pullRequest-${mockPullRequest.id}`, @@ -134,20 +156,37 @@ describe('PullRequestItem', () => { }) ); }); + + // Verify that fetchBuildStatus is called with the correct parameters + + expect(fetchBuildStatusMock).toHaveBeenCalledWith( + mockPullRequest.repository.owner.login, + mockPullRequest.repository.name, + mockPullRequest.number + ); }); - it('triggers revalidation after committing changes', async () => { - const { getPullRequestInfo, commitChangesToPullRequest } = await import('@/lib/github'); + it("triggers revalidation after committing changes", async () => { + const { getPullRequestInfo, commitChangesToPullRequest } = await import( + "@/lib/github" + ); vi.mocked(getPullRequestInfo).mockResolvedValue({ - diff: 'mock diff', - testFiles: [{ name: 'existing_test.ts', content: 'existing content' }], + diff: "mock diff", + testFiles: [{ name: "existing_test.ts", content: "existing content" }], }); - vi.mocked(commitChangesToPullRequest).mockResolvedValue('https://github.com/commit/123'); + vi.mocked(commitChangesToPullRequest).mockResolvedValue( + "https://github.com/commit/123" + ); - vi.mocked(global.fetch).mockResolvedValue({ - ok: true, - json: () => Promise.resolve([{ name: 'generated_test.ts', content: 'generated content' }]), - } as Response); + const mockSubmit = vi.fn(); + vi.mocked(useObject).mockReturnValue({ + object: null, + submit: mockSubmit, + isLoading: false, + setInput: vi.fn(), + error: null, + stop: vi.fn(), + }); const mutate = vi.fn(); vi.mocked(useSWR).mockReturnValue({ @@ -159,14 +198,28 @@ describe('PullRequestItem', () => { }); render(); - const writeTestsButton = screen.getByText('Write new tests'); + + const writeTestsButton = screen.getByText("Write new tests"); fireEvent.click(writeTestsButton); await waitFor(() => { - expect(screen.getByText('generated_test.ts')).toBeInTheDocument(); + expect(screen.getByText("Analyzing PR diff...")).toBeInTheDocument(); + }); + + await act(async () => { + const { onFinish } = vi.mocked(useObject).mock.calls[0][0]; + await onFinish({ + object: { + tests: [{ name: "generated_test.ts", content: "generated content" }], + }, + }); + }); + + await waitFor(() => { + expect(screen.getByText("generated_test.ts")).toBeInTheDocument(); }); - const commitButton = screen.getByText('Commit changes'); + const commitButton = screen.getByText("Commit changes"); fireEvent.click(commitButton); await waitFor(() => { @@ -175,36 +228,36 @@ describe('PullRequestItem', () => { }); }); - it('shows and hides logs when toggle is clicked', async () => { + it("shows and hides logs when toggle is clicked", async () => { vi.mocked(useSWR).mockReturnValue({ - data: { ...mockPullRequest, buildStatus: 'success' }, + data: { ...mockPullRequest, buildStatus: "success" }, mutate: vi.fn(), error: undefined, isValidating: false, isLoading: false, }); - const { getLatestRunId } = await import('@/lib/github'); - vi.mocked(getLatestRunId).mockResolvedValue('123'); + const { getLatestRunId } = await import("@/lib/github"); + vi.mocked(getLatestRunId).mockResolvedValue("123"); render(); await waitFor(() => { - expect(screen.getByText('Show Logs')).toBeInTheDocument(); + expect(screen.getByText("Show Logs")).toBeInTheDocument(); }); - fireEvent.click(screen.getByText('Show Logs')); + fireEvent.click(screen.getByText("Show Logs")); await waitFor(() => { - expect(screen.getByTestId('log-view')).toBeInTheDocument(); - expect(screen.getByText('Hide Logs')).toBeInTheDocument(); + expect(screen.getByTestId("log-view")).toBeInTheDocument(); + expect(screen.getByText("Hide Logs")).toBeInTheDocument(); }); - fireEvent.click(screen.getByText('Hide Logs')); + fireEvent.click(screen.getByText("Hide Logs")); await waitFor(() => { - expect(screen.queryByTestId('log-view')).not.toBeInTheDocument(); - expect(screen.getByText('Show Logs')).toBeInTheDocument(); + expect(screen.queryByTestId("log-view")).not.toBeInTheDocument(); + expect(screen.getByText("Show Logs")).toBeInTheDocument(); }); }); });