Skip to content

Commit

Permalink
sorting apis
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Oct 29, 2024
1 parent 46e65a6 commit 2bfda88
Show file tree
Hide file tree
Showing 13 changed files with 442 additions and 6 deletions.
3 changes: 2 additions & 1 deletion classifiers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ dependencies = [
"torch",
"transformers",
"wandb",
"jq"
"jq",
"grequests"
]

[project.urls]
Expand Down
Empty file.
5 changes: 5 additions & 0 deletions classifiers/src/dolma_classifiers/inference/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .inference import main, parse_args

if __name__ == "__main__":
args = parse_args()
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,3 @@ def parse_args() -> argparse.Namespace:
WandbLogger.name = opts.wandb_name or WandbLogger.name or sanitize_model_name(opts.model_name, opts.__dict__)

return opts


if __name__ == "__main__":
args = parse_args()
main(args)
File renamed without changes.
File renamed without changes.
Empty file.
197 changes: 197 additions & 0 deletions classifiers/src/dolma_classifiers/label/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#!/usr/bin/env python3

import argparse
import glob
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import urlparse

import grequests
import jinja2
import urllib3

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

OPENAI_API_ENDPOINT = "https://api.openai.com/v1/chat/completions"


class DocumentProcessor:
def __init__(
self,
documents_path: str,
destination: str,
prompt_template: str,
api_key: str,
batch_size: int = 5,
max_retries: int = 3,
retry_delay: int = 1
):
self.documents_path = documents_path
self.destination = destination
self.prompt_template = prompt_template
self.api_key = api_key
self.batch_size = batch_size
self.max_retries = max_retries
self.retry_delay = retry_delay
self.template = jinja2.Template(prompt_template)

def _create_request(self, document: Dict[str, Any]) -> grequests.AsyncRequest:
"""Create a single grequest for a document."""
try:
# Render the prompt template with document fields
prompt = self.template.render(**document)

# Prepare the request payload
payload = {
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant that processes documents."},
{"role": "user", "content": prompt}
]
}

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}

# Create the request object
return grequests.post(
OPENAI_API_ENDPOINT,
json=payload,
headers=headers,
timeout=30
), document

except Exception as e:
logger.error(f"Error creating request: {e}")
return None

def _process_response(self, response, document: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single response from the API."""
try:
if response.status_code == 200:
result = response.json()
document['gpt4_response'] = result['choices'][0]['message']['content']
else:
document['error'] = f"API Error: {response.status_code} - {response.text}"
except Exception as e:
document['error'] = f"Processing Error: {str(e)}"

return document

def _process_batch(self, batch: List[Dict[str, Any]], output_file: str):
"""Process a batch of documents and write results to output file."""
# Create request objects for the batch
request_pairs = [self._create_request(doc) for doc in batch]
requests, documents = zip(*[pair for pair in request_pairs if pair is not None])

# Make async requests
responses = grequests.map(requests, size=len(requests))

# Process responses and write to file
with open(output_file, 'a') as f:
for response, document in zip(responses, documents):
result = self._process_response(response, document)
f.write(json.dumps(result) + '\n')

def _download_file(self, url: str, local_path: str) -> str:
"""Download a remote file to local storage."""
with urllib3.PoolManager() as http:
response = http.request('GET', url)
if response.status == 200:
with open(local_path, 'w') as f:
f.write(response.data.decode('utf-8'))
return local_path
else:
raise Exception(f"Failed to download file: {response.status}")

def _get_file_paths(self) -> List[str]:
"""Get list of files to process, handling both local and remote paths."""
if urlparse(self.documents_path).scheme in ('http', 'https'):
# Handle remote files
temp_dir = Path('temp_downloads')
temp_dir.mkdir(exist_ok=True)

# Download remote files
local_paths = []
with urllib3.PoolManager() as http:
response = http.request('GET', self.documents_path)
if response.status == 200:
file_list = response.data.decode('utf-8').splitlines()
for url in file_list:
local_path = temp_dir / Path(urlparse(url).path).name
self._download_file(url, str(local_path))
local_paths.append(str(local_path))
return local_paths
else:
# Handle local files
return glob.glob(self.documents_path)

def process_files(self):
"""Main method to process all files."""
# Create destination directory if it doesn't exist
os.makedirs(self.destination, exist_ok=True)

# Get list of files to process
file_paths = self._get_file_paths()
logger.info(f"Found {len(file_paths)} files to process")

for file_path in file_paths:
try:
# Read input file
with open(file_path, 'r') as f:
documents = [json.loads(line) for line in f]

# Create output file path
output_file = os.path.join(
self.destination,
f"processed_{os.path.basename(file_path)}"
)

# Process documents in batches
for i in range(0, len(documents), self.batch_size):
batch = documents[i:i + self.batch_size]
self._process_batch(batch, output_file)
logger.info(f"Processed batch {i//self.batch_size + 1} of file {file_path}")

except Exception as e:
logger.error(f"Error processing file {file_path}: {e}")

def main():
parser = argparse.ArgumentParser(description='Process documents with GPT-4')
parser.add_argument('--documents', required=True, help='Glob pattern for input documents')
parser.add_argument('--destination', required=True, help='Output directory')
parser.add_argument('--prompt', required=True, help='Prompt template')
parser.add_argument('--api-key', required=True, help='OpenAI API key')
parser.add_argument('--batch-size', type=int, default=5, help='Batch size for processing')

args = parser.parse_args()

# Read prompt template from file if it's a file path
prompt_template = args.prompt
if os.path.isfile(args.prompt):
with open(args.prompt, 'r') as f:
prompt_template = f.read()

processor = DocumentProcessor(
documents_path=args.documents,
destination=args.destination,
prompt_template=prompt_template,
api_key=args.api_key,
batch_size=args.batch_size
)

# Run the processor
processor.process_files()

if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions classifiers/src/dolma_classifiers/label/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os
from dataclasses import dataclass, field

import aiohttp


@dataclass(frozen=True)
class Message:
role: str
content: str

def to_dict(self):
return {
"role": self.role,
"content": self.content
}


@dataclass(frozen=True)
class BaseApiRequest:
endpoint: str
messages: list[Message]
parameters: dict = field(default_factory=dict)
headers: dict = field(default_factory=dict)

async def make(self):
payload = {**self.parameters, "messages": [message.to_dict() for message in self.messages]}
async with aiohttp.ClientSession() as session:
async with session.post(self.endpoint, json=payload, headers=self.headers) as response:
return await response.json()


@dataclass(frozen=True)
class Gpt4oRequest(BaseApiRequest):
model: str = "gpt-4o"
temperature: float = 1.0
top_p: float = 1.0
headers: dict = field(
default_factory=lambda: {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
}
)

def __post_init__(self):
self.parameters.update({
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p
})
107 changes: 107 additions & 0 deletions classifiers/src/dolma_classifiers/label/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any, Dict, Optional

import jq


class JqTemplate:
"""
A template engine that processes strings containing JQ expressions in {expression} syntax.
Supports escaping curly braces with {{ and }}.
"""

def __init__(self, template_string: str):
"""
Initialize the template with a template string.
Args:
template_string: The template string containing JQ expressions in {expression} syntax
"""
self.template_string = template_string
self._compiled = self._compile_template(template_string)

@staticmethod
def _compile_template(template_string: str) -> list[tuple[str, Optional[jq.jq]]]:
"""
Compile the template string into a list of (text, expression) tuples.
Args:
template_string: The template string to compile
Returns:
List of tuples containing (text, compiled_jq_expression)
Raises:
ValueError: If there are unmatched braces or invalid JQ expressions
"""
parts = []
current_pos = 0

# Handle escaped braces first
template_string = template_string.replace("{{", "\0LEFT_BRACE\0").replace("}}", "\0RIGHT_BRACE\0")

while current_pos < len(template_string):
# Find next unescaped opening brace
start = template_string.find("{", current_pos)

if start == -1:
# No more expressions, add remaining text
text = template_string[current_pos:]
text = text.replace("\0LEFT_BRACE\0", "{").replace("\0RIGHT_BRACE\0", "}")
parts.append((text, None))
break

# Add text before the expression
if start > current_pos:
text = template_string[current_pos:start]
text = text.replace("\0LEFT_BRACE\0", "{").replace("\0RIGHT_BRACE\0", "}")
parts.append((text, None))

# Find matching closing brace
end = template_string.find("}", start)
if end == -1:
raise ValueError(f"Unmatched opening brace at position {start}")

# Extract and compile JQ expression
expr = template_string[start + 1:end].strip()
try:
compiled_expr = jq.compile(expr)
except ValueError as e:
raise ValueError(f"Invalid JQ expression '{expr}': {str(e)}")

parts.append(("", compiled_expr))
current_pos = end + 1

return parts

def render(self, data: Dict[str, Any]) -> str:
"""
Render the template by evaluating all JQ expressions against the provided data.
Args:
data: Dictionary containing the data to evaluate expressions against
Returns:
The rendered template string
Raises:
ValueError: If any JQ expression fails to evaluate
"""
result = []

for text, expr in self._compiled:
result.append(text)
if expr is None:
continue

try:
# Evaluate expression and get first result
evaluated = expr.input(data).first()
# append the evaluated result to the result list
result.append(str(evaluated or ""))
except StopIteration:
# No results from JQ expression
result.append("")
except Exception as e:
raise ValueError(f"Error evaluating expression: {str(e)}")

return "".join(result)
Empty file added classifiers/tests/__init__.py
Empty file.
Loading

0 comments on commit 2bfda88

Please sign in to comment.