From 1bfc139bfebf41d4f3f72bdc376367b06442adb5 Mon Sep 17 00:00:00 2001 From: nbqu Date: Fri, 23 Feb 2024 15:38:07 +0900 Subject: [PATCH 01/11] fix API parameters --- dsp/modules/google.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 581c98414..da71a16d1 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -1,4 +1,4 @@ -import math +import os from typing import Any, Optional import backoff @@ -6,10 +6,13 @@ try: import google.generativeai as genai + from google.api_core.exceptions import GoogleAPICallError + google_api_error = GoogleAPICallError except ImportError: google_api_error = Exception print("Not loading Google because it is not installed.") + def backoff_hdlr(details): """Handler from https://pypi.org/project/backoff/""" print( @@ -34,7 +37,7 @@ class Google(LM): def __init__( self, - model: str = "gemini-pro-1.0", + model: str = "models/gemini-1.0-pro", api_key: Optional[str] = None, **kwargs ): @@ -51,27 +54,32 @@ def __init__( Additional arguments to pass to the API provider. """ super().__init__(model) - self.google = genai.configure(api_key=self.api_key) + api_key = os.environ.get("GOOGLE_API_KEY") if api_key is None else api_key + genai.configure(api_key=api_key) self.provider = "google" self.kwargs = { - "model_name": model, + "candidate_count": 1, "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], "max_output_tokens": 2048, "top_p": 1, - "top_k": 1, + "top_k": 40, **kwargs } + self.config = genai.GenerationConfig(**self.kwargs) + self.llm = genai.GenerativeModel(model_name=model, generation_config=self.config) + self.history: list[dict[str, Any]] = [] def basic_request(self, prompt: str, **kwargs): raw_kwargs = kwargs kwargs = { **self.kwargs, - "prompt": prompt, **kwargs, } - response = self.co.generate(**kwargs) + + # Google uses "candidate_count" instead of "num_generations" + response = self.llm.generate_content(prompt, generation_config=kwargs) history = { "prompt": prompt, From f1e2e9fd5ad9606ef70184a4b0720c39781d1ace Mon Sep 17 00:00:00 2001 From: nbqu Date: Fri, 23 Feb 2024 15:38:40 +0900 Subject: [PATCH 02/11] Add Google module to DSP package --- dsp/modules/__init__.py | 1 + dspy/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index 1fb45bdac..469f6fee4 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -9,6 +9,7 @@ from .ollama import * from .clarifai import * from .bedrock import * +from .google import * from .hf_client import HFClientTGI diff --git a/dspy/__init__.py b/dspy/__init__.py index c30295518..b5fadb4e0 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -21,6 +21,7 @@ ColBERTv2 = dsp.ColBERTv2 Pyserini = dsp.PyseriniRetriever Clarifai = dsp.ClarifaiLLM +Google = dsp.Google HFClientTGI = dsp.HFClientTGI HFClientVLLM = HFClientVLLM From 768fb186b051b9c283aed6222893d109fc5da1d9 Mon Sep 17 00:00:00 2001 From: nbqu Date: Fri, 23 Feb 2024 16:32:05 +0900 Subject: [PATCH 03/11] Refactor Google API usage to support "num_generations" argument --- dsp/modules/google.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index da71a16d1..7a1b5d901 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -56,8 +56,12 @@ def __init__( super().__init__(model) api_key = os.environ.get("GOOGLE_API_KEY") if api_key is None else api_key genai.configure(api_key=api_key) + + # Google API uses "candidate_count" instead of "n" or "num_generations" + # For now, google API only supports 1 generation at a time. Raises an error if candidate_count > 1 + num_generations = kwargs.pop("n", kwargs.pop("num_generations", 1)) self.provider = "google" - self.kwargs = { + kwargs = { "candidate_count": 1, "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], "max_output_tokens": 2048, @@ -66,9 +70,14 @@ def __init__( **kwargs } - self.config = genai.GenerationConfig(**self.kwargs) + self.config = genai.GenerationConfig(**kwargs) self.llm = genai.GenerativeModel(model_name=model, generation_config=self.config) + self.kwargs = { + "n": num_generations, + **kwargs, + } + self.history: list[dict[str, Any]] = [] def basic_request(self, prompt: str, **kwargs): @@ -78,7 +87,9 @@ def basic_request(self, prompt: str, **kwargs): **kwargs, } - # Google uses "candidate_count" instead of "num_generations" + # Google disallows "n" arguments + kwargs.pop("n", None) + response = self.llm.generate_content(prompt, generation_config=kwargs) history = { @@ -109,4 +120,14 @@ def __call__( return_sorted: bool = False, **kwargs ): - return self.request(prompt, **kwargs) + assert only_completed, "for now" + assert return_sorted is False, "for now" + + n = kwargs.pop("n", 1) + + completions = [] + for i in range(n): + response = self.request(prompt, **kwargs) + completions.append(response.text) + + return completions From e6a31abd8b11c590ae6cc2b42b169197615d3153 Mon Sep 17 00:00:00 2001 From: nbqu Date: Mon, 26 Feb 2024 13:31:09 +0900 Subject: [PATCH 04/11] Add safety settings to Google API wrapper --- dsp/modules/google.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 7a1b5d901..10f1df44b 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -29,6 +29,25 @@ def giveup_hdlr(details): return True +BLOCK_ONLY_HIGH = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH" + }, +] + class Google(LM): """Wrapper around Google's API. @@ -39,6 +58,7 @@ def __init__( self, model: str = "models/gemini-1.0-pro", api_key: Optional[str] = None, + safety_settings: Optional[Iterable] = BLOCK_ONLY_HIGH, **kwargs ): """ @@ -71,7 +91,9 @@ def __init__( } self.config = genai.GenerationConfig(**kwargs) - self.llm = genai.GenerativeModel(model_name=model, generation_config=self.config) + self.llm = genai.GenerativeModel(model_name=model, + generation_config=self.config, + safety_settings=safety_settings) self.kwargs = { "n": num_generations, From b519d9e167e3aedda505eadd2ccbbd05052b0b4d Mon Sep 17 00:00:00 2001 From: nbqu Date: Mon, 26 Feb 2024 13:31:52 +0900 Subject: [PATCH 05/11] Update Google module to handle multiple generations with temperature 0.0 --- dsp/modules/google.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 10f1df44b..eea13d7af 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional +from typing import Any, Iterable, Optional import backoff from dsp.modules.lm import LM @@ -80,6 +80,9 @@ def __init__( # Google API uses "candidate_count" instead of "n" or "num_generations" # For now, google API only supports 1 generation at a time. Raises an error if candidate_count > 1 num_generations = kwargs.pop("n", kwargs.pop("num_generations", 1)) + if num_generations > 1 and kwargs['temperature'] == 0.0: + kwargs['temperature'] = 0.7 + self.provider = "google" kwargs = { "candidate_count": 1, @@ -110,7 +113,9 @@ def basic_request(self, prompt: str, **kwargs): } # Google disallows "n" arguments - kwargs.pop("n", None) + n = kwargs.pop("n", None) + if n is not None and n > 1 and kwargs['temperature'] == 0.0: + kwargs['temperature'] = 0.7 response = self.llm.generate_content(prompt, generation_config=kwargs) @@ -128,6 +133,7 @@ def basic_request(self, prompt: str, **kwargs): backoff.expo, (google_api_error), max_time=1000, + max_tries=5, on_backoff=backoff_hdlr, giveup=giveup_hdlr, ) @@ -150,6 +156,6 @@ def __call__( completions = [] for i in range(n): response = self.request(prompt, **kwargs) - completions.append(response.text) + completions.append(response.parts[0].text) return completions From 0339b13538897f8ca3f9cbda40ca02b0e183e1bc Mon Sep 17 00:00:00 2001 From: nbqu Date: Mon, 26 Feb 2024 14:18:00 +0900 Subject: [PATCH 06/11] Add inspect_history method to Google class --- dsp/modules/google.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index eea13d7af..78c4ff0f8 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -159,3 +159,41 @@ def __call__( completions.append(response.parts[0].text) return completions + + def inspect_history(self, n: int = 1, skip: int = 0): + """Prints the last n prompts and their completions. + TODO: print the valid choice that contains filled output field instead of the first + """ + + last_prompt = None + printed = [] + n = n + skip + + for x in reversed(self.history[-100:]): + prompt = x["prompt"] + + if prompt != last_prompt: + printed.append( + ( + prompt, + x['response'] + ) + ) + + last_prompt = prompt + + if len(printed) >= n: + break + + for idx, (prompt, response) in enumerate(reversed(printed)): + # skip the first `skip` prompts + if (n - idx - 1) < skip: + continue + + print("\n\n\n") + print(prompt, end="") + text = response.parts[0].text + self.print_green(text, end="") + print("\n\n\n") + + From ff16223e2012098f44b528dcc4a12088cffacee8 Mon Sep 17 00:00:00 2001 From: nbqu Date: Mon, 26 Feb 2024 20:54:56 +0900 Subject: [PATCH 07/11] rebased from head --- dsp/modules/google.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index a422c637f..2fdd99887 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -10,8 +10,7 @@ google_api_error = GoogleAPICallError except ImportError: google_api_error = Exception - # print("Not loading Google because it is not installed.") - + print("Not loading Google because it is not installed.") def backoff_hdlr(details): @@ -132,7 +131,7 @@ def basic_request(self, prompt: str, **kwargs): @backoff.on_exception( backoff.expo, - (Exception), + (google_api_error), max_time=1000, max_tries=5, on_backoff=backoff_hdlr, @@ -147,7 +146,7 @@ def __call__( prompt: str, only_completed: bool = True, return_sorted: bool = False, - **kwargs, + **kwargs ): assert only_completed, "for now" assert return_sorted is False, "for now" From 9e660fda458746e73a4ed208acc2471e3512817e Mon Sep 17 00:00:00 2001 From: nbqu Date: Tue, 27 Feb 2024 10:09:25 +0900 Subject: [PATCH 08/11] deduplication: raise temperature only when requested --- dsp/modules/google.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 2fdd99887..3a5a57063 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -48,6 +48,7 @@ def giveup_hdlr(details): }, ] + class Google(LM): """Wrapper around Google's API. @@ -80,8 +81,6 @@ def __init__( # Google API uses "candidate_count" instead of "n" or "num_generations" # For now, google API only supports 1 generation at a time. Raises an error if candidate_count > 1 num_generations = kwargs.pop("n", kwargs.pop("num_generations", 1)) - if num_generations > 1 and kwargs['temperature'] == 0.0: - kwargs['temperature'] = 0.7 self.provider = "google" kwargs = { @@ -195,5 +194,3 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = response.parts[0].text self.print_green(text, end="") print("\n\n\n") - - From 48f241c43126fcc8797f3c26260d04e7f2e047f5 Mon Sep 17 00:00:00 2001 From: nbqu Date: Tue, 27 Feb 2024 12:28:13 +0900 Subject: [PATCH 09/11] Fix response format in Google class and update condition in LM class for deduplication --- dsp/modules/google.py | 38 +------------------------------------- dsp/modules/lm.py | 8 +++++--- 2 files changed, 6 insertions(+), 40 deletions(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 3a5a57063..c6c6bd43e 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -120,7 +120,7 @@ def basic_request(self, prompt: str, **kwargs): history = { "prompt": prompt, - "response": response, + "response": [response], "kwargs": kwargs, "raw_kwargs": raw_kwargs, } @@ -158,39 +158,3 @@ def __call__( completions.append(response.parts[0].text) return completions - - def inspect_history(self, n: int = 1, skip: int = 0): - """Prints the last n prompts and their completions. - TODO: print the valid choice that contains filled output field instead of the first - """ - - last_prompt = None - printed = [] - n = n + skip - - for x in reversed(self.history[-100:]): - prompt = x["prompt"] - - if prompt != last_prompt: - printed.append( - ( - prompt, - x['response'] - ) - ) - - last_prompt = prompt - - if len(printed) >= n: - break - - for idx, (prompt, response) in enumerate(reversed(printed)): - # skip the first `skip` prompts - if (n - idx - 1) < skip: - continue - - print("\n\n\n") - print(prompt, end="") - text = response.parts[0].text - self.print_green(text, end="") - print("\n\n\n") diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index 6305338d7..e2965d49a 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -46,14 +46,14 @@ def inspect_history(self, n: int = 1, skip: int = 0): if prompt != last_prompt: - if provider=="clarifai": + if provider == "clarifai" or provider == "google": printed.append( ( prompt, x['response'] - ) + ) ) - else: + else: printed.append( ( prompt, @@ -82,6 +82,8 @@ def inspect_history(self, n: int = 1, skip: int = 0): text = ' ' + self._get_choice_text(choices[0]).strip() elif provider == "clarifai": text=choices + elif provider == "google": + text = choices[0].parts[0].text else: text = choices[0]["text"] self.print_green(text, end="") From 8328ac63d48c400c92459f7259293e0816f048e3 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 27 Feb 2024 13:56:07 -0800 Subject: [PATCH 10/11] Update google.py --- dsp/modules/google.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index c6c6bd43e..27d4ecf82 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -10,7 +10,7 @@ google_api_error = GoogleAPICallError except ImportError: google_api_error = Exception - print("Not loading Google because it is not installed.") + # print("Not loading Google because it is not installed.") def backoff_hdlr(details): From c6e96ed4e2c66ea77310c3c149091815215d2719 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Tue, 27 Feb 2024 18:32:51 -0500 Subject: [PATCH 11/11] Slightly increase max_tries, issue around 500 errors --- dsp/modules/google.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsp/modules/google.py b/dsp/modules/google.py index 27d4ecf82..d7546fee3 100644 --- a/dsp/modules/google.py +++ b/dsp/modules/google.py @@ -132,7 +132,7 @@ def basic_request(self, prompt: str, **kwargs): backoff.expo, (google_api_error), max_time=1000, - max_tries=5, + max_tries=8, on_backoff=backoff_hdlr, giveup=giveup_hdlr, )