Skip to content

Commit

Permalink
create new trial on rollback recover
Browse files Browse the repository at this point in the history
  • Loading branch information
rasca committed Nov 15, 2024
1 parent ac4ded2 commit 6bbebe7
Show file tree
Hide file tree
Showing 7 changed files with 552 additions and 429 deletions.
2 changes: 1 addition & 1 deletion flou/flou/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .utils import json_dumps


engine = create_engine(settings.database.url, json_serializer=json_dumps, echo=True) # )
engine = create_engine(settings.database.url, json_serializer=json_dumps) #, echo=True )


def get_db(session=None):
Expand Down
37 changes: 18 additions & 19 deletions flou/flou/engine/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ async def websocket_endpoint(

@router.post("/ltm/{ltm_id}/rollback")
async def rollback(
snapshot: Rollback,
snapshot: Rollback | None = None,
rollback: RollbackIndex | None = None,
new_trial: AddTrial | None = None,
ltm_id: int = Path(..., description="The LTM instance id"),
session=Depends(get_session),
Expand All @@ -216,7 +217,18 @@ async def rollback(
"""
db = get_db()
ltm = db.load_ltm(ltm_id, snapshots=True)
ltm = db.rollback(ltm, snapshot.index, replay=snapshot.replay, reason="replay" if snapshot.replay else "manual")
rollback_args = {
"ltm": ltm,
}
if snapshot:
rollback_args["snapshot_index"] = snapshot.index
rollback_args["replay"] = snapshot.replay
rollback_args["reason"] = "replay" if snapshot.replay else "manual"
elif rollback:
rollback_args["rollback_index"] = rollback.index
rollback_args["reason"] = "recover rollback"

ltm = db.rollback(**rollback_args)

trial = (
session.query(Trial)
Expand All @@ -228,7 +240,9 @@ async def rollback(
result = {"success": True}

if trial:
trial.outputs = new_trial.previous_trial_outputs
if new_trial and new_trial.previous_trial_outputs:
trial.outputs = new_trial.previous_trial_outputs
session.add(trial)

# Create new trial with same name and experiment
new_trial = Trial(
Expand All @@ -237,31 +251,16 @@ async def rollback(
**new_trial.model_dump(include={"inputs"}),
name=new_trial.name or trial.name,
rollback_index=len(ltm._rollbacks or []), # this rollback doesn't exist yet
snapshot_index=snapshot.index,
snapshot_index=snapshot.index if snapshot else 0,
)
session.add(new_trial)
session.add(trial)
session.commit()

result["trial"] = new_trial

return result


@router.post("/ltm/{ltm_id}/recover-rollback")
async def rollback(
rollback: RollbackIndex,
ltm_id: int = Path(..., description="The LTM instance id"),
):
"""
Undo a rollback
"""
db = get_db()
ltm = db.load_ltm(ltm_id, snapshots=True)
db.rollback(ltm, rollback_index=rollback.index, reason="recover rollback")
return True


@router.post("/ltm/{ltm_id}/retry")
async def retry(
error_list: ErrorList,
Expand Down
33 changes: 21 additions & 12 deletions studio/src/lib/Components/NewTrialFromRollback.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
});
export let showTrialModal = false;
export let selectedAction: 'rollback' | 'replay' = 'rollback';
export let selectedSnapshotIndex: number;
export let selectedAction: 'rollback' | 'replay' | 'recover-rollback' = 'rollback';
export let selectedIndex: number;
export let ltm: any = {};
const { form, errors, message, constraints, enhance } = superForm(
Expand All @@ -31,7 +31,7 @@
if (form.valid) {
showTrialModal = false;
dispatch('newTrial', {
snapshotIndex: selectedSnapshotIndex,
index: selectedIndex,
replay: selectedAction === 'replay',
newTrialData: form.data
});
Expand All @@ -52,17 +52,26 @@
<TestTube size="2.5rem" />
</div>
<div>
{#if selectedAction === 'rollback'}
Rollback
{#if selectedAction === 'recover-rollback'}
{@const lastSnapshot = ltm.rollbacks[selectedIndex].snapshots.at(-1)}
Recover Rollback #{selectedIndex}
<div class="snapshot-item">
Last Snapshot
<SnapshotItem item={lastSnapshot.item}></SnapshotItem>
</div>
{:else}
Replay
{/if}
from snapshot #{selectedSnapshotIndex}
<div class="snapshot-item">
{#if ltm.snapshots[selectedSnapshotIndex]}
<SnapshotItem item={ltm.snapshots[selectedSnapshotIndex].item}></SnapshotItem>
{#if selectedAction === 'rollback'}
Rollback
{:else}
Replay
{/if}
</div>
from snapshot #{selectedIndex}
<div class="snapshot-item">
{#if ltm.snapshots[selectedIndex]}
<SnapshotItem item={ltm.snapshots[selectedIndex].item}></SnapshotItem>
{/if}
</div>
{/if}
</div>
</div>
{#if $form}
Expand Down
104 changes: 73 additions & 31 deletions studio/src/lib/Components/Rollbacks.svelte
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
<script lang="ts">
import { ClockCounterClockwise, Lifebuoy } from 'phosphor-svelte';
import { ClockCounterClockwise, Lifebuoy, TestTube } from 'phosphor-svelte';
import { createEventDispatcher } from 'svelte';
import Paginator from '$lib/UI/Paginator.svelte';
import SnapshotItem from '$lib/Components/SnapshotItem.svelte';
import { formatDate } from '$lib/utils';
import { PUBLIC_API_BASE_URL } from '$env/static/public';
import NewTrialFromRollback from '$lib/Components/NewTrialFromRollback.svelte';
let dispatch = createEventDispatcher();
export let ltm: any;
export let ltmId;
export let experiment: any;
let index = 0;
let showTrialModal = false;
let selectedAction = 'recover-rollback';
let selectedIndex = 0;
const ltmUrl = `${PUBLIC_API_BASE_URL}ltm/${ltmId}`;
let recoverRollback = async (snapshot_index: number) => {
let postData = {
index: snapshot_index,
let recoverRollback = async (rollbackIndex: number, newTrialData: any = undefined) => {
let postData: any = {
rollback: {
index: rollbackIndex
}
};
await fetch(`${ltmUrl}/recover-rollback`, {
console.log('NEW TRIAL DATA');
console.log(newTrialData);
if (newTrialData) {
postData['new_trial'] = newTrialData;
}
await fetch(`${ltmUrl}/rollback`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
Expand All @@ -38,28 +50,44 @@
</script>

<h3><ClockCounterClockwise size="1.25rem" />Rollbacks</h3>
<table>
<tr>
<th>Reason</th>
<th>#Snapshots</th>
<th>Last Snapshot</th>
<th>Time</th>
<th></th>
</tr>
{#each ltm.rollbacks as rollback, i}
{@const lastSnapshot = rollback.snapshots.at(-1)}
<tr
class:current={i === index}
on:click={() => index = i}
on:keydown={() => index = i}
role="button"
>
<td title={rollback.reason}>{rollback.reason}</td>
<td title={rollback.snapshots.length}>{rollback.snapshots.length}</td>
<td title={`${lastSnapshot.reason}`}>{lastSnapshot.reason}: <SnapshotItem item={lastSnapshot.item} />
<td title={formatDate(rollback.time)}>{formatDate(rollback.time)}</td>
<td>
<div class="snapshot-controls">
<table>
<tr>
<th>Reason</th>
<th>#Snapshots</th>
<th>Last Snapshot</th>
<th>Time</th>
<th>
{#if experiment}
<span class="th-icon">New trial <TestTube /></span>
{/if}
</th>
</tr>
{#each ltm.rollbacks as rollback, i}
{@const lastSnapshot = rollback.snapshots.at(-1)}
<tr
class:current={i === index}
on:click={() => (index = i)}
on:keydown={() => (index = i)}
role="button"
>
<td title={rollback.reason}>{rollback.reason}</td>
<td title={rollback.snapshots.length}>{rollback.snapshots.length}</td>
<td title={`${lastSnapshot.reason}`}
>{lastSnapshot.reason}: <SnapshotItem item={lastSnapshot.item} />
</td><td title={formatDate(rollback.time)}>{formatDate(rollback.time)}</td>
<td>
<div class="snapshot-controls">
{#if experiment}
<button
on:click={() => {
selectedIndex = i;
showTrialModal = true;
}}
title="Recover rollback"
>
<Lifebuoy size="1.25rem" />
</button>
{:else}
<button
on:click={() => {
recoverRollback(i);
Expand All @@ -68,7 +96,21 @@
>
<Lifebuoy size="1.25rem" />
</button>
</td>
</tr>{/each}
</table>
<Paginator bind:index={index} collection={ltm.rollbacks} />
{/if}
</div>
</td>
</tr>{/each}
</table>
<Paginator bind:index collection={ltm.rollbacks} />

{#if experiment}
<NewTrialFromRollback
bind:showTrialModal
bind:selectedIndex
bind:ltm
selectedAction="recover-rollback"
on:newTrial={(event) => {
recoverRollback(event.detail.index, event.detail.newTrialData);
}}
/>
{/if}
Loading

0 comments on commit 6bbebe7

Please sign in to comment.