some cleanup of converstation mgmt
This commit is contained in:
@@ -2,9 +2,9 @@ use anyhow::Result;
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use g3_config::Config;
|
use g3_config::Config;
|
||||||
use g3_core::Agent;
|
use g3_core::Agent;
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
|
||||||
use rustyline::error::ReadlineError;
|
use rustyline::error::ReadlineError;
|
||||||
use rustyline::DefaultEditor;
|
use rustyline::DefaultEditor;
|
||||||
|
use std::io::Write;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
@@ -142,6 +142,10 @@ async fn run_interactive(mut agent: Agent, show_prompt: bool, show_code: bool) -
|
|||||||
// Add to history
|
// Add to history
|
||||||
rl.add_history_entry(input)?;
|
rl.add_history_entry(input)?;
|
||||||
|
|
||||||
|
// Show thinking indicator immediately
|
||||||
|
print!("🤔 Thinking...");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
// Create cancellation token for this request
|
// Create cancellation token for this request
|
||||||
let cancellation_token = CancellationToken::new();
|
let cancellation_token = CancellationToken::new();
|
||||||
let cancel_token_clone = cancellation_token.clone();
|
let cancel_token_clone = cancellation_token.clone();
|
||||||
|
|||||||
@@ -250,7 +250,6 @@ impl ContextWindow {
|
|||||||
|
|
||||||
pub struct Agent {
|
pub struct Agent {
|
||||||
providers: ProviderRegistry,
|
providers: ProviderRegistry,
|
||||||
config: Config,
|
|
||||||
context_window: ContextWindow,
|
context_window: ContextWindow,
|
||||||
session_id: Option<String>,
|
session_id: Option<String>,
|
||||||
}
|
}
|
||||||
@@ -311,7 +310,6 @@ impl Agent {
|
|||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
providers,
|
providers,
|
||||||
config,
|
|
||||||
context_window,
|
context_window,
|
||||||
session_id: None,
|
session_id: None,
|
||||||
})
|
})
|
||||||
@@ -838,9 +836,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
let mut total_execution_time = Duration::new(0, 0);
|
let mut total_execution_time = Duration::new(0, 0);
|
||||||
let mut iteration_count = 0;
|
let mut iteration_count = 0;
|
||||||
const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
|
const MAX_ITERATIONS: usize = 10; // Prevent infinite loops
|
||||||
|
let mut response_started = false;
|
||||||
print!("🤖 "); // Start the response indicator
|
|
||||||
io::stdout().flush()?;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
iteration_count += 1;
|
iteration_count += 1;
|
||||||
@@ -953,6 +949,11 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
|
|
||||||
// Only print if there's actually new content to show
|
// Only print if there's actually new content to show
|
||||||
if !new_content.trim().is_empty() {
|
if !new_content.trim().is_empty() {
|
||||||
|
// Replace thinking indicator with response indicator if not already done
|
||||||
|
if !response_started {
|
||||||
|
print!("\r🤖 "); // Clear thinking indicator and show response indicator
|
||||||
|
response_started = true;
|
||||||
|
}
|
||||||
print!("{}", new_content);
|
print!("{}", new_content);
|
||||||
io::stdout().flush()?;
|
io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
@@ -1050,10 +1051,10 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
full_response.push_str(final_display_content);
|
full_response.push_str(final_display_content);
|
||||||
full_response.push_str(&format!(
|
// full_response.push_str(&format!(
|
||||||
"\n\nTool executed: {} -> {}\n\n",
|
// "\n\nTool executed: {} -> {}\n\n",
|
||||||
tool_call.tool, tool_result
|
// tool_call.tool, tool_result
|
||||||
));
|
// ));
|
||||||
|
|
||||||
tool_executed = true;
|
tool_executed = true;
|
||||||
// Break out of current stream to start a new one with updated context
|
// Break out of current stream to start a new one with updated context
|
||||||
@@ -1069,6 +1070,12 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
.replace("<</SYS>>", "");
|
.replace("<</SYS>>", "");
|
||||||
|
|
||||||
if !clean_content.is_empty() {
|
if !clean_content.is_empty() {
|
||||||
|
// Replace thinking indicator with response indicator on first content
|
||||||
|
if !response_started {
|
||||||
|
print!("\r🤖 "); // Clear thinking indicator and show response indicator
|
||||||
|
response_started = true;
|
||||||
|
}
|
||||||
|
|
||||||
debug!("Printing clean content: '{}'", clean_content);
|
debug!("Printing clean content: '{}'", clean_content);
|
||||||
print!("{}", clean_content);
|
print!("{}", clean_content);
|
||||||
let _ = io::stdout().flush(); // Force immediate output
|
let _ = io::stdout().flush(); // Force immediate output
|
||||||
@@ -1317,8 +1324,8 @@ fn fix_nested_quotes_in_shell_command(json_str: &str) -> String {
|
|||||||
}
|
}
|
||||||
'\\' => {
|
'\\' => {
|
||||||
// Check what follows the backslash
|
// Check what follows the backslash
|
||||||
if let Some(&next_ch) = chars.peek() {
|
if let Some(&_next_ch) = chars.peek() {
|
||||||
if next_ch == '"' {
|
if _next_ch == '"' {
|
||||||
// This is an escaped quote, keep the backslash
|
// This is an escaped quote, keep the backslash
|
||||||
fixed_command.push(ch);
|
fixed_command.push(ch);
|
||||||
} else {
|
} else {
|
||||||
@@ -1382,7 +1389,7 @@ fn fix_mixed_quotes_in_json(json_str: &str) -> String {
|
|||||||
'\\' if in_string => {
|
'\\' if in_string => {
|
||||||
// Escape sequence - preserve it
|
// Escape sequence - preserve it
|
||||||
result.push(ch);
|
result.push(ch);
|
||||||
if let Some(&next_ch) = chars.peek() {
|
if let Some(&_next_ch) = chars.peek() {
|
||||||
result.push(chars.next().unwrap());
|
result.push(chars.next().unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,22 +8,18 @@ use llama_cpp::{
|
|||||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||||
};
|
};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
pub struct EmbeddedProvider {
|
pub struct EmbeddedProvider {
|
||||||
model: Arc<LlamaModel>,
|
|
||||||
session: Arc<Mutex<LlamaSession>>,
|
session: Arc<Mutex<LlamaSession>>,
|
||||||
model_name: String,
|
model_name: String,
|
||||||
max_tokens: u32,
|
max_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
context_length: u32,
|
context_length: u32,
|
||||||
generation_active: Arc<AtomicBool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddedProvider {
|
impl EmbeddedProvider {
|
||||||
@@ -84,13 +80,11 @@ impl EmbeddedProvider {
|
|||||||
info!("Successfully loaded {} model", model_type);
|
info!("Successfully loaded {} model", model_type);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model: Arc::new(model),
|
|
||||||
session: Arc::new(Mutex::new(session)),
|
session: Arc::new(Mutex::new(session)),
|
||||||
model_name: format!("embedded-{}", model_type),
|
model_name: format!("embedded-{}", model_type),
|
||||||
max_tokens: max_tokens.unwrap_or(2048),
|
max_tokens: max_tokens.unwrap_or(2048),
|
||||||
temperature: temperature.unwrap_or(0.1),
|
temperature: temperature.unwrap_or(0.1),
|
||||||
context_length: context_size,
|
context_length: context_size,
|
||||||
generation_active: Arc::new(AtomicBool::new(false)),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -429,7 +423,6 @@ impl EmbeddedProvider {
|
|||||||
// Download the Qwen 2.5 7B model if it doesn't exist
|
// Download the Qwen 2.5 7B model if it doesn't exist
|
||||||
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Write;
|
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
|
|
||||||
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
||||||
|
|||||||
@@ -194,20 +194,6 @@ impl AnthropicProvider {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_anthropic_tool_calls(&self, content: &[AnthropicContent]) -> Vec<ToolCall> {
|
|
||||||
content
|
|
||||||
.iter()
|
|
||||||
.filter_map(|c| match c {
|
|
||||||
AnthropicContent::ToolUse { id, name, input } => Some(ToolCall {
|
|
||||||
id: id.clone(),
|
|
||||||
tool: name.clone(),
|
|
||||||
args: input.clone(),
|
|
||||||
}),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
||||||
let mut system_message = None;
|
let mut system_message = None;
|
||||||
let mut anthropic_messages = Vec::new();
|
let mut anthropic_messages = Vec::new();
|
||||||
@@ -668,14 +654,8 @@ enum AnthropicContent {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct AnthropicResponse {
|
struct AnthropicResponse {
|
||||||
id: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
response_type: String,
|
|
||||||
role: String,
|
|
||||||
content: Vec<AnthropicContent>,
|
content: Vec<AnthropicContent>,
|
||||||
model: String,
|
model: String,
|
||||||
stop_reason: Option<String>,
|
|
||||||
stop_sequence: Option<String>,
|
|
||||||
usage: AnthropicUsage,
|
usage: AnthropicUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -701,8 +681,6 @@ struct AnthropicStreamEvent {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct AnthropicDelta {
|
struct AnthropicDelta {
|
||||||
#[serde(rename = "type")]
|
|
||||||
delta_type: Option<String>,
|
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
partial_json: Option<String>,
|
partial_json: Option<String>,
|
||||||
}
|
}
|
||||||
@@ -710,7 +688,9 @@ struct AnthropicDelta {
|
|||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct AnthropicError {
|
struct AnthropicError {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
|
#[allow(dead_code)]
|
||||||
error_type: String,
|
error_type: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
message: String,
|
message: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -813,32 +793,4 @@ mod tests {
|
|||||||
assert!(anthropic_tools[0].input_schema.required.is_some());
|
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||||
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tool_call_conversion() {
|
|
||||||
let provider = AnthropicProvider::new(
|
|
||||||
"test-key".to_string(),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
let content = vec![
|
|
||||||
AnthropicContent::Text {
|
|
||||||
text: "I'll help you get the weather.".to_string(),
|
|
||||||
},
|
|
||||||
AnthropicContent::ToolUse {
|
|
||||||
id: "toolu_123".to_string(),
|
|
||||||
name: "get_weather".to_string(),
|
|
||||||
input: serde_json::json!({"location": "San Francisco, CA"}),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let tool_calls = provider.convert_anthropic_tool_calls(&content);
|
|
||||||
|
|
||||||
assert_eq!(tool_calls.len(), 1);
|
|
||||||
assert_eq!(tool_calls[0].id, "toolu_123");
|
|
||||||
assert_eq!(tool_calls[0].tool, "get_weather");
|
|
||||||
assert_eq!(tool_calls[0].args["location"], "San Francisco, CA");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user