fix for buffered messages at end, colorized context bars

This commit is contained in:
Dhanji Prasanna
2025-10-13 13:36:37 +11:00
parent 318355e864
commit 062e6de63f
5 changed files with 143 additions and 70 deletions

View File

@@ -1,4 +1,5 @@
use crossterm::style::Color;
use crossterm::style::{Stylize, SetForegroundColor, ResetColor};
use termimad::MadSkin;
/// Simple output handler with markdown support
@@ -45,11 +46,21 @@ impl SimpleOutput {
let filled_chars = "".repeat(filled_width);
let empty_chars = "".repeat(empty_width);
let progress_bar = format!("{}{}", filled_chars, empty_chars);
println!(
"Context: {} {:.1}% | {}/{} tokens",
progress_bar, percentage, used, total
);
// Determine color based on percentage
let color = if percentage < 60.0 {
crossterm::style::Color::Green
} else if percentage < 80.0 {
crossterm::style::Color::Yellow
} else {
crossterm::style::Color::Red
};
// Print with colored progress bar
print!("Context: ");
print!("{}", SetForegroundColor(color));
print!("{}{}", filled_chars, empty_chars);
print!("{}", ResetColor);
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
}
}

View File

@@ -1510,27 +1510,31 @@ The tool will execute immediately and you'll receive the result (success or erro
// Display tool execution result with proper indentation
if tool_call.tool != "final_output" {
let output_lines: Vec<&str> = tool_result.lines().collect();
// Helper function to safely truncate strings at character boundaries
let truncate_line = |line: &str, max_width: usize| -> String {
let char_count = line.chars().count();
if char_count <= max_width {
line.to_string()
} else {
let truncated: String = line.chars().take(max_width.saturating_sub(3)).collect();
format!("{}...", truncated)
}
};
const MAX_LINES: usize = 5;
const MAX_LINE_WIDTH: usize = 80;
if output_lines.len() <= MAX_LINES {
for line in output_lines {
// Clip line to max width
let clipped_line = if line.len() > MAX_LINE_WIDTH {
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
} else {
line.to_string()
};
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
self.ui_writer.print_tool_output_line(&clipped_line);
}
} else {
for line in output_lines.iter().take(MAX_LINES) {
// Clip line to max width
let clipped_line = if line.len() > MAX_LINE_WIDTH {
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
} else {
line.to_string()
};
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
self.ui_writer.print_tool_output_line(&clipped_line);
}
let hidden_count = output_lines.len() - MAX_LINES;
@@ -1682,7 +1686,20 @@ The tool will execute immediately and you'll receive the result (success or erro
// Only use parser text if we truly have no response
// This should be rare - only if streaming failed to display anything
debug!("Warning: Using parser buffer text as fallback - this may duplicate output");
// Don't add it - it's already been displayed
// Extract only the undisplayed portion from parser buffer
// Parser buffer accumulates across iterations, so we need to be careful
let clean_text = text_content
.replace("<|im_end|>", "")
.replace("</s>", "")
.replace("[/INST]", "")
.replace("<</SYS>>", "");
let filtered_text = filter_json_tool_calls(&clean_text);
// Only use this if we truly have nothing else
if !filtered_text.trim().is_empty() && full_response.is_empty() {
debug!("Using filtered parser text as last resort: {} chars", filtered_text.len());
current_response = filtered_text;
}
}
if !has_text_response && full_response.is_empty() {
@@ -1786,6 +1803,7 @@ The tool will execute immediately and you'll receive the result (success or erro
// Appending would duplicate the output
if !current_response.is_empty() && full_response.is_empty() {
full_response = current_response.clone();
debug!("Set full_response from current_response (no tool): {} chars", full_response.len());
}
self.ui_writer.println("");
@@ -1852,10 +1870,12 @@ The tool will execute immediately and you'll receive the result (success or erro
// If we get here and no tool was executed, we're done
if !tool_executed {
// Don't add parser text_content here - it's already been displayed during streaming
// The parser buffer contains ALL accumulated text, including what was already shown
// Adding it here would cause duplication of the entire response
// IMPORTANT: Do NOT add parser text_content here!
// The text has already been displayed during streaming via current_response.
// The parser buffer accumulates ALL text and would cause duplication.
debug!("Stream completed without tool execution. Response already displayed during streaming.");
debug!("Current response length: {}, Full response length: {}",
current_response.len(), full_response.len());
let has_response = !current_response.is_empty() || !full_response.is_empty();
@@ -1865,10 +1885,11 @@ The tool will execute immediately and you'll receive the result (success or erro
iteration_count
);
} else {
// Don't add current_response to full_response here - it was already displayed during streaming
// Only add it if full_response is empty (meaning no tools were executed)
// Only set full_response if it's empty (first iteration without tools)
// This prevents duplication when the agent responds without calling final_output
if full_response.is_empty() && !current_response.is_empty() {
full_response = current_response.clone();
debug!("Set full_response from current_response: {} chars", full_response.len());
}
self.ui_writer.println("");
}

View File

@@ -125,7 +125,9 @@ impl DatabricksAuth {
cached_token,
} => {
// Use the OAuth implementation with automatic refresh
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
let token =
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
.await?;
// Cache the token for potential reuse within the same session
*cached_token = Some(token.clone());
Ok(token)
@@ -308,12 +310,15 @@ impl DatabricksProvider {
Ok(s) => {
// Debug: Log raw string content (truncated for large chunks)
if s.len() > 1000 {
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
debug!(
"Raw SSE string content (first 500 chars): {:?}...",
&s[..500]
);
} else {
debug!("Raw SSE string content: {:?}", s);
}
s
},
}
Err(e) => {
error!("Invalid UTF-8 in stream chunk: {}", e);
let _ = tx
@@ -393,7 +398,10 @@ impl DatabricksProvider {
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
if data.len() > 1000 {
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
debug!(
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
&data[..500]
);
} else {
debug!("Raw Databricks SSE JSON payload: {}", data);
}
@@ -423,7 +431,10 @@ impl DatabricksProvider {
// Handle tool calls - accumulate across chunks
if let Some(tool_calls) = delta.tool_calls {
debug!("Processing {} tool call deltas", tool_calls.len());
debug!(
"Processing {} tool call deltas",
tool_calls.len()
);
for tool_call in tool_calls {
let index = tool_call.index.unwrap_or(0);
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
@@ -465,7 +476,10 @@ impl DatabricksProvider {
if entry.2.len() > 100 {
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
} else if !entry.2.is_empty() {
debug!("Tool call {} full args: {}", index, entry.2);
debug!(
"Tool call {} full args: {}",
index, entry.2
);
}
}
}
@@ -500,7 +514,10 @@ impl DatabricksProvider {
})
.collect();
debug!("Final tool calls count: {}", final_tool_calls.len());
debug!(
"Final tool calls count: {}",
final_tool_calls.len()
);
let final_chunk = CompletionChunk {
content: String::new(),
@@ -524,14 +541,14 @@ impl DatabricksProvider {
// Check if this is likely an incomplete JSON due to line splitting
// Common indicators: unexpected EOF, unterminated string, etc.
let error_str = e.to_string().to_lowercase();
if line.starts_with("data: ") && (
error_str.contains("eof") ||
if line.starts_with("data: ")
&& (error_str.contains("eof") ||
error_str.contains("unterminated") ||
error_str.contains("unexpected end") ||
error_str.contains("trailing") ||
// Also check if the data doesn't end with a proper JSON terminator
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
) {
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
{
// This looks like an incomplete data line, save it for the next chunk
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
incomplete_data_line = line.clone();
@@ -542,7 +559,10 @@ impl DatabricksProvider {
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
// For debugging large payloads, log a sample
if data.len() > 1000 {
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
debug!(
"JSON parse error - data sample: {}",
&data[..std::cmp::min(500, data.len())]
);
}
}
// Don't error out on parse failures, just continue
@@ -564,7 +584,10 @@ impl DatabricksProvider {
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
debug!(
"Processing final incomplete data line (len={})",
incomplete_data_line.len()
);
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
// Try to parse it as-is, it might be complete
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
@@ -612,7 +635,7 @@ impl DatabricksProvider {
let response = match self
.client
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
.get(format!("{}/api/2.0/serving-endpoints", self.host))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
@@ -726,9 +749,9 @@ impl LLMProvider for DatabricksProvider {
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
@@ -756,11 +779,19 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {
@@ -877,9 +908,9 @@ impl LLMProvider for DatabricksProvider {
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
@@ -907,11 +938,19 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {

View File

@@ -67,8 +67,10 @@ impl EmbeddedProvider {
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
// Create session with parameters
let mut session_params = SessionParams::default();
session_params.n_ctx = context_size;
let mut session_params = SessionParams {
n_ctx: context_size,
..Default::default()
};
if let Some(threads) = threads {
session_params.n_threads = threads;
}
@@ -137,7 +139,7 @@ impl EmbeddedProvider {
in_conversation = false;
}
MessageRole::Assistant => {
formatted.push_str(" ");
formatted.push(' ');
formatted.push_str(&message.content);
formatted.push_str("</s> ");
in_conversation = false;
@@ -146,8 +148,8 @@ impl EmbeddedProvider {
}
// If the last message was from user, add a space for the assistant's response
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
formatted.push_str(" ");
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
formatted.push(' ');
}
formatted
@@ -439,7 +441,7 @@ impl EmbeddedProvider {
// Use curl with progress bar for download
let output = Command::new("curl")
.args(&[
.args([
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors

View File

@@ -52,7 +52,7 @@ impl TokenCache {
hasher.update(scopes.join(",").as_bytes());
let hash = format!("{:x}", hasher.finalize());
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
fs::create_dir_all(get_base_path()).unwrap_or(());
let cache_path = get_base_path().join(format!("{}.json", hash));
Self { cache_path }