fix for buffered messages at end, colorized context bars
This commit is contained in:
@@ -125,14 +125,16 @@ impl DatabricksAuth {
|
||||
cached_token,
|
||||
} => {
|
||||
// Use the OAuth implementation with automatic refresh
|
||||
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
|
||||
let token =
|
||||
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes)
|
||||
.await?;
|
||||
// Cache the token for potential reuse within the same session
|
||||
*cached_token = Some(token.clone());
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Force a token refresh by clearing any cached token
|
||||
/// This is useful when we get a 403 Invalid Token error
|
||||
pub fn clear_cached_token(&mut self) {
|
||||
@@ -303,17 +305,20 @@ impl DatabricksProvider {
|
||||
Ok(chunk) => {
|
||||
// Debug: Log raw bytes received
|
||||
debug!("Raw SSE bytes received: {} bytes", chunk.len());
|
||||
|
||||
|
||||
let chunk_str = match std::str::from_utf8(&chunk) {
|
||||
Ok(s) => {
|
||||
// Debug: Log raw string content (truncated for large chunks)
|
||||
if s.len() > 1000 {
|
||||
debug!("Raw SSE string content (first 500 chars): {:?}...", &s[..500]);
|
||||
debug!(
|
||||
"Raw SSE string content (first 500 chars): {:?}...",
|
||||
&s[..500]
|
||||
);
|
||||
} else {
|
||||
debug!("Raw SSE string content: {:?}", s);
|
||||
}
|
||||
s
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Invalid UTF-8 in stream chunk: {}", e);
|
||||
let _ = tx
|
||||
@@ -393,7 +398,10 @@ impl DatabricksProvider {
|
||||
|
||||
// Debug: Log every raw JSON payload from Databricks API (truncated for large payloads)
|
||||
if data.len() > 1000 {
|
||||
debug!("Raw Databricks SSE JSON payload (first 500 chars): {}...", &data[..500]);
|
||||
debug!(
|
||||
"Raw Databricks SSE JSON payload (first 500 chars): {}...",
|
||||
&data[..500]
|
||||
);
|
||||
} else {
|
||||
debug!("Raw Databricks SSE JSON payload: {}", data);
|
||||
}
|
||||
@@ -423,12 +431,15 @@ impl DatabricksProvider {
|
||||
|
||||
// Handle tool calls - accumulate across chunks
|
||||
if let Some(tool_calls) = delta.tool_calls {
|
||||
debug!("Processing {} tool call deltas", tool_calls.len());
|
||||
debug!(
|
||||
"Processing {} tool call deltas",
|
||||
tool_calls.len()
|
||||
);
|
||||
for tool_call in tool_calls {
|
||||
let index = tool_call.index.unwrap_or(0);
|
||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||
debug!("Tool call delta for index {}: id={:?}, name='{}', args_len={}",
|
||||
index, tool_call.id, tool_call.function.name, tool_call.function.arguments.len());
|
||||
|
||||
|
||||
let entry = current_tool_calls
|
||||
.entry(index)
|
||||
.or_insert_with(|| {
|
||||
@@ -452,7 +463,7 @@ impl DatabricksProvider {
|
||||
}
|
||||
|
||||
// Append arguments
|
||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||
debug!("Appending {} chars to tool call {} args (current len: {})",
|
||||
tool_call.function.arguments.len(), index, entry.2.len());
|
||||
entry.2.push_str(
|
||||
&tool_call.function.arguments,
|
||||
@@ -460,12 +471,15 @@ impl DatabricksProvider {
|
||||
|
||||
debug!("Accumulated tool call {}: id='{}', name='{}', args_len={}",
|
||||
index, entry.0, entry.1, entry.2.len());
|
||||
|
||||
|
||||
// Debug: Show a sample of the accumulated args if they're getting long
|
||||
if entry.2.len() > 100 {
|
||||
debug!("Tool call {} args sample (first 100 chars): {}", index, &entry.2[..100]);
|
||||
} else if !entry.2.is_empty() {
|
||||
debug!("Tool call {} full args: {}", index, entry.2);
|
||||
debug!(
|
||||
"Tool call {} full args: {}",
|
||||
index, entry.2
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -500,7 +514,10 @@ impl DatabricksProvider {
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!("Final tool calls count: {}", final_tool_calls.len());
|
||||
debug!(
|
||||
"Final tool calls count: {}",
|
||||
final_tool_calls.len()
|
||||
);
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
@@ -524,14 +541,14 @@ impl DatabricksProvider {
|
||||
// Check if this is likely an incomplete JSON due to line splitting
|
||||
// Common indicators: unexpected EOF, unterminated string, etc.
|
||||
let error_str = e.to_string().to_lowercase();
|
||||
if line.starts_with("data: ") && (
|
||||
error_str.contains("eof") ||
|
||||
if line.starts_with("data: ")
|
||||
&& (error_str.contains("eof") ||
|
||||
error_str.contains("unterminated") ||
|
||||
error_str.contains("unexpected end") ||
|
||||
error_str.contains("trailing") ||
|
||||
// Also check if the data doesn't end with a proper JSON terminator
|
||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']'))
|
||||
) {
|
||||
(!data.trim_end().ends_with('}') && !data.trim_end().ends_with(']')))
|
||||
{
|
||||
// This looks like an incomplete data line, save it for the next chunk
|
||||
debug!("Detected incomplete data line (len={}), buffering for next chunk", line.len());
|
||||
incomplete_data_line = line.clone();
|
||||
@@ -542,7 +559,10 @@ impl DatabricksProvider {
|
||||
debug!("Failed to parse Databricks stream chunk JSON: {} - Data length: {}", e, data.len());
|
||||
// For debugging large payloads, log a sample
|
||||
if data.len() > 1000 {
|
||||
debug!("JSON parse error - data sample: {}", &data[..std::cmp::min(500, data.len())]);
|
||||
debug!(
|
||||
"JSON parse error - data sample: {}",
|
||||
&data[..std::cmp::min(500, data.len())]
|
||||
);
|
||||
}
|
||||
}
|
||||
// Don't error out on parse failures, just continue
|
||||
@@ -564,7 +584,10 @@ impl DatabricksProvider {
|
||||
|
||||
// If we have any incomplete data line at the end, try to process it
|
||||
if !incomplete_data_line.is_empty() {
|
||||
debug!("Processing final incomplete data line (len={})", incomplete_data_line.len());
|
||||
debug!(
|
||||
"Processing final incomplete data line (len={})",
|
||||
incomplete_data_line.len()
|
||||
);
|
||||
if let Some(data) = incomplete_data_line.strip_prefix("data: ") {
|
||||
// Try to parse it as-is, it might be complete
|
||||
if let Ok(_chunk) = serde_json::from_str::<DatabricksStreamChunk>(data) {
|
||||
@@ -612,7 +635,7 @@ impl DatabricksProvider {
|
||||
|
||||
let response = match self
|
||||
.client
|
||||
.get(&format!("{}/api/2.0/serving-endpoints", self.host))
|
||||
.get(format!("{}/api/2.0/serving-endpoints", self.host))
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.send()
|
||||
.await
|
||||
@@ -724,23 +747,23 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
if status == reqwest::StatusCode::FORBIDDEN
|
||||
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||
{
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying request");
|
||||
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(false)
|
||||
@@ -749,25 +772,33 @@ impl LLMProvider for DatabricksProvider {
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
return Err(anyhow!(
|
||||
"Databricks API error {} after token refresh: {}",
|
||||
retry_status,
|
||||
retry_error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
return Err(anyhow!(
|
||||
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||
e,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -875,23 +906,23 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
if status == reqwest::StatusCode::FORBIDDEN
|
||||
&& (error_text.contains("Invalid Token") || error_text.contains("invalid_token"))
|
||||
{
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying streaming request");
|
||||
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(true)
|
||||
@@ -900,25 +931,33 @@ impl LLMProvider for DatabricksProvider {
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
return Err(anyhow!(
|
||||
"Databricks API error {} after token refresh: {}",
|
||||
retry_status,
|
||||
retry_error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
return Err(anyhow!(
|
||||
"Failed to refresh OAuth token: {}. Original error: {}",
|
||||
e,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user