Skip to content

Commit

Permalink
Merge pull request #24 from TJC-LP/release-v0.2.1
Browse files Browse the repository at this point in the history
Release v0.2.1
  • Loading branch information
arcaputo3 authored Sep 24, 2024
2 parents 2a6ef7d + 9de1b1f commit 2836604
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "spark-instructor"
version = "0.2.0"
version = "0.2.1"
description = "A library for building structured LLM responses with Spark"
readme = "README.md"
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion spark_instructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Init for ``spark_instructor``."""

__version__ = "0.2.0"
__version__ = "0.2.1"
__author__ = "Richie Caputo"
__email__ = "[email protected]"

Expand Down
133 changes: 132 additions & 1 deletion tests/utils/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyspark.sql.functions import col, lit
from pyspark.sql.functions import array, col, lit, struct
from pyspark.sql.types import (
ArrayType,
MapType,
Expand Down Expand Up @@ -88,6 +88,137 @@ def test_create_chat_completion_messages(spark):
schema = df.select(result.alias("messages")).schema

assert isinstance(schema, StructType)
assert schema == StructType(
[
StructField(
"messages",
ArrayType(
StructType(
[
StructField("role", StringType(), False),
StructField("content", StringType(), True),
StructField(
"image_urls",
ArrayType(
StructType(
[
StructField("url", StringType(), False),
StructField("detail", StringType(), True),
]
),
True,
),
True,
),
StructField("name", StringType(), True),
StructField(
"tool_calls",
ArrayType(
StructType(
[
StructField("id", StringType(), False),
StructField(
"function",
StructType(
[
StructField("arguments", StringType(), False),
StructField("name", StringType(), False),
]
),
False,
),
StructField("type", StringType(), False),
]
),
True,
),
True,
),
StructField("tool_call_id", StringType(), True),
]
),
False,
),
False,
)
]
)

result_data = df.collect()[0]["messages"]
assert len(result_data) == 2
assert result_data[0]["role"] == "system"
assert result_data[0]["content"] == "Be helpful"
assert result_data[1]["role"] == "user"
assert result_data[1]["content"] == "Hello"


def test_create_chat_completion_messages_nullable(spark, valid_base64):
# Test with minimal required fields
df = spark.createDataFrame([("Hello", "Be helpful")], ["user_msg", "sys_msg"]).withColumn(
"image_urls", array(struct(lit(valid_base64).alias("url"), lit("auto").alias("detail")))
)
messages = [
{"role": lit("system"), "content": "sys_msg"},
{"role": lit("user"), "content": "user_msg", "image_urls": "image_urls"},
]
result = create_chat_completion_messages(messages, strict=False)
df = df.withColumn("messages", result)
schema = df.select(result.alias("messages")).schema
assert schema == StructType(
[
StructField(
"messages",
ArrayType(
StructType(
[
StructField("role", StringType(), True),
StructField("content", StringType(), True),
StructField(
"image_urls",
ArrayType(
StructType(
[
StructField("url", StringType(), True),
StructField("detail", StringType(), True),
]
),
True,
),
True,
),
StructField("name", StringType(), True),
StructField(
"tool_calls",
ArrayType(
StructType(
[
StructField("id", StringType(), True),
StructField(
"function",
StructType(
[
StructField("arguments", StringType(), True),
StructField("name", StringType(), True),
]
),
True,
),
StructField("type", StringType(), True),
]
),
True,
),
True,
),
StructField("tool_call_id", StringType(), True),
]
),
False,
),
False,
)
]
)

result_data = df.collect()[0]["messages"]
assert len(result_data) == 2
Expand Down

0 comments on commit 2836604

Please sign in to comment.