Skip to content

Commit

Permalink
adding model config modal and changing model context
Browse files Browse the repository at this point in the history
  • Loading branch information
Aayush-Agnihotri committed Dec 26, 2024
1 parent 805338c commit 55daffa
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Client Secrets (All must be prefixed with NEXT_PUBLIC_)
NEXT_PUBLIC_DEFAULT_MODEL=

NEXT_PUBLIC_FIREBASE_API_KEY=
Expand All @@ -8,6 +9,7 @@ NEXT_PUBLIC_FIREBASE_MESSAGING_SENDER_ID=
NEXT_PUBLIC_FIREBASE_APP_ID=
NEXT_PUBLIC_FIREBASE_MEASUREMENT_ID=

# Server Secrets
# For Firebase Admin SDK env variables, see https://github.com/firebase/firebase-admin-node/discussions/2043
FIREBASE_ADMIN_CLIENT_EMAIL=
FIREBASE_ADMIN_PRIVATE_KEY=
Expand Down
14 changes: 13 additions & 1 deletion src/app/chat/[chatId]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import ChatMenuNavbar from '@/app/components/chat/ChatMenuNavbar';
import ChatHeader from '@/app/components/chat/ChatHeader';
import InputField from '@/app/components/InputField';
import ChatMessage from '@/app/components/chat/ChatMessage';
// import { toast } from 'react-toastify';
// import Toast from '@/app/components/Toast';

export default function ChatPage() {
const { user } = useAuth();
const { selectedModel } = useModel();
const { selectedModel, fetchActiveModels } = useModel();
const [messages, setMessages] = useState<Message[]>([]);
const [messageStreaming, setMessageStreaming] = useState(false);
const pathname = usePathname();
Expand All @@ -30,6 +32,14 @@ export default function ChatPage() {
}
}, [messages]);

// TODO: Figure out which models become active and when
// const displayToast = async () => {
// const activeModels = await fetchActiveModels();
// if (!activeModels.includes('llama')) {
// toast.info(`Loading ${selectedModel} into memory...`);
// }
// };

const createChatCompletionRequestBody = (message: string) => {
const body: ChatCompletionRequest = {
model: selectedModel,
Expand Down Expand Up @@ -169,6 +179,7 @@ export default function ChatPage() {
setMessages((prevMessages) => [...prevMessages, userMessage]);
setMessageStreaming(true);
sendStreamedMessage(message);
// displayToast();
};

return (
Expand All @@ -191,6 +202,7 @@ export default function ChatPage() {
</div>
</div>
</div>
{/* <Toast /> */}
</Protected>
);
}
1 change: 0 additions & 1 deletion src/app/components/Toast.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ export default function ErrorToast() {
newestOnTop={false}
closeOnClick={true}
rtl={false}
pauseOnFocusLoss
pauseOnHover
theme="colored"
transition={Slide}
Expand Down
92 changes: 42 additions & 50 deletions src/app/components/chat/ChatHeader.tsx
Original file line number Diff line number Diff line change
@@ -1,48 +1,31 @@
import { useModel } from '@/contexts/ModelContext';
import { AllModelsResponse } from '@/types/model';
import { useState, useEffect, useRef } from 'react';
import Modal from '../modals/Modal';
import ModelInfoModal from '../modals/ModelInfoModal';
import ModelConfigModal from '../modals/ModelConfigModal';
import Spinner from '../Spinner';

// TODO: Add extra dropdown to set model temperature, etc

export default function ChatHeader() {
const { selectedModel, setSelectedModel } = useModel();
const { selectedModel, setSelectedModel, fetchAllModels, fetchActiveModels } = useModel();
const [modelDropdownOpen, setModelDropdownOpen] = useState(false);
const [loading, setLoading] = useState(false);
const [isModelInfoModalOpen, setIsModelInfoModalOpen] = useState(false);
const [availableModels, setAvailableModels] = useState<AllModelsResponse>({
models: [],
});
const [availableModels, setAvailableModels] = useState<string[]>([]);
const [activeModels, setActiveModels] = useState<string[]>([]);
const [isModelInfoModalOpen, setIsModelInfoModalOpen] = useState(false);
const [isModelConfigModalOpen, setIsModelConfigModalOpen] = useState(false);
const modelDropdownRef = useRef<HTMLDivElement>(null);

const openModelDropdown = async () => {
if (modelDropdownOpen) return;
setLoading(true);
await Promise.all([fetchAllModels(), fetchActiveModels()]).catch((error) => {
console.error(error);
setAvailableModels({
models: [
{
name: 'Unable to load models',
modified_at: '',
size: 0,
digest: '',
details: {
format: '',
family: '',
families: null,
parameter_size: '',
quantization_level: '',
},
expires_at: '',
size_vram: 0,
},
],
});
});
try {
const [allModels, activeModels] = await Promise.all([fetchAllModels(), fetchActiveModels()]);
setActiveModels(activeModels);
setAvailableModels(allModels);
} catch (error) {
setActiveModels([]);
setAvailableModels([`Error fetching models - ${(error as Error).message}`]);
}
setModelDropdownOpen(true);
setLoading(false);
};
Expand All @@ -52,20 +35,6 @@ export default function ChatHeader() {
setModelDropdownOpen(false);
};

const fetchAllModels = async () => {
await fetch('/api/models/all')
.then((response) => response.json())
.then((data: AllModelsResponse) => setAvailableModels(data));
};

const fetchActiveModels = async () => {
await fetch('/api/models/active')
.then((response) => response.json())
.then((data: AllModelsResponse) => {
setActiveModels(data.models.map((model) => model.name));
});
};

const handleClickOutside = (event: MouseEvent) => {
if (modelDropdownRef.current && !modelDropdownRef.current.contains(event.target as Node)) {
setModelDropdownOpen(false);
Expand All @@ -82,7 +51,7 @@ export default function ChatHeader() {
return (
<div className="m-auto flex w-11/12 flex-row items-center justify-end pt-3">
<div className="relative" ref={modelDropdownRef}>
<div className="flex flex-row items-center">
<div className="flex flex-row items-center gap-2">
<div className={`${loading ? 'flex items-center' : 'hidden'} mr-2`}>
<Spinner width="1" height="1" />
</div>
Expand Down Expand Up @@ -124,23 +93,40 @@ export default function ChatHeader() {
></path>
</svg>
</button>

<button onClick={() => setIsModelConfigModalOpen(true)}>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
strokeWidth="1.5"
stroke="currentColor"
className="size-6"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
d="M10.5 6h9.75M10.5 6a1.5 1.5 0 1 1-3 0m3 0a1.5 1.5 0 1 0-3 0M3.75 6H7.5m3 12h9.75m-9.75 0a1.5 1.5 0 0 1-3 0m3 0a1.5 1.5 0 0 0-3 0m-3.75 0H7.5m9-6h3.75m-3.75 0a1.5 1.5 0 0 1-3 0m3 0a1.5 1.5 0 0 0-3 0m-9.75 0h9.75"
/>
</svg>
</button>
</div>
{modelDropdownOpen && (
<div className="absolute right-0 top-full z-10 w-40 rounded-md bg-gray-100 shadow-lg">
{availableModels &&
availableModels.models.map((model, index) => (
availableModels.map((model, index) => (
<button
key={index}
onClick={() => selectModel(model.name)}
onClick={() => selectModel(model)}
className="flex w-full items-center px-4 py-2 text-left text-sm text-secondaryColor hover:rounded-md hover:bg-gray-200"
>
{activeModels.includes(model.name) && (
{activeModels.includes(model) && (
<div className="mr-2 h-2 w-2 rounded-full bg-green-500"></div>
)}
<span
className={`flex-1 truncate ${selectedModel === model.name ? 'font-semibold' : ''}`}
className={`flex-1 truncate ${selectedModel === model ? 'font-semibold' : ''}`}
>
{model.name}
{model}
</span>
</button>
))}
Expand All @@ -153,6 +139,12 @@ export default function ChatHeader() {
<ModelInfoModal />
</Modal>
)}

{isModelConfigModalOpen && (
<Modal onClose={() => setIsModelConfigModalOpen(false)}>
<ModelConfigModal />
</Modal>
)}
</div>
);
}
7 changes: 7 additions & 0 deletions src/app/components/modals/ModelConfigModal.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export default function ModelConfigModal() {
return (
<div>
<h1>ModelConfigModal</h1>
</div>
);
}
23 changes: 22 additions & 1 deletion src/contexts/ModelContext.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
'use client';

import { createContext, useContext, useState, ReactNode } from 'react';
import { AllModelsResponse } from '@/types/model';
import { env } from 'next-runtime-env';

interface ModelContextType {
selectedModel: string;
setSelectedModel: (model: string) => void;
fetchAllModels: () => Promise<string[]>;
fetchActiveModels: () => Promise<string[]>;
}

const ModelContext = createContext<ModelContextType | undefined>(undefined);
Expand All @@ -22,8 +25,26 @@ export const ModelProvider = ({ children }: { children: ReactNode }) => {
const defaultModel = env('NEXT_PUBLIC_DEFAULT_MODEL');
const [selectedModel, setSelectedModel] = useState<string>(defaultModel!);

const fetchAllModels = async () => {
return await fetch('/api/models/all')
.then((response) => response.json())
.then((data: AllModelsResponse) => {
return data.models.map((model) => model.name);
});
};

const fetchActiveModels = async () => {
return await fetch('/api/models/active')
.then((response) => response.json())
.then((data: AllModelsResponse) => {
return data.models.map((model) => model.name);
});
};

return (
<ModelContext.Provider value={{ selectedModel, setSelectedModel }}>
<ModelContext.Provider
value={{ selectedModel, setSelectedModel, fetchAllModels, fetchActiveModels }}
>
{children}
</ModelContext.Provider>
);
Expand Down

0 comments on commit 55daffa

Please sign in to comment.