Skip to content

Commit

Permalink
[Fix] Fixing http_request open_ai_client mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
sunishsheth2009 committed Jan 22, 2025
1 parent 5576d32 commit 77e838e
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json as js
from typing import Dict, Optional
import databricks
from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict

from databricks.sdk.service.serving import (ExternalFunctionRequestHttpMethod,
ExternalFunctionResponse,
Expand All @@ -8,6 +10,24 @@

class ServingEndpointsExt(ServingEndpointsAPI):

@dataclass
class ExternalFunctionResponseOverride(ExternalFunctionResponse):
text: Dict[str, Any] = None
"""The content of the response"""

@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "ExternalFunctionResponse":
"""Deserializes the ExternalFunctionResponse from a dictionary."""
return cls(status_code=200, text=d)

def to_dict(self) -> Dict[str, Any]:
"""Serializes the object back into a dictionary."""
result = asdict(self) # Use dataclasses.asdict to serialize fields
# Ensure the text field is serialized correctly
if self.text is not None and not isinstance(self.text, str):
result["text"] = js.dumps(self.text) # Serialize text as JSON if it's a dict
return result

# Using the HTTP Client to pass in the databricks authorization
# This method will be called on every invocation, so when using with model serving will always get the refreshed token
def _get_authorized_http_client(self):
Expand Down Expand Up @@ -82,10 +102,15 @@ def http_request(self,
:returns: :class:`ExternalFunctionResponse`
"""

return super.http_request(connection_name=conn,
method=method,
path=path,
headers=js.dumps(headers),
json=js.dumps(json),
params=js.dumps(params),
)
databricks.sdk.service.serving.ExternalFunctionResponse = (
ServingEndpointsExt.ExternalFunctionResponseOverride)

response = super().http_request(connection_name=conn,
method=method,
path=path,
headers=js.dumps(headers) if headers is not None else None,
json=js.dumps(json) if json is not None else None,
params=js.dumps(params) if params is not None else None)

# Convert the overridden response back to the original response type
return ExternalFunctionResponse.from_dict(response.to_dict())

0 comments on commit 77e838e

Please sign in to comment.