Skip to content

Commit

Permalink
Merge pull request #62 from barun-saha/ollama
Browse files Browse the repository at this point in the history
Offline LLMs via Ollama
  • Loading branch information
barun-saha authored Dec 8, 2024
2 parents 89c5253 + 50f37bd commit 89a3160
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 38 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Different LLMs offer different styles of content generation. Use one of the foll

The Mistral models do not mandatorily require an access token. However, you are encouraged to get and use your own Hugging Face access token.

In addition, offline LLMs provided by Ollama can be used. Read below to know more.


# Icons

Expand All @@ -62,6 +64,33 @@ To run this project by yourself, you need to provide the `HUGGINGFACEHUB_API_TOK
for example, in a `.env` file. Alternatively, you can provide the access token in the app's user interface itself (UI). For other LLM providers, the API key can only be specified in the UI. For image search, the `PEXEL_API_KEY` should be made available as an environment variable.
Visit the respective websites to obtain the API keys.

## Offline LLMs Using Ollama

SlideDeck AI allows the use of offline LLMs to generate the contents of the slide decks. This is typically suitable for individuals or organizations who would like to use self-hosted LLMs for privacy concerns, for example.

Offline LLMs are made available via Ollama. Therefore, a pre-requisite here is to have [Ollama installed](https://ollama.com/download) on the system and the desired [LLM](https://ollama.com/search) pulled locally.

In addition, the `RUN_IN_OFFLINE_MODE` environment variable needs to be set to `True` to enable the offline mode. This, for example, can be done using a `.env` file or from the terminal. The typical steps to use SlideDeck AI in offline mode (in a `bash` shell) are as follows:

```bash
ollama list # View locally available LLMs
export RUN_IN_OFFLINE_MODE=True # Enable the offline mode to use Ollama
git clone https://github.com/barun-saha/slide-deck-ai.git
cd slide-deck-ai
python -m venv venv # Create a virtual environment
source venv/bin/activate # On a Linux system
pip install -r requirements.txt
streamlit run ./app.py # Run the application
```

The `.env` file should be created inside the `slide-deck-ai` directory.

The UI is similar to the online mode. However, rather than selecting an LLM from a list, one has to write the name of the Ollama model to be used in a textbox. There is no API key asked here.

The online and offline modes are mutually exclusive. So, setting `RUN_IN_OFFLINE_MODE` to `False` will make SlideDeck AI use the online LLMs (i.e., the "original mode."). By default, `RUN_IN_OFFLINE_MODE` is set to `False`.

Finally, the focus is on using offline LLMs, not going completely offline. So, Internet connectivity would still be required to fetch the images from Pexels.


# Live Demo

Expand Down
85 changes: 59 additions & 26 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,34 @@
"""
import datetime
import logging
import os
import pathlib
import random
import tempfile
from typing import List, Union

import httpx
import huggingface_hub
import json5
import ollama
import requests
import streamlit as st
from dotenv import load_dotenv
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate

import global_config as gcfg
from global_config import GlobalConfig
from helpers import llm_helper, pptx_helper, text_helper


load_dotenv()


RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'


@st.cache_data
def _load_strings() -> dict:
"""
Expand Down Expand Up @@ -135,25 +146,36 @@ def reset_api_key():
horizontal=True
)

# The LLMs
llm_provider_to_use = st.sidebar.selectbox(
label='2: Select an LLM to use:',
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
index=GlobalConfig.DEFAULT_MODEL_INDEX,
help=GlobalConfig.LLM_PROVIDER_HELP,
on_change=reset_api_key
).split(' ')[0]

# The API key/access token
api_key_token = st.text_input(
label=(
'3: Paste your API key/access token:\n\n'
'*Mandatory* for Cohere and Gemini LLMs.'
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
),
type='password',
key='api_key_input'
)
if RUN_IN_OFFLINE_MODE:
llm_provider_to_use = st.text_input(
label='2: Enter Ollama model name to use:',
help=(
'Specify a correct, locally available LLM, found by running `ollama list`, for'
' example `mistral:v0.2` and `mistral-nemo:latest`. Having an Ollama-compatible'
' and supported GPU is strongly recommended.'
)
)
api_key_token: str = ''
else:
# The LLMs
llm_provider_to_use = st.sidebar.selectbox(
label='2: Select an LLM to use:',
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
index=GlobalConfig.DEFAULT_MODEL_INDEX,
help=GlobalConfig.LLM_PROVIDER_HELP,
on_change=reset_api_key
).split(' ')[0]

# The API key/access token
api_key_token = st.text_input(
label=(
'3: Paste your API key/access token:\n\n'
'*Mandatory* for Cohere and Gemini LLMs.'
' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
),
type='password',
key='api_key_input'
)


def build_ui():
Expand Down Expand Up @@ -200,7 +222,11 @@ def set_up_chat_ui():
placeholder=APP_TEXT['chat_placeholder'],
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
):
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
provider, llm_name = llm_helper.get_provider_model(
llm_provider_to_use,
use_ollama=RUN_IN_OFFLINE_MODE
)
print(f'{llm_provider_to_use=}, {provider=}, {llm_name=}, {api_key_token=}')

if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
return
Expand Down Expand Up @@ -233,7 +259,7 @@ def set_up_chat_ui():
llm = llm_helper.get_langchain_llm(
provider=provider,
model=llm_name,
max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
max_new_tokens=gcfg.get_max_output_tokens(llm_provider_to_use),
api_key=api_key_token.strip(),
)

Expand All @@ -252,18 +278,17 @@ def set_up_chat_ui():
# Update the progress bar with an approx progress percentage
progress_bar.progress(
min(
len(response) / GlobalConfig.VALID_MODELS[
llm_provider_to_use
]['max_new_tokens'],
len(response) / gcfg.get_max_output_tokens(llm_provider_to_use),
0.95
),
text='Streaming content...this might take a while...'
)
except requests.exceptions.ConnectionError:
except (httpx.ConnectError, requests.exceptions.ConnectionError):
handle_error(
'A connection error occurred while streaming content from the LLM endpoint.'
' Unfortunately, the slide deck cannot be generated. Please try again later.'
' Alternatively, try selecting a different LLM from the dropdown list.',
' Alternatively, try selecting a different LLM from the dropdown list. If you are'
' using Ollama, make sure that Ollama is already running on your system.',
True
)
return
Expand All @@ -274,6 +299,14 @@ def set_up_chat_ui():
True
)
return
except ollama.ResponseError:
handle_error(
f'The model `{llm_name}` is unavailable with Ollama on your system.'
f' Make sure that you have provided the correct LLM name or pull it using'
f' `ollama pull {llm_name}`. View LLMs available locally by running `ollama list`.',
True
)
return
except Exception as ex:
handle_error(
f'An unexpected error occurred while generating the content: {ex}'
Expand Down
24 changes: 22 additions & 2 deletions global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class GlobalConfig:
PROVIDER_COHERE = 'co'
PROVIDER_GOOGLE_GEMINI = 'gg'
PROVIDER_HUGGING_FACE = 'hf'
VALID_PROVIDERS = {PROVIDER_COHERE, PROVIDER_GOOGLE_GEMINI, PROVIDER_HUGGING_FACE}
PROVIDER_OLLAMA = 'ol'
VALID_PROVIDERS = {
PROVIDER_COHERE,
PROVIDER_GOOGLE_GEMINI,
PROVIDER_HUGGING_FACE,
PROVIDER_OLLAMA
}
VALID_MODELS = {
'[co]command-r-08-2024': {
'description': 'simpler, slower',
Expand All @@ -47,7 +53,7 @@ class GlobalConfig:
'LLM provider codes:\n\n'
'- **[co]**: Cohere\n'
'- **[gg]**: Google Gemini API\n'
'- **[hf]**: Hugging Face Inference Endpoint\n'
'- **[hf]**: Hugging Face Inference API\n'
)
DEFAULT_MODEL_INDEX = 2
LLM_MODEL_TEMPERATURE = 0.2
Expand Down Expand Up @@ -125,3 +131,17 @@ class GlobalConfig:
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)


def get_max_output_tokens(llm_name: str) -> int:
"""
Get the max output tokens value configured for an LLM. Return a default value if not configured.
:param llm_name: The name of the LLM.
:return: Max output tokens or a default count.
"""

try:
return GlobalConfig.VALID_MODELS[llm_name]['max_new_tokens']
except KeyError:
return 2048
39 changes: 30 additions & 9 deletions helpers/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
OLLAMA_MODEL_REGEX = re.compile(r'[a-zA-Z0-9._:-]+$')
# 6-64 characters long, only containing alphanumeric characters, hyphens, and underscores
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9\-_]{6,64}$')
API_KEY_REGEX = re.compile(r'^[a-zA-Z0-9_-]{6,64}$')
HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
REQUEST_TIMEOUT = 35

Expand All @@ -39,20 +40,28 @@
http_session.mount('http://', adapter)


def get_provider_model(provider_model: str) -> Tuple[str, str]:
def get_provider_model(provider_model: str, use_ollama: bool) -> Tuple[str, str]:
"""
Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
:param provider_model: The provider, model name string from `GlobalConfig`.
:return: The provider and the model name.
:param use_ollama: Whether Ollama is used (i.e., running in offline mode).
:return: The provider and the model name; empty strings in case no matching pattern found.
"""

match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
provider_model = provider_model.strip()

if match:
inside_brackets = match.group(1)
outside_brackets = match.group(2)
return inside_brackets, outside_brackets
if use_ollama:
match = OLLAMA_MODEL_REGEX.match(provider_model)
if match:
return GlobalConfig.PROVIDER_OLLAMA, match.group(0)
else:
match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)

if match:
inside_brackets = match.group(1)
outside_brackets = match.group(2)
return inside_brackets, outside_brackets

return '', ''

Expand Down Expand Up @@ -152,6 +161,18 @@ def get_langchain_llm(
streaming=True,
)

if provider == GlobalConfig.PROVIDER_OLLAMA:
from langchain_ollama.llms import OllamaLLM

logger.debug('Getting LLM via Ollama: %s', model)
return OllamaLLM(
model=model,
temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
num_predict=max_new_tokens,
format='json',
streaming=True,
)

return None


Expand All @@ -163,4 +184,4 @@ def get_langchain_llm(
]

for text in inputs:
print(get_provider_model(text))
print(get_provider_model(text, use_ollama=False))
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ langchain-core~=0.3.0
langchain-community==0.3.0
langchain-google-genai==2.0.6
langchain-cohere==0.3.3
langchain-ollama==0.2.1
streamlit~=1.38.0

python-pptx
python-pptx~=0.6.21
# metaphor-python
json5~=0.9.14
requests~=2.32.3
Expand All @@ -32,3 +33,7 @@ certifi==2024.8.30
urllib3==2.2.3

anyio==4.4.0

httpx~=0.27.2
huggingface-hub~=0.24.5
ollama~=0.4.3

0 comments on commit 89a3160

Please sign in to comment.