mirror of
https://github.com/tcsenpai/agenticSeek.git
synced 2025-06-06 19:15:28 +00:00
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
from .generator import GeneratorLLM
|
|
from vllm import LLM, SamplingParams
|
|
import logging
|
|
from typing import List, Dict
|
|
|
|
class Vllm(GeneratorLLM):
|
|
def __init__(self):
|
|
"""
|
|
Handle generation using vLLM.
|
|
"""
|
|
super().__init__()
|
|
self.logger = logging.getLogger(__name__)
|
|
self.llm = None
|
|
|
|
def convert_history_to_prompt(self, history: List[Dict[str, str]]) -> str:
|
|
"""
|
|
Convert OpenAI-format history to a single prompt string for vLLM.
|
|
"""
|
|
prompt = ""
|
|
for message in history:
|
|
role = message["role"]
|
|
content = message["content"]
|
|
if role == "system":
|
|
prompt += f"System: {content}\n"
|
|
elif role == "user":
|
|
prompt += f"User: {content}\n"
|
|
elif role == "assistant":
|
|
prompt += f"Assistant: {content}\n"
|
|
prompt += "Assistant: "
|
|
return prompt
|
|
|
|
def generate(self, history: List[Dict[str, str]]):
|
|
"""
|
|
Generate response using vLLM from OpenAI-format message history.
|
|
|
|
Args:
|
|
history: List of dictionaries in OpenAI format [{"role": "user", "content": "..."}, ...]
|
|
"""
|
|
self.logger.info(f"Using {self.model} for generation with vLLM")
|
|
if self.llm is None:
|
|
self.llm = LLM(model=self.model)
|
|
|
|
try:
|
|
with self.state.lock:
|
|
self.state.is_generating = True
|
|
self.state.last_complete_sentence = ""
|
|
self.state.current_buffer = ""
|
|
|
|
prompt = self.convert_history_to_prompt(history)
|
|
|
|
sampling_params = SamplingParams(
|
|
temperature=0.7,
|
|
max_tokens=512,
|
|
gpu_memory_utilization=0.5,
|
|
stream=True # Enable streaming
|
|
)
|
|
outputs = self.llm.generate(prompt, sampling_params, use_tqdm=False)
|
|
for output in outputs:
|
|
content = output.outputs[0].text
|
|
with self.state.lock:
|
|
if '.' in content:
|
|
self.logger.info(self.state.current_buffer)
|
|
self.state.current_buffer += content
|
|
with self.state.lock:
|
|
self.logger.info(f"Final output: {self.state.current_buffer}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error during generation: {str(e)}")
|
|
raise e
|
|
|
|
finally:
|
|
self.logger.info("Generation complete")
|
|
with self.state.lock:
|
|
self.state.is_generating = False |