auto refresh token

This commit is contained in:
Dhanji Prasanna
2025-10-04 17:32:48 +10:00
parent 1a57dd3b1d
commit bcba99ec6c
2 changed files with 238 additions and 27 deletions

View File

@@ -122,13 +122,24 @@ impl DatabricksAuth {
client_id,
redirect_url,
scopes,
cached_token: _,
cached_token,
} => {
// Use the OAuth implementation
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
// Use the OAuth implementation with automatic refresh
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
// Cache the token for potential reuse within the same session
*cached_token = Some(token.clone());
Ok(token)
}
}
}
/// Force a token refresh by clearing any cached token
/// This is useful when we get a 403 Invalid Token error
pub fn clear_cached_token(&mut self) {
if let DatabricksAuth::OAuth { cached_token, .. } = self {
*cached_token = None;
}
}
}
#[derive(Debug, Clone)]
@@ -693,7 +704,7 @@ impl LLMProvider for DatabricksProvider {
}
let mut provider_clone = self.clone();
let response = provider_clone
let mut response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
@@ -707,7 +718,51 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let response_text = response.text().await?;
@@ -800,7 +855,7 @@ impl LLMProvider for DatabricksProvider {
);
let mut provider_clone = self.clone();
let response = provider_clone
let mut response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
@@ -814,7 +869,51 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying streaming request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let stream = response.bytes_stream();