Streaming token support

This commit is contained in:
Dhanji Prasanna
2025-09-08 13:24:39 +10:00
parent 33d4cef00b
commit 1e06b9fea3
3 changed files with 358 additions and 202 deletions

View File

@@ -75,9 +75,20 @@ async fn run_interactive(mut agent: Agent, show_prompt: bool, show_code: bool) -
println!(
"I solve problems by writing and executing code. Tell me what you need to accomplish!"
);
println!();
// Display provider and model information
match agent.get_provider_info() {
Ok((provider, model)) => {
println!("🔧 Provider: {} | Model: {}", provider, model);
}
Err(e) => {
error!("Failed to get provider info: {}", e);
}
}
println!();
println!("Type 'exit' or 'quit' to exit, use Up/Down arrows for command history");
println!("Press ESC during operations to cancel the current request");
println!();
// Initialize rustyline editor with history

View File

@@ -6,8 +6,7 @@ use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::field::debug;
use tracing::info;
use tracing::{error, field::debug, info};
#[derive(Debug, Clone)]
pub struct ContextWindow {
@@ -94,8 +93,9 @@ impl Agent {
// Set default provider
providers.set_default(&config.providers.default_provider)?;
// Initialize context window with configured max context length
let context_window = ContextWindow::new(config.agent.max_context_length as u32);
// Determine context window size based on active provider
let context_length = Self::determine_context_length(&config, &providers)?;
let context_window = ContextWindow::new(context_length);
Ok(Self {
providers,
@@ -104,6 +104,62 @@ impl Agent {
})
}
fn determine_context_length(config: &Config, providers: &ProviderRegistry) -> Result<u32> {
// Get the active provider to determine context length
let provider = providers.get(None)?;
let provider_name = provider.name();
let model_name = provider.model();
// Use provider-specific context length if available, otherwise fall back to agent config
let context_length = match provider_name {
"embedded" => {
// For embedded models, use the configured context_length or model-specific defaults
if let Some(embedded_config) = &config.providers.embedded {
embedded_config.context_length.unwrap_or_else(|| {
// Model-specific defaults for embedded models
match embedded_config.model_type.to_lowercase().as_str() {
"codellama" => 16384, // CodeLlama supports 16k context
"llama" => 4096, // Base Llama models
"mistral" => 8192, // Mistral models
_ => 4096, // Conservative default
}
})
} else {
config.agent.max_context_length as u32
}
}
"openai" => {
// OpenAI model-specific context lengths
match model_name {
m if m.contains("gpt-4") => 128000, // GPT-4 models have 128k context
m if m.contains("gpt-3.5") => 16384, // GPT-3.5-turbo has 16k context
_ => 4096, // Conservative default
}
}
"anthropic" => {
// Anthropic model-specific context lengths
match model_name {
m if m.contains("claude-3") => 200000, // Claude-3 has 200k context
m if m.contains("claude-2") => 100000, // Claude-2 has 100k context
_ => 100000, // Conservative default for Claude
}
}
_ => config.agent.max_context_length as u32,
};
info!(
"Using context length: {} tokens for provider: {} (model: {})",
context_length, provider_name, model_name
);
Ok(context_length)
}
pub fn get_provider_info(&self) -> Result<(String, String)> {
let provider = self.providers.get(None)?;
Ok((provider.name().to_string(), provider.model().to_string()))
}
pub async fn execute_task(
&mut self,
description: &str,
@@ -168,41 +224,67 @@ impl Agent {
) -> Result<String> {
info!("Executing task: {}", description);
let total_start = Instant::now();
let provider = self.providers.get(None)?;
let system_prompt = format!(
"You are G3, a code-first AI agent. Your goal is to solve problems by writing code that completes the desired task.
"You are G3, a general-purpose AI agent. Your goal is to analyze and write code to solve given problems.
When given a task:
1. Analyze what needs to be done
2. Rate the difficulty of the task from 1 (easy, file operations) to 10 (difficult, build complex applications like Firefox)
3. Choose the most appropriate programming language{}
4. Include any necessary imports/dependencies
5. Add error handling where appropriate
6. Generate code to complete the task, or ask for more details, but no other output
G3 uses LLMs with tool calling capability.
Tools allow external systems to provide context and data to G3. You solve higher level problems using
tools, and can interact with multiple at once. When you want to perform an action, use 'I' as the pronoun.
Prefer these languages:
- Bash/Shell: File operations, system administration, simple tasks
- Python: Complex data processing, when libraries are needed
- Rust: Performance-critical tasks, system programming
# Available Tools
- shell:
Execute a command in the shell.
Only use Rust/Python when you need libraries or complex logic that bash can't handle easily.
This will return the output and error concatenated into a single string, as
you would see from running on the command line. There will also be an indication
of if the command succeeded or failed.
Format your code response in markdown backticks as follows:
difficulty rating: [X]
```[language]
[code]
```
Avoid commands that produce a large amount of output, and consider piping those outputs to files.
with nothing afterwards.",
if let Some(lang) = language {
format!(" (prefer {})", lang)
} else {
" based on the task type".to_string()
}
);
**Important**: Each shell command runs in its own process. Things like directory changes or
sourcing files do not persist between tool calls. So you may need to repeat them each time by
stringing together commands, e.g. `cd example && ls` or `source env/bin/activate && pip install numpy`
Multiple commands: Use ; or && to chain commands, avoid newlines
Pathnames: Use absolute paths and avoid cd unless explicitly requested
Usage:
- Call the `shell` tool with the desired bash/shell commands.
- search:
Search the web for information about any topic.
- final_output:
This tool signals the final output for a user in a conversation and MUST be used for the final message to the user. You must
pass in a detailed summary of the work done to this tool call.
Purpose:
- Collects the final output for a user
- Provides clear validation feedback when output isn't valid
Usage:
- Call the `final_output` tool with a summary of the work performed.
# Response Guidelines
- Use Markdown formatting for all responses.
- Follow best practices for Markdown, including:
- Using headers for organization.
- Bullet points for lists.
- Links formatted correctly, either as linked text (e.g., [this is linked text](https://example.com)) or automatic links using angle brackets (e.g., <http://example.com/>).
- For code, use fenced code blocks by placing triple backticks (` ``` `) before and after the code. Include the language identifier after the opening backticks (e.g., ` ```python `) to enable syntax highlighting.
- Ensure clarity, conciseness, and proper formatting to enhance readability and usability.
IMPORTANT INSTRUCTIONS:
Please keep going until the user's query is completely resolved, before ending your turn and yielding back to the user.
Only terminate your turn when you are sure that the problem is solved.
If you are not sure about file content or codebase structure, or other information pertaining to the user's request,
use your tools to read files and gather the relevant information: do NOT guess or make up an answer. It is important
you use tools that can assist with providing the right context.
");
if show_prompt {
println!("🔍 System Prompt:");
@@ -232,26 +314,33 @@ with nothing afterwards.",
messages,
max_tokens: Some(2048),
temperature: Some(0.2),
stream: false,
stream: true, // Enable streaming
};
// Time the LLM call with cancellation support
// Time the LLM call with cancellation support and streaming
let llm_start = Instant::now();
let response = tokio::select! {
result = provider.complete(request) => result?,
let response_content = tokio::select! {
result = self.stream_completion(request) => result?,
_ = cancellation_token.cancelled() => {
return Err(anyhow::anyhow!("Operation cancelled by user"));
}
};
let llm_duration = llm_start.elapsed();
// Update context window with actual token usage
self.context_window.update_usage(&response.usage);
// Create a mock usage for now (we'll need to track this during streaming)
let mock_usage = g3_providers::Usage {
prompt_tokens: 100, // Estimate
completion_tokens: response_content.len() as u32 / 4, // Rough estimate
total_tokens: 100 + (response_content.len() as u32 / 4),
};
// Update context window with estimated token usage
self.context_window.update_usage(&mock_usage);
// Add assistant response to context window
let assistant_message = Message {
role: MessageRole::Assistant,
content: response.content.clone(),
content: response_content.clone(),
};
self.context_window.add_message(assistant_message);
@@ -259,19 +348,16 @@ with nothing afterwards.",
let exec_start = Instant::now();
let executor = CodeExecutor::new();
let result = tokio::select! {
result = executor.execute_from_response_with_options(&response.content, show_code) => result?,
result = executor.execute_from_response_with_options(&response_content, show_code) => result?,
_ = cancellation_token.cancelled() => {
return Err(anyhow::anyhow!("Operation cancelled by user"));
}
};
let exec_duration = exec_start.elapsed();
let total_duration = total_start.elapsed();
if show_timing {
let timing_summary = format!(
"\n{} [💡: {} ⚡️: {}]",
Self::format_duration(total_duration),
"\n💭 {} | ⚡️ {}",
Self::format_duration(llm_duration),
Self::format_duration(exec_duration)
);
@@ -285,6 +371,39 @@ with nothing afterwards.",
&self.context_window
}
async fn stream_completion(&self, request: CompletionRequest) -> Result<String> {
use tokio_stream::StreamExt;
let provider = self.providers.get(None)?;
let mut stream = provider.stream(request).await?;
let mut full_content = String::new();
print!("🤖 "); // Start the response indicator
use std::io::{self, Write};
io::stdout().flush()?;
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
print!("{}", chunk.content);
io::stdout().flush()?;
full_content.push_str(&chunk.content);
if chunk.finished {
break;
}
}
Err(e) => {
error!("Streaming error: {}", e);
return Err(e);
}
}
}
println!(); // New line after streaming completes
Ok(full_content)
}
fn format_duration(duration: Duration) -> String {
let total_ms = duration.as_millis();

View File

@@ -1,13 +1,19 @@
use g3_providers::{LLMProvider, CompletionRequest, CompletionResponse, CompletionStream, CompletionChunk, Usage, Message, MessageRole};
use anyhow::Result;
use llama_cpp::{LlamaModel, LlamaSession, LlamaParams, SessionParams, standard_sampler::{StandardSampler, SamplerStage}};
use g3_providers::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
};
use llama_cpp::{
standard_sampler::{SamplerStage, StandardSampler},
LlamaModel, LlamaParams, LlamaSession, SessionParams,
};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info, error, warn};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error, info, warn};
pub struct EmbeddedProvider {
model: Arc<LlamaModel>,
@@ -30,43 +36,44 @@ impl EmbeddedProvider {
threads: Option<u32>,
) -> Result<Self> {
info!("Loading embedded model from: {}", model_path);
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path = Path::new(expanded_path.as_ref());
if !model_path.exists() {
anyhow::bail!("Model file not found: {}", model_path.display());
}
// Set up model parameters
let mut params = LlamaParams::default();
if let Some(gpu_layers) = gpu_layers {
params.n_gpu_layers = gpu_layers;
info!("Using {} GPU layers", gpu_layers);
}
let context_size = context_length.unwrap_or(4096);
info!("Using context length: {}", context_size);
// Load the model
info!("Loading model...");
let model = LlamaModel::load_from_file(model_path, params)
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams::default();
session_params.n_ctx = context_size;
if let Some(threads) = threads {
session_params.n_threads = threads;
}
let session = model.create_session(session_params)
let session = model
.create_session(session_params)
.map_err(|e| anyhow::anyhow!("Failed to create session: {}", e))?;
info!("Successfully loaded {} model", model_type);
Ok(Self {
model: Arc::new(model),
session: Arc::new(Mutex::new(session)),
@@ -77,15 +84,18 @@ impl EmbeddedProvider {
generation_active: Arc::new(AtomicBool::new(false)),
})
}
fn format_messages(&self, messages: &[Message]) -> String {
// Use proper prompt format for CodeLlama
let mut formatted = String::new();
for message in messages {
match message.role {
MessageRole::System => {
formatted.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n", message.content));
formatted.push_str(&format!(
"[INST] <<SYS>>\n{}\n<</SYS>>\n\n",
message.content
));
}
MessageRole::User => {
formatted.push_str(&format!("{} [/INST] ", message.content));
@@ -95,108 +105,127 @@ impl EmbeddedProvider {
}
}
}
formatted
}
async fn generate_completion(&self, prompt: &str, max_tokens: u32, temperature: f32) -> Result<String> {
async fn generate_completion(
&self,
prompt: &str,
max_tokens: u32,
temperature: f32,
) -> Result<String> {
let session = self.session.clone();
let prompt = prompt.to_string();
// Calculate dynamic max tokens based on available context headroom
let prompt_tokens = self.estimate_tokens(&prompt);
let available_tokens = self.context_length.saturating_sub(prompt_tokens).saturating_sub(50); // Reserve 50 tokens for safety
let available_tokens = self
.context_length
.saturating_sub(prompt_tokens)
.saturating_sub(50); // Reserve 50 tokens for safety
let dynamic_max_tokens = std::cmp::min(max_tokens as usize, available_tokens as usize);
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
debug!("Context calculation: prompt_tokens={}, context_length={}, available_tokens={}, dynamic_max_tokens={}",
prompt_tokens, self.context_length, available_tokens, dynamic_max_tokens);
// Add timeout to the entire operation
let timeout_duration = std::time::Duration::from_secs(30); // Increased timeout for larger contexts
let result = tokio::time::timeout(timeout_duration, tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
};
debug!("Starting inference with prompt length: {} chars, estimated {} tokens", prompt.len(), prompt_tokens);
// Set context to the prompt
debug!("About to call set_context...");
session.set_context(&prompt)
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
debug!("set_context completed successfully");
// Create sampler with temperature
debug!("Creating sampler...");
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
debug!("Sampler created successfully");
// Start completion with dynamic max tokens
debug!("About to call start_completing_with with {} max tokens...", dynamic_max_tokens);
let mut completion_handle = session.start_completing_with(sampler, dynamic_max_tokens)
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
debug!("start_completing_with completed successfully");
let mut generated_text = String::new();
let mut token_count = 0;
let start_time = std::time::Instant::now();
debug!("Starting token generation loop...");
// Generate tokens with dynamic limits
while let Some(token) = completion_handle.next_token() {
// Check for timeout on each token
if start_time.elapsed() > std::time::Duration::from_secs(25) {
debug!("Token generation timeout after {} tokens", token_count);
break;
let result = tokio::time::timeout(
timeout_duration,
tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => return Err(anyhow::anyhow!("Model is busy, please try again")),
};
debug!(
"Starting inference with prompt length: {} chars, estimated {} tokens",
prompt.len(),
prompt_tokens
);
// Set context to the prompt
debug!("About to call set_context...");
session
.set_context(&prompt)
.map_err(|e| anyhow::anyhow!("Failed to set context: {}", e))?;
debug!("set_context completed successfully");
// Create sampler with temperature
debug!("Creating sampler...");
let stages = vec![
SamplerStage::Temperature(temperature),
SamplerStage::TopK(40),
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
debug!("Sampler created successfully");
// Start completion with dynamic max tokens
debug!(
"About to call start_completing_with with {} max tokens...",
dynamic_max_tokens
);
let mut completion_handle = session
.start_completing_with(sampler, dynamic_max_tokens)
.map_err(|e| anyhow::anyhow!("Failed to start completion: {}", e))?;
debug!("start_completing_with completed successfully");
let mut generated_text = String::new();
let mut token_count = 0;
let start_time = std::time::Instant::now();
debug!("Starting token generation loop...");
// Generate tokens with dynamic limits
while let Some(token) = completion_handle.next_token() {
// Check for timeout on each token
if start_time.elapsed() > std::time::Duration::from_secs(25) {
debug!("Token generation timeout after {} tokens", token_count);
break;
}
let token_string = session.model().token_to_piece(token);
generated_text.push_str(&token_string);
token_count += 1;
if token_count <= 10 || token_count % 50 == 0 {
debug!("Generated token {}: '{}'", token_count, token_string);
}
// Use dynamic token limit
if token_count >= dynamic_max_tokens {
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
break;
}
// Stop on completion markers
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
break;
}
}
let token_string = session.model().token_to_piece(token);
generated_text.push_str(&token_string);
token_count += 1;
if token_count <= 10 || token_count % 50 == 0 {
debug!("Generated token {}: '{}'", token_count, token_string);
}
// Use dynamic token limit
if token_count >= dynamic_max_tokens {
debug!("Reached dynamic token limit: {}", dynamic_max_tokens);
break;
}
// Stop on completion markers
if generated_text.contains("</s>") || generated_text.contains("[/INST]") {
debug!("Hit CodeLlama stop sequence at {} tokens", token_count);
break;
}
// Stop on natural completion points after reasonable generation
if token_count >= 20 && (
generated_text.trim().ends_with("```") ||
(generated_text.contains("```") && generated_text.matches("```").count() % 2 == 0) // Complete code blocks
) {
debug!("Hit code block completion at {} tokens", token_count);
break;
}
}
debug!("Token generation loop completed. Generated {} tokens in {:?}", token_count, start_time.elapsed());
Ok((generated_text.trim().to_string(), token_count))
})).await;
debug!(
"Token generation loop completed. Generated {} tokens in {:?}",
token_count,
start_time.elapsed()
);
Ok((generated_text.trim().to_string(), token_count))
}),
)
.await;
match result {
Ok(inner_result) => match inner_result {
Ok(task_result) => match task_result {
Ok((text, token_count)) => {
info!("Completed generation: {} tokens (dynamic limit was {})", token_count, dynamic_max_tokens);
info!(
"Completed generation: {} tokens (dynamic limit was {})",
token_count, dynamic_max_tokens
);
Ok(text)
}
Err(e) => Err(e),
@@ -209,7 +238,7 @@ impl EmbeddedProvider {
}
}
}
// Helper function to estimate token count from text
fn estimate_tokens(&self, text: &str) -> u32 {
// Rough estimation: average 4 characters per token
@@ -221,20 +250,25 @@ impl EmbeddedProvider {
#[async_trait::async_trait]
impl LLMProvider for EmbeddedProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
debug!("Processing completion request with {} messages", request.messages.len());
debug!(
"Processing completion request with {} messages",
request.messages.len()
);
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
debug!("Formatted prompt length: {} chars", prompt.len());
let content = self.generate_completion(&prompt, max_tokens, temperature).await?;
let content = self
.generate_completion(&prompt, max_tokens, temperature)
.await?;
// Estimate token usage (rough approximation)
let prompt_tokens = (prompt.len() / 4) as u32; // Rough estimate: 4 chars per token
let completion_tokens = (content.len() / 4) as u32;
Ok(CompletionResponse {
content,
usage: Usage {
@@ -245,34 +279,38 @@ impl LLMProvider for EmbeddedProvider {
model: self.model_name.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
debug!("Processing streaming request with {} messages", request.messages.len());
debug!(
"Processing streaming request with {} messages",
request.messages.len()
);
let prompt = self.format_messages(&request.messages);
let max_tokens = request.max_tokens.unwrap_or(self.max_tokens);
let temperature = request.temperature.unwrap_or(self.temperature);
let (tx, rx) = mpsc::channel(100);
let session = self.session.clone();
let prompt = prompt.to_string();
// Spawn streaming task
tokio::task::spawn_blocking(move || {
let mut session = match session.try_lock() {
Ok(ctx) => ctx,
Err(_) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
let _ =
tx.blocking_send(Err(anyhow::anyhow!("Model is busy, please try again")));
return;
}
};
// Set context to the prompt
if let Err(e) = session.set_context(&prompt) {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to set context: {}", e)));
return;
}
// Create sampler with temperature
let stages = vec![
SamplerStage::Temperature(temperature),
@@ -280,67 +318,55 @@ impl LLMProvider for EmbeddedProvider {
SamplerStage::TopP(0.9),
];
let sampler = StandardSampler::new_softmax(stages, 1);
// Start completion
let mut completion_handle = match session.start_completing_with(sampler, max_tokens as usize) {
let mut completion_handle = match session
.start_completing_with(sampler, max_tokens as usize)
{
Ok(handle) => handle,
Err(e) => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
let _ =
tx.blocking_send(Err(anyhow::anyhow!("Failed to start completion: {}", e)));
return;
}
};
let mut accumulated_text = String::new();
let mut token_count = 0;
// Stream tokens with proper limits
while let Some(token) = completion_handle.next_token() {
let token_string = session.model().token_to_piece(token);
accumulated_text.push_str(&token_string);
token_count += 1;
let chunk = CompletionChunk {
content: token_string.clone(),
finished: false,
};
if tx.blocking_send(Ok(chunk)).is_err() {
break; // Receiver dropped
}
// Enforce token limit
if token_count >= max_tokens as usize {
debug!("Reached max token limit in streaming: {}", max_tokens);
break;
}
// Stop if we hit common stop sequences
if accumulated_text.contains("### Human") ||
accumulated_text.contains("### System") ||
accumulated_text.contains("<|end|>") ||
accumulated_text.contains("</s>") ||
accumulated_text.trim().ends_with("```") {
if accumulated_text.contains("### Human")
|| accumulated_text.contains("### System")
|| accumulated_text.contains("<|end|>")
|| accumulated_text.contains("</s>")
{
debug!("Hit stop sequence in streaming, stopping generation");
break;
}
// Emergency brake for streaming too
if token_count > 0 && token_count % 100 == 0 {
debug!("Streaming: Generated {} tokens so far", token_count);
if accumulated_text.trim().len() > 50 &&
(accumulated_text.contains('\n') || accumulated_text.len() > 200) {
if accumulated_text.trim().ends_with('.') ||
accumulated_text.trim().ends_with('!') ||
accumulated_text.trim().ends_with('?') ||
accumulated_text.trim().ends_with('\n') {
debug!("Found natural stopping point in streaming at {} tokens", token_count);
break;
}
}
}
}
// Send final chunk
let final_chunk = CompletionChunk {
content: String::new(),
@@ -348,14 +374,14 @@ impl LLMProvider for EmbeddedProvider {
};
let _ = tx.blocking_send(Ok(final_chunk));
});
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
"embedded"
}
fn model(&self) -> &str {
&self.model_name
}