auto refresh token
This commit is contained in:
@@ -122,13 +122,24 @@ impl DatabricksAuth {
|
||||
client_id,
|
||||
redirect_url,
|
||||
scopes,
|
||||
cached_token: _,
|
||||
cached_token,
|
||||
} => {
|
||||
// Use the OAuth implementation
|
||||
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
|
||||
// Use the OAuth implementation with automatic refresh
|
||||
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
|
||||
// Cache the token for potential reuse within the same session
|
||||
*cached_token = Some(token.clone());
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Force a token refresh by clearing any cached token
|
||||
/// This is useful when we get a 403 Invalid Token error
|
||||
pub fn clear_cached_token(&mut self) {
|
||||
if let DatabricksAuth::OAuth { cached_token, .. } = self {
|
||||
*cached_token = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -693,7 +704,7 @@ impl LLMProvider for DatabricksProvider {
|
||||
}
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
let mut response = provider_clone
|
||||
.create_request_builder(false)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
@@ -707,7 +718,51 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying request");
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(false)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
@@ -800,7 +855,7 @@ impl LLMProvider for DatabricksProvider {
|
||||
);
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
let mut response = provider_clone
|
||||
.create_request_builder(true)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
@@ -814,7 +869,51 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying streaming request");
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(true)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
|
||||
Reference in New Issue
Block a user