fix for buffered messages at end, colorized context bars
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use crossterm::style::Color;
|
||||
use crossterm::style::{Stylize, SetForegroundColor, ResetColor};
|
||||
use termimad::MadSkin;
|
||||
|
||||
/// Simple output handler with markdown support
|
||||
@@ -45,11 +46,21 @@ impl SimpleOutput {
|
||||
|
||||
let filled_chars = "●".repeat(filled_width);
|
||||
let empty_chars = "○".repeat(empty_width);
|
||||
let progress_bar = format!("{}{}", filled_chars, empty_chars);
|
||||
|
||||
// Determine color based on percentage
|
||||
let color = if percentage < 60.0 {
|
||||
crossterm::style::Color::Green
|
||||
} else if percentage < 80.0 {
|
||||
crossterm::style::Color::Yellow
|
||||
} else {
|
||||
crossterm::style::Color::Red
|
||||
};
|
||||
|
||||
println!(
|
||||
"Context: {} {:.1}% | {}/{} tokens",
|
||||
progress_bar, percentage, used, total
|
||||
);
|
||||
// Print with colored progress bar
|
||||
print!("Context: ");
|
||||
print!("{}", SetForegroundColor(color));
|
||||
print!("{}{}", filled_chars, empty_chars);
|
||||
print!("{}", ResetColor);
|
||||
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1510,27 +1510,31 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Display tool execution result with proper indentation
|
||||
if tool_call.tool != "final_output" {
|
||||
let output_lines: Vec<&str> = tool_result.lines().collect();
|
||||
|
||||
// Helper function to safely truncate strings at character boundaries
|
||||
let truncate_line = |line: &str, max_width: usize| -> String {
|
||||
let char_count = line.chars().count();
|
||||
if char_count <= max_width {
|
||||
line.to_string()
|
||||
} else {
|
||||
let truncated: String = line.chars().take(max_width.saturating_sub(3)).collect();
|
||||
format!("{}...", truncated)
|
||||
}
|
||||
};
|
||||
|
||||
const MAX_LINES: usize = 5;
|
||||
const MAX_LINE_WIDTH: usize = 80;
|
||||
|
||||
if output_lines.len() <= MAX_LINES {
|
||||
for line in output_lines {
|
||||
// Clip line to max width
|
||||
let clipped_line = if line.len() > MAX_LINE_WIDTH {
|
||||
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
|
||||
} else {
|
||||
line.to_string()
|
||||
};
|
||||
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
|
||||
self.ui_writer.print_tool_output_line(&clipped_line);
|
||||
}
|
||||
} else {
|
||||
for line in output_lines.iter().take(MAX_LINES) {
|
||||
// Clip line to max width
|
||||
let clipped_line = if line.len() > MAX_LINE_WIDTH {
|
||||
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
|
||||
} else {
|
||||
line.to_string()
|
||||
};
|
||||
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
|
||||
self.ui_writer.print_tool_output_line(&clipped_line);
|
||||
}
|
||||
let hidden_count = output_lines.len() - MAX_LINES;
|
||||
@@ -1682,7 +1686,20 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Only use parser text if we truly have no response
|
||||
// This should be rare - only if streaming failed to display anything
|
||||
debug!("Warning: Using parser buffer text as fallback - this may duplicate output");
|
||||
// Don't add it - it's already been displayed
|
||||
// Extract only the undisplayed portion from parser buffer
|
||||
// Parser buffer accumulates across iterations, so we need to be careful
|
||||
let clean_text = text_content
|
||||
.replace("<|im_end|>", "")
|
||||
.replace("</s>", "")
|
||||
.replace("[/INST]", "")
|
||||
.replace("<</SYS>>", "");
|
||||
let filtered_text = filter_json_tool_calls(&clean_text);
|
||||
|
||||
// Only use this if we truly have nothing else
|
||||
if !filtered_text.trim().is_empty() && full_response.is_empty() {
|
||||
debug!("Using filtered parser text as last resort: {} chars", filtered_text.len());
|
||||
current_response = filtered_text;
|
||||
}
|
||||
}
|
||||
|
||||
if !has_text_response && full_response.is_empty() {
|
||||
@@ -1786,6 +1803,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
// Appending would duplicate the output
|
||||
if !current_response.is_empty() && full_response.is_empty() {
|
||||
full_response = current_response.clone();
|
||||
debug!("Set full_response from current_response (no tool): {} chars", full_response.len());
|
||||
}
|
||||
|
||||
self.ui_writer.println("");
|
||||
@@ -1852,10 +1870,12 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
|
||||
// If we get here and no tool was executed, we're done
|
||||
if !tool_executed {
|
||||
// Don't add parser text_content here - it's already been displayed during streaming
|
||||
// The parser buffer contains ALL accumulated text, including what was already shown
|
||||
// Adding it here would cause duplication of the entire response
|
||||
// IMPORTANT: Do NOT add parser text_content here!
|
||||
// The text has already been displayed during streaming via current_response.
|
||||
// The parser buffer accumulates ALL text and would cause duplication.
|
||||
debug!("Stream completed without tool execution. Response already displayed during streaming.");
|
||||
debug!("Current response length: {}, Full response length: {}",
|
||||
current_response.len(), full_response.len());
|
||||
|
||||
let has_response = !current_response.is_empty() || !full_response.is_empty();
|
||||
|
||||
@@ -1865,10 +1885,11 @@ The tool will execute immediately and you'll receive the result (success or erro
|
||||
iteration_count
|
||||
);
|
||||
} else {
|
||||
// Don't add current_response to full_response here - it was already displayed during streaming
|
||||
// Only add it if full_response is empty (meaning no tools were executed)
|
||||
// Only set full_response if it's empty (first iteration without tools)
|
||||
// This prevents duplication when the agent responds without calling final_output
|
||||
if full_response.is_empty() && !current_response.is_empty() {
|
||||
full_response = current_response.clone();
|
||||
debug!("Set full_response from current_response: {} chars", full_response.len());
|
||||
}
|
||||
self.ui_writer.println("");
|
||||
}
|
||||
|
||||
@@ -125,14 +125,16 @@ impl DatabricksAuth {
|
||||
cached_token,
|
||||
} => {
|
||||
// Use the OAuth implementation with automatic refresh
|
||||
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
|
||||
let token =
|
||||
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
|
||||
.await?;
|
||||
// Cache the token for potential reuse within the same session
|
||||
*cached_token = Some(token.clone());
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Force a token refresh by clearing any cached token
|
||||
/// This is useful when we get a 403 Invalid Token error
|
||||
pub fn clear_cached_token(&mut self) {
|
||||
@@ -303,17 +305,20 @@ impl DatabricksProvider {
|
||||
Ok(chunk) => {
|
||||
// Debug: Log raw bytes received
|
||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||
|
||||
|
||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||
Ok(s) => {
|
||||
// Debug: Log raw string content (truncated for large chunks)
|
||||
if s.len() > 1000 {
|
||||
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
|
||||
debug!(
|
||||
"Raw SSE string content (first 500 chars): {:?}...",
|
||||
&s[..500]
|
||||
);
|
||||
} else {
|
||||
debug!("Raw SSE string content: {:?}", s);
|
||||
}
|
||||
s
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Invalid UTF-8 in stream chunk: {}", e);
|
||||
let _ = tx
|
||||
@@ -393,7 +398,10 @@ impl DatabricksProvider {
|
||||
|
||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||
if data.len() > 1000 {
|
||||
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
|
||||
debug!(
|
||||
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
|
||||
&data[..500]
|
||||
);
|
||||
} else {
|
||||
debug!("Raw Databricks SSE JSON payload: {}", data);
|
||||
}
|
||||
@@ -423,12 +431,15 @@ impl DatabricksProvider {
|
||||
|
||||
// Handle tool calls - accumulate across chunks
|
||||
if let Some(tool_calls) = delta.tool_calls {
|
||||
debug!("Processing {} tool call deltas", tool_calls.len());
|
||||
debug!(
|
||||
"Processing {} tool call deltas",
|
||||
tool_calls.len()
|
||||
);
|
||||
for tool_call in tool_calls {
|
||||
let index = tool_call.index.unwrap_or(0);
|
||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
||||
|
||||
|
||||
let entry = current_tool_calls
|
||||
.entry(index)
|
||||
.or_insert_with(|| {
|
||||
@@ -452,7 +463,7 @@ impl DatabricksProvider {
|
||||
}
|
||||
|
||||
// Append arguments
|
||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||
tool_call.function.arguments.len(), index, entry.2.len());
|
||||
entry.2.push_str(
|
||||
&tool_call.function.arguments,
|
||||
@@ -460,12 +471,15 @@ impl DatabricksProvider {
|
||||
|
||||
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
||||
index, entry.0, entry.1, entry.2.len());
|
||||
|
||||
|
||||
// Debug: Show a sample of the accumulated args if they're getting long
|
||||
if entry.2.len() > 100 {
|
||||
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
||||
} else if !entry.2.is_empty() {
|
||||
debug!("Tool call {} full args: {}", index, entry.2);
|
||||
debug!(
|
||||
"Tool call {} full args: {}",
|
||||
index, entry.2
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -500,7 +514,10 @@ impl DatabricksProvider {
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!("Final tool calls count: {}", final_tool_calls.len());
|
||||
debug!(
|
||||
"Final tool calls count: {}",
|
||||
final_tool_calls.len()
|
||||
);
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
@@ -524,14 +541,14 @@ impl DatabricksProvider {
|
||||
// Check if this is likely an incomplete JSON due to line splitting
|
||||
// Common indicators: unexpected EOF, unterminated string, etc.
|
||||
let error_str = e.to_string().to_lowercase();
|
||||
if line.starts_with("data: ") && (
|
||||
error_str.contains("eof") ||
|
||||
if line.starts_with("data: ")
|
||||
&& (error_str.contains("eof") ||
|
||||
error_str.contains("unterminated") ||
|
||||
error_str.contains("unexpected end") ||
|
||||
error_str.contains("trailing") ||
|
||||
// Also check if the data doesn't end with a proper JSON terminator
|
||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
|
||||
) {
|
||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
|
||||
{
|
||||
// This looks like an incomplete data line, save it for the next chunk
|
||||
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
||||
incomplete_data_line = line.clone();
|
||||
@@ -542,7 +559,10 @@ impl DatabricksProvider {
|
||||
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
||||
// For debugging large payloads, log a sample
|
||||
if data.len() > 1000 {
|
||||
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
|
||||
debug!(
|
||||
"JSON parse error - data sample: {}",
|
||||
&data[..std::cmp::min(500, data.len())]
|
||||
);
|
||||
}
|
||||
}
|
||||
// Don't error out on parse failures, just continue
|
||||
@@ -564,7 +584,10 @@ impl DatabricksProvider {
|
||||
|
||||
// If we have any incomplete data line at the end, try to process it
|
||||
if !incomplete_data_line.is_empty() {
|
||||
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
|
||||
debug!(
|
||||
"Processing final incomplete data line (len={})",
|
||||
incomplete_data_line.len()
|
||||
);
|
||||
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
||||
// Try to parse it as-is, it might be complete
|
||||
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||
@@ -612,7 +635,7 @@ impl DatabricksProvider {
|
||||
|
||||
let response = match self
|
||||
.client
|
||||
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
|
||||
.get(format!("{}/api/2.0/serving-endpoints", self.host))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
@@ -724,23 +747,23 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
if status == reqwest::StatusCode::FORBIDDEN
|
||||
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||
{
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying request");
|
||||
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(false)
|
||||
@@ -749,25 +772,33 @@ impl LLMProvider for DatabricksProvider {
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
return Err(anyhow!(
|
||||
"Databricks API error {} after token refresh: {}",
|
||||
retry_status,
|
||||
retry_error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
return Err(anyhow!(
|
||||
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||
e,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -875,23 +906,23 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
if status == reqwest::StatusCode::FORBIDDEN
|
||||
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||
{
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying streaming request");
|
||||
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(true)
|
||||
@@ -900,25 +931,33 @@ impl LLMProvider for DatabricksProvider {
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
return Err(anyhow!(
|
||||
"Databricks API error {} after token refresh: {}",
|
||||
retry_status,
|
||||
retry_error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
return Err(anyhow!(
|
||||
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||
e,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,8 +67,10 @@ impl EmbeddedProvider {
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
|
||||
|
||||
// Create session with parameters
|
||||
let mut session_params = SessionParams::default();
|
||||
session_params.n_ctx = context_size;
|
||||
let mut session_params = SessionParams {
|
||||
n_ctx: context_size,
|
||||
..Default::default()
|
||||
};
|
||||
if let Some(threads) = threads {
|
||||
session_params.n_threads = threads;
|
||||
}
|
||||
@@ -137,7 +139,7 @@ impl EmbeddedProvider {
|
||||
in_conversation = false;
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
formatted.push_str(" ");
|
||||
formatted.push(' ');
|
||||
formatted.push_str(&message.content);
|
||||
formatted.push_str("</s> ");
|
||||
in_conversation = false;
|
||||
@@ -146,8 +148,8 @@ impl EmbeddedProvider {
|
||||
}
|
||||
|
||||
// If the last message was from user, add a space for the assistant's response
|
||||
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
|
||||
formatted.push_str(" ");
|
||||
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
|
||||
formatted.push(' ');
|
||||
}
|
||||
|
||||
formatted
|
||||
@@ -439,7 +441,7 @@ impl EmbeddedProvider {
|
||||
|
||||
// Use curl with progress bar for download
|
||||
let output = Command::new("curl")
|
||||
.args(&[
|
||||
.args([
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
|
||||
@@ -52,7 +52,7 @@ impl TokenCache {
|
||||
hasher.update(scopes.join(",").as_bytes());
|
||||
let hash = format!("{:x}", hasher.finalize());
|
||||
|
||||
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
|
||||
fs::create_dir_all(get_base_path()).unwrap_or(());
|
||||
let cache_path = get_base_path().join(format!("{}.json", hash));
|
||||
|
||||
Self { cache_path }
|
||||
|
||||
Reference in New Issue
Block a user