Streaming token support
This commit is contained in:
@@ -75,9 +75,20 @@ async fn run_interactive(mut agent: Agent, show_prompt: bool, show_code: bool) -
|
||||
println!(
|
||||
"I solve problems by writing and executing code. Tell me what you need to accomplish!"
|
||||
);
|
||||
println!();
|
||||
|
||||
// Display provider and model information
|
||||
match agent.get_provider_info() {
|
||||
Ok((provider, model)) => {
|
||||
println!("🔧 Provider: {} | Model: {}", provider, model);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get provider info: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("Type 'exit' or 'quit' to exit, use Up/Down arrows for command history");
|
||||
println!("Press ESC during operations to cancel the current request");
|
||||
println!();
|
||||
|
||||
// Initialize rustyline editor with history
|
||||
|
||||
@@ -6,8 +6,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::field::debug;
|
||||
use tracing::info;
|
||||
use tracing::{error, field::debug, info};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextWindow {
|
||||
@@ -94,8 +93,9 @@ impl Agent {
|
||||
// Set default provider
|
||||
providers.set_default(&config.providers.default_provider)?;
|
||||
|
||||
// Initialize context window with configured max context length
|
||||
let context_window = ContextWindow::new(config.agent.max_context_length as u32);
|
||||
// Determine context window size based on active provider
|
||||
let context_length = Self::determine_context_length(&config, &providers)?;
|
||||
let context_window = ContextWindow::new(context_length);
|
||||
|
||||
Ok(Self {
|
||||
providers,
|
||||
@@ -104,6 +104,62 @@ impl Agent {
|
||||
})
|
||||
}
|
||||
|
||||
fn determine_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
|
||||
// Get the active provider to determine context length
|
||||
let provider = providers.get(None)?;
|
||||
let provider_name = provider.name();
|
||||
let model_name = provider.model();
|
||||
|
||||
// Use provider-specific context length if available, otherwise fall back to agent config
|
||||
let context_length = match provider_name {
|
||||
"embedded" => {
|
||||
// For embedded models, use the configured context_length or model-specific defaults
|
||||
if let Some(embedded_config) = &config.providers.embedded {
|
||||
embedded_config.context_length.unwrap_or_else(|| {
|
||||
// Model-specific defaults for embedded models
|
||||
match embedded_config.model_type.to_lowercase().as_str() {
|
||||
"codellama" => 16384, // CodeLlama supports 16k context
|
||||
"llama" => 4096, // Base Llama models
|
||||
"mistral" => 8192, // Mistral models
|
||||
_ => 4096, // Conservative default
|
||||
}
|
||||
})
|
||||
} else {
|
||||
config.agent.max_context_length as u32
|
||||
}
|
||||
}
|
||||
"openai" => {
|
||||
// OpenAI model-specific context lengths
|
||||
match model_name {
|
||||
m if m.contains("gpt-4") => 128000, // GPT-4 models have 128k context
|
||||
m if m.contains("gpt-3.5") => 16384, // GPT-3.5-turbo has 16k context
|
||||
_ => 4096, // Conservative default
|
||||
}
|
||||
}
|
||||
"anthropic" => {
|
||||
// Anthropic model-specific context lengths
|
||||
match model_name {
|
||||
m if m.contains("claude-3") => 200000, // Claude-3 has 200k context
|
||||
m if m.contains("claude-2") => 100000, // Claude-2 has 100k context
|
||||
_ => 100000, // Conservative default for Claude
|
||||
}
|
||||
}
|
||||
_ => config.agent.max_context_length as u32,
|
||||
};
|
||||
|
||||
info!(
|
||||
"Using context length: {} tokens for provider: {} (model: {})",
|
||||
context_length, provider_name, model_name
|
||||
);
|
||||
|
||||
Ok(context_length)
|
||||
}
|
||||
|
||||
pub fn get_provider_info(&self) -> Result<(String, String)> {
|
||||
let provider = self.providers.get(None)?;
|
||||
Ok((provider.name().to_string(), provider.model().to_string()))
|
||||
}
|
||||
|
||||
pub async fn execute_task(
|
||||
&mut self,
|
||||
description: &str,
|
||||
@@ -168,41 +224,67 @@ impl Agent {
|
||||
) -> Result<String> {
|
||||
info!("Executing task: {}", description);
|
||||
|
||||
let total_start = Instant::now();
|
||||
|
||||
let provider = self.providers.get(None)?;
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are G3, a code-first AI agent. Your goal is to solve problems by writing code that completes the desired task.
|
||||
"You are G3, a general-purpose AI agent. Your goal is to analyze and write code to solve given problems.
|
||||
|
||||
When given a task:
|
||||
1. Analyze what needs to be done
|
||||
2. Rate the difficulty of the task from 1 (easy, file operations) to 10 (difficult, build complex applications like Firefox)
|
||||
3. Choose the most appropriate programming language{}
|
||||
4. Include any necessary imports/dependencies
|
||||
5. Add error handling where appropriate
|
||||
6. Generate code to complete the task, or ask for more details, but no other output
|
||||
G3 uses LLMs with tool calling capability.
|
||||
Tools allow external systems to provide context and data to G3. You solve higher level problems using
|
||||
tools, and can interact with multiple at once. When you want to perform an action, use 'I' as the pronoun.
|
||||
|
||||
Prefer these languages:
|
||||
- Bash/Shell: File operations, system administration, simple tasks
|
||||
- Python: Complex data processing, when libraries are needed
|
||||
- Rust: Performance-critical tasks, system programming
|
||||
# Available Tools
|
||||
- shell:
|
||||
Execute a command in the shell.
|
||||
|
||||
Only use Rust/Python when you need libraries or complex logic that bash can't handle easily.
|
||||
This will return the output and error concatenated into a single string, as
|
||||
you would see from running on the command line. There will also be an indication
|
||||
of if the command succeeded or failed.
|
||||
|
||||
Format your code response in markdown backticks as follows:
|
||||
difficulty rating: [X]
|
||||
```[language]
|
||||
[code]
|
||||
```
|
||||
Avoid commands that produce a large amount of output, and consider piping those outputs to files.
|
||||
|
||||
with nothing afterwards.",
|
||||
if let Some(lang) = language {
|
||||
format!(" (prefer {})", lang)
|
||||
} else {
|
||||
" based on the task type".to_string()
|
||||
}
|
||||
);
|
||||
**Important**: Each shell command runs in its own process. Things like directory changes or
|
||||
sourcing files do not persist between tool calls. So you may need to repeat them each time by
|
||||
stringing together commands, e.g. `cd example && ls` or `source env/bin/activate && pip install numpy`
|
||||
|
||||
Multiple commands: Use ; or && to chain commands, avoid newlines
|
||||
Pathnames: Use absolute paths and avoid cd unless explicitly requested
|
||||
|
||||
Usage:
|
||||
- Call the `shell` tool with the desired bash/shell commands.
|
||||
|
||||
- search:
|
||||
Search the web for information about any topic.
|
||||
|
||||
- final_output:
|
||||
This tool signals the final output for a user in a conversation and MUST be used for the final message to the user. You must
|
||||
pass in a detailed summary of the work done to this tool call.
|
||||
|
||||
Purpose:
|
||||
- Collects the final output for a user
|
||||
- Provides clear validation feedback when output isn't valid
|
||||
|
||||
Usage:
|
||||
- Call the `final_output` tool with a summary of the work performed.
|
||||
|
||||
# Response Guidelines
|
||||
- Use Markdown formatting for all responses.
|
||||
- Follow best practices for Markdown, including:
|
||||
- Using headers for organization.
|
||||
- Bullet points for lists.
|
||||
- Links formatted correctly, either as linked text (e.g., [this is linked text](https://example.com)) or automatic links using angle brackets (e.g., <http://example.com/>).
|
||||
- For code, use fenced code blocks by placing triple backticks (` ``` `) before and after the code. Include the language identifier after the opening backticks (e.g., ` ```python `) to enable syntax highlighting.
|
||||
- Ensure clarity, conciseness, and proper formatting to enhance readability and usability.
|
||||
|
||||
IMPORTANT INSTRUCTIONS:
|
||||
|
||||
Please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user.
|
||||
Only terminate your turn when you are sure that the problem is solved.
|
||||
|
||||
If you are not sure about file content or codebase structure, or other information pertaining to the user's request,
|
||||
use your tools to read files and gather the relevant information: do NOT guess or make up an answer. It is important
|
||||
you use tools that can assist with providing the right context.
|
||||
");
|
||||
|
||||
if show_prompt {
|
||||
println!("🔍 System Prompt:");
|
||||
@@ -232,26 +314,33 @@ with nothing afterwards.",
|
||||
messages,
|
||||
max_tokens: Some(2048),
|
||||
temperature: Some(0.2),
|
||||
stream: false,
|
||||
stream: true, // Enable streaming
|
||||
};
|
||||
|
||||
// Time the LLM call with cancellation support
|
||||
// Time the LLM call with cancellation support and streaming
|
||||
let llm_start = Instant::now();
|
||||
let response = tokio::select! {
|
||||
result = provider.complete(request) => result?,
|
||||
let response_content = tokio::select! {
|
||||
result = self.stream_completion(request) => result?,
|
||||
_ = cancellation_token.cancelled() => {
|
||||
return Err(anyhow::anyhow!("Operation cancelled by user"));
|
||||
}
|
||||
};
|
||||
let llm_duration = llm_start.elapsed();
|
||||
|
||||
// Update context window with actual token usage
|
||||
self.context_window.update_usage(&response.usage);
|
||||
// Create a mock usage for now (we'll need to track this during streaming)
|
||||
let mock_usage = g3_providers::Usage {
|
||||
prompt_tokens: 100, // Estimate
|
||||
completion_tokens: response_content.len() as u32 / 4, // Rough estimate
|
||||
total_tokens: 100 + (response_content.len() as u32 / 4),
|
||||
};
|
||||
|
||||
// Update context window with estimated token usage
|
||||
self.context_window.update_usage(&mock_usage);
|
||||
|
||||
// Add assistant response to context window
|
||||
let assistant_message = Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: response.content.clone(),
|
||||
content: response_content.clone(),
|
||||
};
|
||||
self.context_window.add_message(assistant_message);
|
||||
|
||||
@@ -259,19 +348,16 @@ with nothing afterwards.",
|
||||
let exec_start = Instant::now();
|
||||
let executor = CodeExecutor::new();
|
||||
let result = tokio::select! {
|
||||
result = executor.execute_from_response_with_options(&response.content, show_code) => result?,
|
||||
result = executor.execute_from_response_with_options(&response_content, show_code) => result?,
|
||||
_ = cancellation_token.cancelled() => {
|
||||
return Err(anyhow::anyhow!("Operation cancelled by user"));
|
||||
}
|
||||
};
|
||||
let exec_duration = exec_start.elapsed();
|
||||
|
||||
let total_duration = total_start.elapsed();
|
||||
|
||||
if show_timing {
|
||||
let timing_summary = format!(
|
||||
"\n{} [💡: {} ⚡️: {}]",
|
||||
Self::format_duration(total_duration),
|
||||
"\n💭 {} | ⚡️ {}",
|
||||
Self::format_duration(llm_duration),
|
||||
Self::format_duration(exec_duration)
|
||||
);
|
||||
@@ -285,6 +371,39 @@ with nothing afterwards.",
|
||||
&self.context_window
|
||||
}
|
||||
|
||||
async fn stream_completion(&self, request: CompletionRequest) -> Result<String> {
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
let provider = self.providers.get(None)?;
|
||||
let mut stream = provider.stream(request).await?;
|
||||
|
||||
let mut full_content = String::new();
|
||||
print!("🤖 "); // Start the response indicator
|
||||
use std::io::{self, Write};
|
||||
io::stdout().flush()?;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
print!("{}", chunk.content);
|
||||
io::stdout().flush()?;
|
||||
full_content.push_str(&chunk.content);
|
||||
|
||||
if chunk.finished {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Streaming error: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!(); // New line after streaming completes
|
||||
Ok(full_content)
|
||||
}
|
||||
|
||||
fn format_duration(duration: Duration) -> String {
|
||||
let total_ms = duration.as_millis();
|
||||
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole};
|
||||
use anyhow::Result;
|
||||
use llama_cpp::{LlamaModel, LlamaSession, LlamaParams, SessionParams, standard_sampler::{StandardSampler, SamplerStage}};
|
||||
use g3_providers::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Usage,
|
||||
};
|
||||
use llama_cpp::{
|
||||
standard_sampler::{SamplerStage, StandardSampler},
|
||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||
};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, info, error, warn};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
pub struct EmbeddedProvider {
|
||||
model: Arc<LlamaModel>,
|
||||
@@ -30,43 +36,44 @@ impl EmbeddedProvider {
|
||||
threads: Option<u32>,
|
||||
) -> Result<Self> {
|
||||
info!("Loading embedded model from: {}", model_path);
|
||||
|
||||
|
||||
// Expand tilde in path
|
||||
let expanded_path = shellexpand::tilde(&model_path);
|
||||
let model_path = Path::new(expanded_path.as_ref());
|
||||
|
||||
|
||||
if !model_path.exists() {
|
||||
anyhow::bail!("Model file not found: {}", model_path.display());
|
||||
}
|
||||
|
||||
|
||||
// Set up model parameters
|
||||
let mut params = LlamaParams::default();
|
||||
|
||||
|
||||
if let Some(gpu_layers) = gpu_layers {
|
||||
params.n_gpu_layers = gpu_layers;
|
||||
info!("Using {} GPU layers", gpu_layers);
|
||||
}
|
||||
|
||||
|
||||
let context_size = context_length.unwrap_or(4096);
|
||||
info!("Using context length: {}", context_size);
|
||||
|
||||
|
||||
// Load the model
|
||||
info!("Loading model...");
|
||||
let model = LlamaModel::load_from_file(model_path, params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
|
||||
|
||||
|
||||
// Create session with parameters
|
||||
let mut session_params = SessionParams::default();
|
||||
session_params.n_ctx = context_size;
|
||||
if let Some(threads) = threads {
|
||||
session_params.n_threads = threads;
|
||||
}
|
||||
|
||||
let session = model.create_session(session_params)
|
||||
|
||||
let session = model
|
||||
.create_session(session_params)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
|
||||
|
||||
|
||||
info!("Successfully loaded {} model", model_type);
|
||||
|
||||
|
||||
Ok(Self {
|
||||
model: Arc::new(model),
|
||||
session: Arc::new(Mutex::new(session)),
|
||||
@@ -77,15 +84,18 @@ impl EmbeddedProvider {
|
||||
generation_active: Arc::new(AtomicBool::new(false)),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
// Use proper prompt format for CodeLlama
|
||||
let mut formatted = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
formatted.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n", message.content));
|
||||
formatted.push_str(&format!(
|
||||
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
|
||||
message.content
|
||||
));
|
||||
}
|
||||
MessageRole::User => {
|
||||
formatted.push_str(&format!("{} [/INST] ", message.content));
|
||||
@@ -95,108 +105,127 @@ impl EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
formatted
|
||||
}
|
||||
|
||||
async fn generate_completion(&self, prompt: &str, max_tokens: u32, temperature: f32) -> Result<String> {
|
||||
|
||||
async fn generate_completion(
|
||||
&self,
|
||||
prompt: &str,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
) -> Result<String> {
|
||||
let session = self.session.clone();
|
||||
let prompt = prompt.to_string();
|
||||
|
||||
|
||||
// Calculate dynamic max tokens based on available context headroom
|
||||
let prompt_tokens = self.estimate_tokens(&prompt);
|
||||
let available_tokens = self.context_length.saturating_sub(prompt_tokens).saturating_sub(50); // Reserve 50 tokens for safety
|
||||
let available_tokens = self
|
||||
.context_length
|
||||
.saturating_sub(prompt_tokens)
|
||||
.saturating_sub(50); // Reserve 50 tokens for safety
|
||||
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
|
||||
|
||||
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
||||
|
||||
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
|
||||
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
|
||||
|
||||
|
||||
// Add timeout to the entire operation
|
||||
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
|
||||
|
||||
let result = tokio::time::timeout(timeout_duration, tokio::task::spawn_blocking(move || {
|
||||
let mut session = match session.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
|
||||
};
|
||||
|
||||
debug!("Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens);
|
||||
|
||||
// Set context to the prompt
|
||||
debug!("About to call set_context...");
|
||||
session.set_context(&prompt)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
|
||||
debug!("set_context completed successfully");
|
||||
|
||||
// Create sampler with temperature
|
||||
debug!("Creating sampler...");
|
||||
let stages = vec![
|
||||
SamplerStage::Temperature(temperature),
|
||||
SamplerStage::TopK(40),
|
||||
SamplerStage::TopP(0.9),
|
||||
];
|
||||
let sampler = StandardSampler::new_softmax(stages, 1);
|
||||
debug!("Sampler created successfully");
|
||||
|
||||
// Start completion with dynamic max tokens
|
||||
debug!("About to call start_completing_with with {} max tokens...", dynamic_max_tokens);
|
||||
let mut completion_handle = session.start_completing_with(sampler, dynamic_max_tokens)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
|
||||
debug!("start_completing_with completed successfully");
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
debug!("Starting token generation loop...");
|
||||
|
||||
// Generate tokens with dynamic limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
// Check for timeout on each token
|
||||
if start_time.elapsed() > std::time::Duration::from_secs(25) {
|
||||
debug!("Token generation timeout after {} tokens", token_count);
|
||||
break;
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut session = match session.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||
prompt.len(),
|
||||
prompt_tokens
|
||||
);
|
||||
|
||||
// Set context to the prompt
|
||||
debug!("About to call set_context...");
|
||||
session
|
||||
.set_context(&prompt)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
|
||||
debug!("set_context completed successfully");
|
||||
|
||||
// Create sampler with temperature
|
||||
debug!("Creating sampler...");
|
||||
let stages = vec![
|
||||
SamplerStage::Temperature(temperature),
|
||||
SamplerStage::TopK(40),
|
||||
SamplerStage::TopP(0.9),
|
||||
];
|
||||
let sampler = StandardSampler::new_softmax(stages, 1);
|
||||
debug!("Sampler created successfully");
|
||||
|
||||
// Start completion with dynamic max tokens
|
||||
debug!(
|
||||
"About to call start_completing_with with {} max tokens...",
|
||||
dynamic_max_tokens
|
||||
);
|
||||
let mut completion_handle = session
|
||||
.start_completing_with(sampler, dynamic_max_tokens)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
|
||||
debug!("start_completing_with completed successfully");
|
||||
|
||||
let mut generated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
debug!("Starting token generation loop...");
|
||||
|
||||
// Generate tokens with dynamic limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
// Check for timeout on each token
|
||||
if start_time.elapsed() > std::time::Duration::from_secs(25) {
|
||||
debug!("Token generation timeout after {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
generated_text.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
if token_count <= 10 || token_count % 50 == 0 {
|
||||
debug!("Generated token {}: '{}'", token_count, token_string);
|
||||
}
|
||||
|
||||
// Use dynamic token limit
|
||||
if token_count >= dynamic_max_tokens {
|
||||
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop on completion markers
|
||||
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
|
||||
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
generated_text.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
if token_count <= 10 || token_count % 50 == 0 {
|
||||
debug!("Generated token {}: '{}'", token_count, token_string);
|
||||
}
|
||||
|
||||
// Use dynamic token limit
|
||||
if token_count >= dynamic_max_tokens {
|
||||
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop on completion markers
|
||||
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
|
||||
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
|
||||
// Stop on natural completion points after reasonable generation
|
||||
if token_count >= 20 && (
|
||||
generated_text.trim().ends_with("```") ||
|
||||
(generated_text.contains("```") && generated_text.matches("```").count() % 2 == 0) // Complete code blocks
|
||||
) {
|
||||
debug!("Hit code block completion at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed());
|
||||
Ok((generated_text.trim().to_string(), token_count))
|
||||
})).await;
|
||||
|
||||
|
||||
debug!(
|
||||
"Token generation loop completed. Generated {} tokens in {:?}",
|
||||
token_count,
|
||||
start_time.elapsed()
|
||||
);
|
||||
Ok((generated_text.trim().to_string(), token_count))
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(inner_result) => match inner_result {
|
||||
Ok(task_result) => match task_result {
|
||||
Ok((text, token_count)) => {
|
||||
info!("Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens);
|
||||
info!(
|
||||
"Completed generation: {} tokens (dynamic limit was {})",
|
||||
token_count, dynamic_max_tokens
|
||||
);
|
||||
Ok(text)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -209,7 +238,7 @@ impl EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Helper function to estimate token count from text
|
||||
fn estimate_tokens(&self, text: &str) -> u32 {
|
||||
// Rough estimation: average 4 characters per token
|
||||
@@ -221,20 +250,25 @@ impl EmbeddedProvider {
|
||||
#[async_trait::async_trait]
|
||||
impl LLMProvider for EmbeddedProvider {
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
debug!("Processing completion request with {} messages", request.messages.len());
|
||||
|
||||
debug!(
|
||||
"Processing completion request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let prompt = self.format_messages(&request.messages);
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
|
||||
debug!("Formatted prompt length: {} chars", prompt.len());
|
||||
|
||||
let content = self.generate_completion(&prompt, max_tokens, temperature).await?;
|
||||
|
||||
|
||||
let content = self
|
||||
.generate_completion(&prompt, max_tokens, temperature)
|
||||
.await?;
|
||||
|
||||
// Estimate token usage (rough approximation)
|
||||
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
|
||||
let completion_tokens = (content.len() / 4) as u32;
|
||||
|
||||
|
||||
Ok(CompletionResponse {
|
||||
content,
|
||||
usage: Usage {
|
||||
@@ -245,34 +279,38 @@ impl LLMProvider for EmbeddedProvider {
|
||||
model: self.model_name.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
|
||||
debug!("Processing streaming request with {} messages", request.messages.len());
|
||||
|
||||
debug!(
|
||||
"Processing streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
let prompt = self.format_messages(&request.messages);
|
||||
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
let session = self.session.clone();
|
||||
let prompt = prompt.to_string();
|
||||
|
||||
|
||||
// Spawn streaming task
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let mut session = match session.try_lock() {
|
||||
Ok(ctx) => ctx,
|
||||
Err(_) => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Set context to the prompt
|
||||
if let Err(e) = session.set_context(&prompt) {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Create sampler with temperature
|
||||
let stages = vec![
|
||||
SamplerStage::Temperature(temperature),
|
||||
@@ -280,67 +318,55 @@ impl LLMProvider for EmbeddedProvider {
|
||||
SamplerStage::TopP(0.9),
|
||||
];
|
||||
let sampler = StandardSampler::new_softmax(stages, 1);
|
||||
|
||||
|
||||
// Start completion
|
||||
let mut completion_handle = match session.start_completing_with(sampler, max_tokens as usize) {
|
||||
let mut completion_handle = match session
|
||||
.start_completing_with(sampler, max_tokens as usize)
|
||||
{
|
||||
Ok(handle) => handle,
|
||||
Err(e) => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
|
||||
|
||||
// Stream tokens with proper limits
|
||||
while let Some(token) = completion_handle.next_token() {
|
||||
let token_string = session.model().token_to_piece(token);
|
||||
|
||||
|
||||
accumulated_text.push_str(&token_string);
|
||||
token_count += 1;
|
||||
|
||||
|
||||
let chunk = CompletionChunk {
|
||||
content: token_string.clone(),
|
||||
finished: false,
|
||||
};
|
||||
|
||||
|
||||
if tx.blocking_send(Ok(chunk)).is_err() {
|
||||
break; // Receiver dropped
|
||||
}
|
||||
|
||||
|
||||
// Enforce token limit
|
||||
if token_count >= max_tokens as usize {
|
||||
debug!("Reached max token limit in streaming: {}", max_tokens);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Stop if we hit common stop sequences
|
||||
if accumulated_text.contains("### Human") ||
|
||||
accumulated_text.contains("### System") ||
|
||||
accumulated_text.contains("<|end|>") ||
|
||||
accumulated_text.contains("</s>") ||
|
||||
accumulated_text.trim().ends_with("```") {
|
||||
if accumulated_text.contains("### Human")
|
||||
|| accumulated_text.contains("### System")
|
||||
|| accumulated_text.contains("<|end|>")
|
||||
|| accumulated_text.contains("</s>")
|
||||
{
|
||||
debug!("Hit stop sequence in streaming, stopping generation");
|
||||
break;
|
||||
}
|
||||
|
||||
// Emergency brake for streaming too
|
||||
if token_count > 0 && token_count % 100 == 0 {
|
||||
debug!("Streaming: Generated {} tokens so far", token_count);
|
||||
if accumulated_text.trim().len() > 50 &&
|
||||
(accumulated_text.contains('\n') || accumulated_text.len() > 200) {
|
||||
if accumulated_text.trim().ends_with('.') ||
|
||||
accumulated_text.trim().ends_with('!') ||
|
||||
accumulated_text.trim().ends_with('?') ||
|
||||
accumulated_text.trim().ends_with('\n') {
|
||||
debug!("Found natural stopping point in streaming at {} tokens", token_count);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Send final chunk
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
@@ -348,14 +374,14 @@ impl LLMProvider for EmbeddedProvider {
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
});
|
||||
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"embedded"
|
||||
}
|
||||
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user