Skip to content

Commit

Permalink
add trial from experiment detail
Browse files Browse the repository at this point in the history
  • Loading branch information
rasca committed Nov 14, 2024
1 parent 7dcb68e commit ac4ded2
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 15 deletions.
17 changes: 14 additions & 3 deletions flou/flou/experiments/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,26 @@ async def list_trials(experiment_id: uuid.UUID, session = Depends(get_session)):
).all()
return trials

@router.post("/{experiment_id}/trials/", response_model=TrialList)
@router.post("/{experiment_id}/trials/", response_model=LTMId)
async def create_trial(experiment_id: uuid.UUID, trial: TrialCreate, session = Depends(get_session)):
if not session.get(Experiment, experiment_id):
raise HTTPException(status_code=404, detail="Experiment not found")
new_trial = Trial(experiment_id=experiment_id, **trial.model_dump())

# for now we are just creating a new LTM based on fqn
trial_kwargs = trial.model_dump()
fqn = trial_kwargs.pop("fqn")
db = get_db(session)
ltm = db.get_ltm_class(fqn)()

# create the LTM and assign it to the trial
trial_kwargs["ltm_id"] = ltm.start(payload={}, playground=True)

new_trial = Trial(experiment_id=experiment_id, **trial_kwargs)
session.add(new_trial)
session.commit()
session.refresh(new_trial)
return new_trial
return {"id": new_trial.ltm_id}


@router.get("/{experiment_id}/trials/{trial_id}", response_model=TrialList)
async def get_trial(experiment_id: uuid.UUID, trial_id: uuid.UUID, session = Depends(get_session)):
Expand Down
1 change: 0 additions & 1 deletion flou/flou/experiments/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class TrialCreateBase(TrialBase):


class TrialCreate(TrialCreateBase):
experiment_id: UUID
name: str


Expand Down
5 changes: 3 additions & 2 deletions studio/src/global.scss
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ tr th {
.table-header {
display: flex;
justify-content: space-between;
width: 100%;
}

.table-controls {
Expand All @@ -213,11 +214,11 @@ tr th {
justify-content: end;
padding: 0 1rem;

a {
a, button {
display: flex;
align-items: center;
gap: 0.5rem;
color: var(--black-40, rgba(28, 28, 28, 0.40));
color: var(--black-80, rgba(28, 28, 28, 0.40));
text-decoration: none;

&:hover {
Expand Down
7 changes: 4 additions & 3 deletions studio/src/lib/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ const registryUrl = `${PUBLIC_API_BASE_URL}ltm/registry`;

export let getRegistry = async (fetch: any) => {
await fetch(registryUrl)
.then((response) => response.json())
.then((data) => {
.then((response: any) => response.json())
.then((data: any) => {
registry = data;
return data;
})
.catch((error) => {
.catch((error: any) => {
console.log(error);
return [];
});
Expand Down
4 changes: 0 additions & 4 deletions studio/src/routes/(app)/experiments/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,4 @@
{/if}

<style lang="scss">
.table-header {
display: flex;
justify-content: space-between;
}
</style>
11 changes: 9 additions & 2 deletions studio/src/routes/(app)/experiments/[id]/+page.svelte
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<script lang="ts">
import { ListMagnifyingGlass } from 'phosphor-svelte';
import { ListMagnifyingGlass, Plus } from 'phosphor-svelte';
import Block from '$lib/UI/Block.svelte';
import { formatDate } from '$lib/utils';
import InlineEditTextArea from '$lib/UI/InlineEditTextArea.svelte';
Expand Down Expand Up @@ -79,7 +79,14 @@
</dl> -->
</Block>
<Block>
<h3>Trials</h3>
<div class="table-header">
<h3>Trials</h3>
<div class="table-controls">
<a href="/experiments/{data.experiment.id}/trials/new">
<Plus size="1rem" /> New Trial
</a>
</div>
</div>
<table>
<tr>
<th>#</th>
Expand Down
90 changes: 90 additions & 0 deletions studio/src/routes/(app)/experiments/[id]/trials/new/+page.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
<script lang="ts">
import type { PageData } from './$types';
import { goto } from '$app/navigation';
import { PUBLIC_API_BASE_URL } from '$env/static/public';
import Block from '$lib/UI/Block.svelte';
import Select from '$lib/UI/Select.svelte';
import { superForm, setMessage, setError } from 'sveltekit-superforms';
import { _newTrialSchema } from './+page';
import { zod } from 'sveltekit-superforms/adapters';
export let data: PageData;
$: ({experiment, params, registryOptions} = data)
const { form, errors, message, constraints, enhance, isTainted } = superForm(data.form, {
SPA: true,
dataType: 'json',
validators: zod(_newTrialSchema),
onUpdate({ form }) {
if (form.valid) {
handleSubmit(form.data)
}
}
});
let handleSubmit = async (formData: any) => {
let url = `${PUBLIC_API_BASE_URL}experiments/${experiment.id}/trials`;
await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(formData)
})
.then((response) => response.json())
.then((data) => {
goto(`/playground/${data.id}`);
})
.catch((error) => {
console.error('Error:', error);
});
};
</script>

<h2>New Trial</h2>
<Block>
{#if $message}<div>{$message}</div>{/if}
<form use:enhance>
<div>

Create a new Trial for experiment <b>#{experiment.index} {experiment.name}</b>.
</div>

<p>Choose from the following:</p>

<label>
Name
<input
aria-invalid={$errors.name ? 'true' : undefined}
bind:value={$form.name}
{...$constraints.name}
/>
</label>
{#if $errors.name}<span class="invalid">{$errors.name}</span>{/if}
<Select
ariaInvalid={$errors.fqn ? 'true' : undefined}
bind:value={$form.fqn}
{...$constraints.fqn}
options={data.registryOptions}
label="LTM"
emptyLabel="Select LTM"
/>
{#if $errors.fqn && isTainted('fqn')}<span class="invalid">{$errors.fqn}</span>{/if}
<label>
Inputs (JSON)
<textarea
aria-invalid={$errors.inputs ? 'true' : undefined}
bind:value={$form.inputs}
{...$constraints.inputs}
/>
</label>
{#if $errors.inputs}<span class="invalid">{$errors.inputs}</span>{/if}

<div class="buttons full-width">
<a class="button secondary large" href="/experiments">Cancel</a>
<button type="submit" class="primary large">Create Trial</button>
</div>
</form>
</Block>
41 changes: 41 additions & 0 deletions studio/src/routes/(app)/experiments/[id]/trials/new/+page.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import type { PageLoad } from './$types';
import { error } from '@sveltejs/kit';
import { superValidate } from 'sveltekit-superforms';
import { zod } from 'sveltekit-superforms/adapters';
import { z } from 'zod';
import { PUBLIC_API_BASE_URL } from '$env/static/public';
import { getRegistry } from '$lib/registry';

let registryOptions: any = [];

const newTrialInitial = {
name: 'Trial',
fqn: undefined,
};

export const _newTrialSchema = z.object({
name: z.string().default('Trial'),
fqn: z.string()
.refine((data: any) => {
const validFQNs = registryOptions.map((option: any) => option[0]);
return validFQNs.includes(data);
}, {
message: 'Invalid LTM',
}),
inputs: z.string().optional()
});

export const load: PageLoad = async ({ params, fetch }) => {

const experiment = fetch(`${PUBLIC_API_BASE_URL}experiments/${params.id}`)
.then((response) => response.json())
.catch((error) => {
console.log(error);
return [];
});
registryOptions = await getRegistry(fetch);

const form = await superValidate(newTrialInitial, zod(_newTrialSchema));

return { form, registryOptions, params, experiment: await experiment };
};
24 changes: 24 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,27 @@ def test_rollback_trial(session):
assert trials[1].name == rollback_data["new_trial"]["name"]
assert trials[1].rollback_index == 1
assert trials[1].snapshot_index == 2


def test_create_trial(session):
# First create an experiment
experiment = Experiment(name="Test Experiment", description="Test experiment")
session.add(experiment)
session.commit()

# Create trial data
trial_creation_data = {
"name": "New Trial",
"fqn": "tests.test_ltm.Root",
}

response = client.post(f"/api/v0/experiments/{experiment.id}/trials", json=trial_creation_data)
assert response.status_code == 200
data = response.json()
assert "id" in data

# Verify trial was created correctly
trials = session.query(Trial).filter(Trial.experiment_id == experiment.id).all()
assert len(trials) == 1
assert trials[0].name == "New Trial"
assert trials[0].ltm_id == data["id"]

0 comments on commit ac4ded2

Please sign in to comment.