fix for buffered messages at end, colorized context bars
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use crossterm::style::Color;
|
use crossterm::style::Color;
|
||||||
|
use crossterm::style::{Stylize, SetForegroundColor, ResetColor};
|
||||||
use termimad::MadSkin;
|
use termimad::MadSkin;
|
||||||
|
|
||||||
/// Simple output handler with markdown support
|
/// Simple output handler with markdown support
|
||||||
@@ -45,11 +46,21 @@ impl SimpleOutput {
|
|||||||
|
|
||||||
let filled_chars = "●".repeat(filled_width);
|
let filled_chars = "●".repeat(filled_width);
|
||||||
let empty_chars = "○".repeat(empty_width);
|
let empty_chars = "○".repeat(empty_width);
|
||||||
let progress_bar = format!("{}{}", filled_chars, empty_chars);
|
|
||||||
|
// Determine color based on percentage
|
||||||
|
let color = if percentage < 60.0 {
|
||||||
|
crossterm::style::Color::Green
|
||||||
|
} else if percentage < 80.0 {
|
||||||
|
crossterm::style::Color::Yellow
|
||||||
|
} else {
|
||||||
|
crossterm::style::Color::Red
|
||||||
|
};
|
||||||
|
|
||||||
println!(
|
// Print with colored progress bar
|
||||||
"Context: {} {:.1}% | {}/{} tokens",
|
print!("Context: ");
|
||||||
progress_bar, percentage, used, total
|
print!("{}", SetForegroundColor(color));
|
||||||
);
|
print!("{}{}", filled_chars, empty_chars);
|
||||||
|
print!("{}", ResetColor);
|
||||||
|
println!(" {:.1}% | {}/{} tokens", percentage, used, total);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1510,27 +1510,31 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Display tool execution result with proper indentation
|
// Display tool execution result with proper indentation
|
||||||
if tool_call.tool != "final_output" {
|
if tool_call.tool != "final_output" {
|
||||||
let output_lines: Vec<&str> = tool_result.lines().collect();
|
let output_lines: Vec<&str> = tool_result.lines().collect();
|
||||||
|
|
||||||
|
// Helper function to safely truncate strings at character boundaries
|
||||||
|
let truncate_line = |line: &str, max_width: usize| -> String {
|
||||||
|
let char_count = line.chars().count();
|
||||||
|
if char_count <= max_width {
|
||||||
|
line.to_string()
|
||||||
|
} else {
|
||||||
|
let truncated: String = line.chars().take(max_width.saturating_sub(3)).collect();
|
||||||
|
format!("{}...", truncated)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const MAX_LINES: usize = 5;
|
const MAX_LINES: usize = 5;
|
||||||
const MAX_LINE_WIDTH: usize = 80;
|
const MAX_LINE_WIDTH: usize = 80;
|
||||||
|
|
||||||
if output_lines.len() <= MAX_LINES {
|
if output_lines.len() <= MAX_LINES {
|
||||||
for line in output_lines {
|
for line in output_lines {
|
||||||
// Clip line to max width
|
// Clip line to max width
|
||||||
let clipped_line = if line.len() > MAX_LINE_WIDTH {
|
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
|
||||||
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
|
|
||||||
} else {
|
|
||||||
line.to_string()
|
|
||||||
};
|
|
||||||
self.ui_writer.print_tool_output_line(&clipped_line);
|
self.ui_writer.print_tool_output_line(&clipped_line);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for line in output_lines.iter().take(MAX_LINES) {
|
for line in output_lines.iter().take(MAX_LINES) {
|
||||||
// Clip line to max width
|
// Clip line to max width
|
||||||
let clipped_line = if line.len() > MAX_LINE_WIDTH {
|
let clipped_line = truncate_line(line, MAX_LINE_WIDTH);
|
||||||
format!("{}...", &line[..MAX_LINE_WIDTH.saturating_sub(3)])
|
|
||||||
} else {
|
|
||||||
line.to_string()
|
|
||||||
};
|
|
||||||
self.ui_writer.print_tool_output_line(&clipped_line);
|
self.ui_writer.print_tool_output_line(&clipped_line);
|
||||||
}
|
}
|
||||||
let hidden_count = output_lines.len() - MAX_LINES;
|
let hidden_count = output_lines.len() - MAX_LINES;
|
||||||
@@ -1682,7 +1686,20 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Only use parser text if we truly have no response
|
// Only use parser text if we truly have no response
|
||||||
// This should be rare - only if streaming failed to display anything
|
// This should be rare - only if streaming failed to display anything
|
||||||
debug!("Warning: Using parser buffer text as fallback - this may duplicate output");
|
debug!("Warning: Using parser buffer text as fallback - this may duplicate output");
|
||||||
// Don't add it - it's already been displayed
|
// Extract only the undisplayed portion from parser buffer
|
||||||
|
// Parser buffer accumulates across iterations, so we need to be careful
|
||||||
|
let clean_text = text_content
|
||||||
|
.replace("<|im_end|>", "")
|
||||||
|
.replace("</s>", "")
|
||||||
|
.replace("[/INST]", "")
|
||||||
|
.replace("<</SYS>>", "");
|
||||||
|
let filtered_text = filter_json_tool_calls(&clean_text);
|
||||||
|
|
||||||
|
// Only use this if we truly have nothing else
|
||||||
|
if !filtered_text.trim().is_empty() && full_response.is_empty() {
|
||||||
|
debug!("Using filtered parser text as last resort: {} chars", filtered_text.len());
|
||||||
|
current_response = filtered_text;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !has_text_response && full_response.is_empty() {
|
if !has_text_response && full_response.is_empty() {
|
||||||
@@ -1786,6 +1803,7 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
// Appending would duplicate the output
|
// Appending would duplicate the output
|
||||||
if !current_response.is_empty() && full_response.is_empty() {
|
if !current_response.is_empty() && full_response.is_empty() {
|
||||||
full_response = current_response.clone();
|
full_response = current_response.clone();
|
||||||
|
debug!("Set full_response from current_response (no tool): {} chars", full_response.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ui_writer.println("");
|
self.ui_writer.println("");
|
||||||
@@ -1852,10 +1870,12 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
|
|
||||||
// If we get here and no tool was executed, we're done
|
// If we get here and no tool was executed, we're done
|
||||||
if !tool_executed {
|
if !tool_executed {
|
||||||
// Don't add parser text_content here - it's already been displayed during streaming
|
// IMPORTANT: Do NOT add parser text_content here!
|
||||||
// The parser buffer contains ALL accumulated text, including what was already shown
|
// The text has already been displayed during streaming via current_response.
|
||||||
// Adding it here would cause duplication of the entire response
|
// The parser buffer accumulates ALL text and would cause duplication.
|
||||||
debug!("Stream completed without tool execution. Response already displayed during streaming.");
|
debug!("Stream completed without tool execution. Response already displayed during streaming.");
|
||||||
|
debug!("Current response length: {}, Full response length: {}",
|
||||||
|
current_response.len(), full_response.len());
|
||||||
|
|
||||||
let has_response = !current_response.is_empty() || !full_response.is_empty();
|
let has_response = !current_response.is_empty() || !full_response.is_empty();
|
||||||
|
|
||||||
@@ -1865,10 +1885,11 @@ The tool will execute immediately and you'll receive the result (success or erro
|
|||||||
iteration_count
|
iteration_count
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
// Don't add current_response to full_response here - it was already displayed during streaming
|
// Only set full_response if it's empty (first iteration without tools)
|
||||||
// Only add it if full_response is empty (meaning no tools were executed)
|
// This prevents duplication when the agent responds without calling final_output
|
||||||
if full_response.is_empty() && !current_response.is_empty() {
|
if full_response.is_empty() && !current_response.is_empty() {
|
||||||
full_response = current_response.clone();
|
full_response = current_response.clone();
|
||||||
|
debug!("Set full_response from current_response: {} chars", full_response.len());
|
||||||
}
|
}
|
||||||
self.ui_writer.println("");
|
self.ui_writer.println("");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,14 +125,16 @@ impl DatabricksAuth {
|
|||||||
cached_token,
|
cached_token,
|
||||||
} => {
|
} => {
|
||||||
// Use the OAuth implementation with automatic refresh
|
// Use the OAuth implementation with automatic refresh
|
||||||
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
|
let token =
|
||||||
|
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
|
||||||
|
.await?;
|
||||||
// Cache the token for potential reuse within the same session
|
// Cache the token for potential reuse within the same session
|
||||||
*cached_token = Some(token.clone());
|
*cached_token = Some(token.clone());
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Force a token refresh by clearing any cached token
|
/// Force a token refresh by clearing any cached token
|
||||||
/// This is useful when we get a 403 Invalid Token error
|
/// This is useful when we get a 403 Invalid Token error
|
||||||
pub fn clear_cached_token(&mut self) {
|
pub fn clear_cached_token(&mut self) {
|
||||||
@@ -303,17 +305,20 @@ impl DatabricksProvider {
|
|||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
// Debug: Log raw bytes received
|
// Debug: Log raw bytes received
|
||||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||||
|
|
||||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||||
Ok(s) => {
|
Ok(s) => {
|
||||||
// Debug: Log raw string content (truncated for large chunks)
|
// Debug: Log raw string content (truncated for large chunks)
|
||||||
if s.len() > 1000 {
|
if s.len() > 1000 {
|
||||||
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
|
debug!(
|
||||||
|
"Raw SSE string content (first 500 chars): {:?}...",
|
||||||
|
&s[..500]
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
debug!("Raw SSE string content: {:?}", s);
|
debug!("Raw SSE string content: {:?}", s);
|
||||||
}
|
}
|
||||||
s
|
s
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Invalid UTF-8 in stream chunk: {}", e);
|
error!("Invalid UTF-8 in stream chunk: {}", e);
|
||||||
let _ = tx
|
let _ = tx
|
||||||
@@ -393,7 +398,10 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||||
if data.len() > 1000 {
|
if data.len() > 1000 {
|
||||||
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
|
debug!(
|
||||||
|
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
|
||||||
|
&data[..500]
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
debug!("Raw Databricks SSE JSON payload: {}", data);
|
debug!("Raw Databricks SSE JSON payload: {}", data);
|
||||||
}
|
}
|
||||||
@@ -423,12 +431,15 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
// Handle tool calls - accumulate across chunks
|
// Handle tool calls - accumulate across chunks
|
||||||
if let Some(tool_calls) = delta.tool_calls {
|
if let Some(tool_calls) = delta.tool_calls {
|
||||||
debug!("Processing {} tool call deltas", tool_calls.len());
|
debug!(
|
||||||
|
"Processing {} tool call deltas",
|
||||||
|
tool_calls.len()
|
||||||
|
);
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
let index = tool_call.index.unwrap_or(0);
|
let index = tool_call.index.unwrap_or(0);
|
||||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||||
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
||||||
|
|
||||||
let entry = current_tool_calls
|
let entry = current_tool_calls
|
||||||
.entry(index)
|
.entry(index)
|
||||||
.or_insert_with(|| {
|
.or_insert_with(|| {
|
||||||
@@ -452,7 +463,7 @@ impl DatabricksProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append arguments
|
// Append arguments
|
||||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||||
tool_call.function.arguments.len(), index, entry.2.len());
|
tool_call.function.arguments.len(), index, entry.2.len());
|
||||||
entry.2.push_str(
|
entry.2.push_str(
|
||||||
&tool_call.function.arguments,
|
&tool_call.function.arguments,
|
||||||
@@ -460,12 +471,15 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
||||||
index, entry.0, entry.1, entry.2.len());
|
index, entry.0, entry.1, entry.2.len());
|
||||||
|
|
||||||
// Debug: Show a sample of the accumulated args if they're getting long
|
// Debug: Show a sample of the accumulated args if they're getting long
|
||||||
if entry.2.len() > 100 {
|
if entry.2.len() > 100 {
|
||||||
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
||||||
} else if !entry.2.is_empty() {
|
} else if !entry.2.is_empty() {
|
||||||
debug!("Tool call {} full args: {}", index, entry.2);
|
debug!(
|
||||||
|
"Tool call {} full args: {}",
|
||||||
|
index, entry.2
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -500,7 +514,10 @@ impl DatabricksProvider {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
debug!("Final tool calls count: {}", final_tool_calls.len());
|
debug!(
|
||||||
|
"Final tool calls count: {}",
|
||||||
|
final_tool_calls.len()
|
||||||
|
);
|
||||||
|
|
||||||
let final_chunk = CompletionChunk {
|
let final_chunk = CompletionChunk {
|
||||||
content: String::new(),
|
content: String::new(),
|
||||||
@@ -524,14 +541,14 @@ impl DatabricksProvider {
|
|||||||
// Check if this is likely an incomplete JSON due to line splitting
|
// Check if this is likely an incomplete JSON due to line splitting
|
||||||
// Common indicators: unexpected EOF, unterminated string, etc.
|
// Common indicators: unexpected EOF, unterminated string, etc.
|
||||||
let error_str = e.to_string().to_lowercase();
|
let error_str = e.to_string().to_lowercase();
|
||||||
if line.starts_with("data: ") && (
|
if line.starts_with("data: ")
|
||||||
error_str.contains("eof") ||
|
&& (error_str.contains("eof") ||
|
||||||
error_str.contains("unterminated") ||
|
error_str.contains("unterminated") ||
|
||||||
error_str.contains("unexpected end") ||
|
error_str.contains("unexpected end") ||
|
||||||
error_str.contains("trailing") ||
|
error_str.contains("trailing") ||
|
||||||
// Also check if the data doesn't end with a proper JSON terminator
|
// Also check if the data doesn't end with a proper JSON terminator
|
||||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
|
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
|
||||||
) {
|
{
|
||||||
// This looks like an incomplete data line, save it for the next chunk
|
// This looks like an incomplete data line, save it for the next chunk
|
||||||
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
||||||
incomplete_data_line = line.clone();
|
incomplete_data_line = line.clone();
|
||||||
@@ -542,7 +559,10 @@ impl DatabricksProvider {
|
|||||||
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
||||||
// For debugging large payloads, log a sample
|
// For debugging large payloads, log a sample
|
||||||
if data.len() > 1000 {
|
if data.len() > 1000 {
|
||||||
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
|
debug!(
|
||||||
|
"JSON parse error - data sample: {}",
|
||||||
|
&data[..std::cmp::min(500, data.len())]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Don't error out on parse failures, just continue
|
// Don't error out on parse failures, just continue
|
||||||
@@ -564,7 +584,10 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
// If we have any incomplete data line at the end, try to process it
|
// If we have any incomplete data line at the end, try to process it
|
||||||
if !incomplete_data_line.is_empty() {
|
if !incomplete_data_line.is_empty() {
|
||||||
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
|
debug!(
|
||||||
|
"Processing final incomplete data line (len={})",
|
||||||
|
incomplete_data_line.len()
|
||||||
|
);
|
||||||
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
||||||
// Try to parse it as-is, it might be complete
|
// Try to parse it as-is, it might be complete
|
||||||
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||||
@@ -612,7 +635,7 @@ impl DatabricksProvider {
|
|||||||
|
|
||||||
let response = match self
|
let response = match self
|
||||||
.client
|
.client
|
||||||
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
|
.get(format!("{}/api/2.0/serving-endpoints", self.host))
|
||||||
.header("Authorization", format!("Bearer {}", token))
|
.header("Authorization", format!("Bearer {}", token))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -724,23 +747,23 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
.text()
|
.text()
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
|
||||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
if status == reqwest::StatusCode::FORBIDDEN
|
||||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||||
|
{
|
||||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||||
|
|
||||||
// Try to refresh the token if we're using OAuth
|
// Try to refresh the token if we're using OAuth
|
||||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||||
// Clear any cached token to force a refresh
|
// Clear any cached token to force a refresh
|
||||||
provider_clone.auth.clear_cached_token();
|
provider_clone.auth.clear_cached_token();
|
||||||
|
|
||||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||||
match provider_clone.auth.get_token().await {
|
match provider_clone.auth.get_token().await {
|
||||||
Ok(_new_token) => {
|
Ok(_new_token) => {
|
||||||
info!("Successfully refreshed OAuth token, retrying request");
|
info!("Successfully refreshed OAuth token, retrying request");
|
||||||
|
|
||||||
// Retry the request with the new token
|
// Retry the request with the new token
|
||||||
response = provider_clone
|
response = provider_clone
|
||||||
.create_request_builder(false)
|
.create_request_builder(false)
|
||||||
@@ -749,25 +772,33 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
||||||
|
|
||||||
let retry_status = response.status();
|
let retry_status = response.status();
|
||||||
if !retry_status.is_success() {
|
if !retry_status.is_success() {
|
||||||
let retry_error_text = response
|
let retry_error_text = response
|
||||||
.text()
|
.text()
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
return Err(anyhow!(
|
||||||
|
"Databricks API error {} after token refresh: {}",
|
||||||
|
retry_status,
|
||||||
|
retry_error_text
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
return Err(anyhow!(
|
||||||
|
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||||
|
e,
|
||||||
|
error_text
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -875,23 +906,23 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
.text()
|
.text()
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
|
|
||||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
if status == reqwest::StatusCode::FORBIDDEN
|
||||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||||
|
{
|
||||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||||
|
|
||||||
// Try to refresh the token if we're using OAuth
|
// Try to refresh the token if we're using OAuth
|
||||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||||
// Clear any cached token to force a refresh
|
// Clear any cached token to force a refresh
|
||||||
provider_clone.auth.clear_cached_token();
|
provider_clone.auth.clear_cached_token();
|
||||||
|
|
||||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||||
match provider_clone.auth.get_token().await {
|
match provider_clone.auth.get_token().await {
|
||||||
Ok(_new_token) => {
|
Ok(_new_token) => {
|
||||||
info!("Successfully refreshed OAuth token, retrying streaming request");
|
info!("Successfully refreshed OAuth token, retrying streaming request");
|
||||||
|
|
||||||
// Retry the request with the new token
|
// Retry the request with the new token
|
||||||
response = provider_clone
|
response = provider_clone
|
||||||
.create_request_builder(true)
|
.create_request_builder(true)
|
||||||
@@ -900,25 +931,33 @@ impl LLMProvider for DatabricksProvider {
|
|||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
||||||
|
|
||||||
let retry_status = response.status();
|
let retry_status = response.status();
|
||||||
if !retry_status.is_success() {
|
if !retry_status.is_success() {
|
||||||
let retry_error_text = response
|
let retry_error_text = response
|
||||||
.text()
|
.text()
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
return Err(anyhow!(
|
||||||
|
"Databricks API error {} after token refresh: {}",
|
||||||
|
retry_status,
|
||||||
|
retry_error_text
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
return Err(anyhow!(
|
||||||
|
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||||
|
e,
|
||||||
|
error_text
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -67,8 +67,10 @@ impl EmbeddedProvider {
|
|||||||
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to load model: {}", e))?;
|
||||||
|
|
||||||
// Create session with parameters
|
// Create session with parameters
|
||||||
let mut session_params = SessionParams::default();
|
let mut session_params = SessionParams {
|
||||||
session_params.n_ctx = context_size;
|
n_ctx: context_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
if let Some(threads) = threads {
|
if let Some(threads) = threads {
|
||||||
session_params.n_threads = threads;
|
session_params.n_threads = threads;
|
||||||
}
|
}
|
||||||
@@ -137,7 +139,7 @@ impl EmbeddedProvider {
|
|||||||
in_conversation = false;
|
in_conversation = false;
|
||||||
}
|
}
|
||||||
MessageRole::Assistant => {
|
MessageRole::Assistant => {
|
||||||
formatted.push_str(" ");
|
formatted.push(' ');
|
||||||
formatted.push_str(&message.content);
|
formatted.push_str(&message.content);
|
||||||
formatted.push_str("</s> ");
|
formatted.push_str("</s> ");
|
||||||
in_conversation = false;
|
in_conversation = false;
|
||||||
@@ -146,8 +148,8 @@ impl EmbeddedProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the last message was from user, add a space for the assistant's response
|
// If the last message was from user, add a space for the assistant's response
|
||||||
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
|
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
|
||||||
formatted.push_str(" ");
|
formatted.push(' ');
|
||||||
}
|
}
|
||||||
|
|
||||||
formatted
|
formatted
|
||||||
@@ -439,7 +441,7 @@ impl EmbeddedProvider {
|
|||||||
|
|
||||||
// Use curl with progress bar for download
|
// Use curl with progress bar for download
|
||||||
let output = Command::new("curl")
|
let output = Command::new("curl")
|
||||||
.args(&[
|
.args([
|
||||||
"-L", // Follow redirects
|
"-L", // Follow redirects
|
||||||
"-#", // Show progress bar
|
"-#", // Show progress bar
|
||||||
"-f", // Fail on HTTP errors
|
"-f", // Fail on HTTP errors
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ impl TokenCache {
|
|||||||
hasher.update(scopes.join(",").as_bytes());
|
hasher.update(scopes.join(",").as_bytes());
|
||||||
let hash = format!("{:x}", hasher.finalize());
|
let hash = format!("{:x}", hasher.finalize());
|
||||||
|
|
||||||
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
|
fs::create_dir_all(get_base_path()).unwrap_or(());
|
||||||
let cache_path = get_base_path().join(format!("{}.json", hash));
|
let cache_path = get_base_path().join(format!("{}.json", hash));
|
||||||
|
|
||||||
Self { cache_path }
|
Self { cache_path }
|
||||||
|
|||||||
Reference in New Issue
Block a user