fix for buffered messages at end, colorized context bars

This commit is contained in:
Dhanji Prasanna
2025-10-13 13:36:37 +11:00
parent 318355e864
commit 062e6de63f
5 changed files with 143 additions and 70 deletions

View File

@@ -125,14 +125,16 @@ impl DatabricksAuth {
cached_token,
} => {
// Use the OAuth implementation with automatic refresh
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
let token =
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
.await?;
// Cache the token for potential reuse within the same session
*cached_token = Some(token.clone());
Ok(token)
}
}
}
/// Force a token refresh by clearing any cached token
/// This is useful when we get a 403 Invalid Token error
pub fn clear_cached_token(&mut self) {
@@ -303,17 +305,20 @@ impl DatabricksProvider {
Ok(chunk) => {
// Debug: Log raw bytes received
debug!("Raw SSE bytes received: {} bytes", chunk.len());
let chunk_str = match std::str::from_utf8(&chunk) {
Ok(s) => {
// Debug: Log raw string content (truncated for large chunks)
if s.len() > 1000 {
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
debug!(
"Raw SSE string content (first 500 chars): {:?}...",
&s[..500]
);
} else {
debug!("Raw SSE string content: {:?}", s);
}
s
},
}
Err(e) => {
error!("Invalid UTF-8 in stream chunk: {}", e);
let _ = tx
@@ -393,7 +398,10 @@ impl DatabricksProvider {
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
if data.len() > 1000 {
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
debug!(
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
&data[..500]
);
} else {
debug!("Raw Databricks SSE JSON payload: {}", data);
}
@@ -423,12 +431,15 @@ impl DatabricksProvider {
// Handle tool calls - accumulate across chunks
if let Some(tool_calls) = delta.tool_calls {
debug!("Processing {} tool call deltas", tool_calls.len());
debug!(
"Processing {} tool call deltas",
tool_calls.len()
);
for tool_call in tool_calls {
let index = tool_call.index.unwrap_or(0);
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
let entry = current_tool_calls
.entry(index)
.or_insert_with(|| {
@@ -452,7 +463,7 @@ impl DatabricksProvider {
}
// Append arguments
debug!("Appending {} chars to tool call {} args (current len: {})",
debug!("Appending {} chars to tool call {} args (current len: {})",
tool_call.function.arguments.len(), index, entry.2.len());
entry.2.push_str(
&tool_call.function.arguments,
@@ -460,12 +471,15 @@ impl DatabricksProvider {
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
index, entry.0, entry.1, entry.2.len());
// Debug: Show a sample of the accumulated args if they're getting long
if entry.2.len() > 100 {
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
} else if !entry.2.is_empty() {
debug!("Tool call {} full args: {}", index, entry.2);
debug!(
"Tool call {} full args: {}",
index, entry.2
);
}
}
}
@@ -500,7 +514,10 @@ impl DatabricksProvider {
})
.collect();
debug!("Final tool calls count: {}", final_tool_calls.len());
debug!(
"Final tool calls count: {}",
final_tool_calls.len()
);
let final_chunk = CompletionChunk {
content: String::new(),
@@ -524,14 +541,14 @@ impl DatabricksProvider {
// Check if this is likely an incomplete JSON due to line splitting
// Common indicators: unexpected EOF, unterminated string, etc.
let error_str = e.to_string().to_lowercase();
if line.starts_with("data: ") && (
error_str.contains("eof") ||
if line.starts_with("data: ")
&& (error_str.contains("eof") ||
error_str.contains("unterminated") ||
error_str.contains("unexpected end") ||
error_str.contains("trailing") ||
// Also check if the data doesn't end with a proper JSON terminator
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
) {
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
{
// This looks like an incomplete data line, save it for the next chunk
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
incomplete_data_line = line.clone();
@@ -542,7 +559,10 @@ impl DatabricksProvider {
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
// For debugging large payloads, log a sample
if data.len() > 1000 {
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
debug!(
"JSON parse error - data sample: {}",
&data[..std::cmp::min(500, data.len())]
);
}
}
// Don't error out on parse failures, just continue
@@ -564,7 +584,10 @@ impl DatabricksProvider {
// If we have any incomplete data line at the end, try to process it
if !incomplete_data_line.is_empty() {
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
debug!(
"Processing final incomplete data line (len={})",
incomplete_data_line.len()
);
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
// Try to parse it as-is, it might be complete
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
@@ -612,7 +635,7 @@ impl DatabricksProvider {
let response = match self
.client
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
.get(format!("{}/api/2.0/serving-endpoints", self.host))
.header("Authorization", format!("Bearer {}", token))
.send()
.await
@@ -724,23 +747,23 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(false)
@@ -749,25 +772,33 @@ impl LLMProvider for DatabricksProvider {
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
@@ -875,23 +906,23 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
if status == reqwest::StatusCode::FORBIDDEN
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
{
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying streaming request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(true)
@@ -900,25 +931,33 @@ impl LLMProvider for DatabricksProvider {
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
return Err(anyhow!(
"Databricks API error {} after token refresh: {}",
retry_status,
retry_error_text
));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
return Err(anyhow!(
"Failed to refresh OAuth token: {}. Original error: {}",
e,
error_text
));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}