mirror of
https://github.com/tcsenpai/multi1.git
synced 2025-06-06 02:55:21 +00:00
litellm now request a json from the provider while supporting non compliant json schemas
This commit is contained in:
parent
1f21d83f39
commit
4da9935c86
@ -1,5 +1,13 @@
|
|||||||
from api_handlers import BaseHandler
|
from api_handlers import BaseHandler
|
||||||
from litellm import completion
|
from litellm import completion, set_verbose
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import json
|
||||||
|
|
||||||
|
class ResponseSchema(BaseModel):
|
||||||
|
title: str = Field(..., description="Title of the reasoning step")
|
||||||
|
content: str = Field(..., description="Content demonstrating the thought process")
|
||||||
|
confidence: int = Field(..., ge=0, le=100, description="Confidence level (0-100)")
|
||||||
|
next_action: str = Field(..., description="Either 'continue' or 'final_answer'")
|
||||||
|
|
||||||
class LiteLLMHandler(BaseHandler):
|
class LiteLLMHandler(BaseHandler):
|
||||||
def __init__(self, model, api_base=None, api_key=None):
|
def __init__(self, model, api_base=None, api_key=None):
|
||||||
@ -9,15 +17,35 @@ class LiteLLMHandler(BaseHandler):
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
def _make_request(self, messages, max_tokens):
|
def _make_request(self, messages, max_tokens):
|
||||||
|
set_verbose=True
|
||||||
response = completion(
|
response = completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
response_format= { "type": "json_schema", "json_schema": ResponseSchema.model_json_schema() , "strict": True },
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
api_key=self.api_key
|
api_key=self.api_key,
|
||||||
|
stream=False,
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content
|
|
||||||
|
# Parse the JSON content from the response
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
print("\nResponse from LiteLLM:")
|
||||||
|
print(content)
|
||||||
|
print("===\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Warning: Response is not valid JSON. Formatting raw content.")
|
||||||
|
return {
|
||||||
|
"title": "Raw Response",
|
||||||
|
"content": "Warning: Response is not valid JSON. Formatting raw content.\n\n" + content,
|
||||||
|
"confidence": 50,
|
||||||
|
"next_action": "continue"
|
||||||
|
}
|
||||||
|
|
||||||
def _process_response(self, response, is_final_answer):
|
def _process_response(self, response, is_final_answer):
|
||||||
return super()._process_response(response, is_final_answer)
|
# The response is already validated against the schema
|
||||||
|
return response
|
Loading…
x
Reference in New Issue
Block a user