databricks support
This commit is contained in:
457
crates/g3-providers/src/oauth.rs
Normal file
457
crates/g3-providers/src/oauth.rs
Normal file
@@ -0,0 +1,457 @@
|
||||
use anyhow::Result;
|
||||
use axum::{extract::Query, response::Html, routing::get, Router};
|
||||
use base64::Engine;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sha2::Digest;
|
||||
use std::{collections::HashMap, fs, net::SocketAddr, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OidcEndpoints {
|
||||
authorization_endpoint: String,
|
||||
token_endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenData {
|
||||
/// The access token used to authenticate API requests
|
||||
access_token: String,
|
||||
|
||||
/// Optional refresh token that can be used to obtain a new access token
|
||||
/// when the current one expires, enabling offline access without user interaction
|
||||
refresh_token: Option<String>,
|
||||
|
||||
/// When the access token expires (if known)
|
||||
/// Used to determine when a token needs to be refreshed
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
struct TokenCache {
|
||||
cache_path: PathBuf,
|
||||
}
|
||||
|
||||
fn get_base_path() -> PathBuf {
|
||||
// Use a similar pattern to Goose but for g3
|
||||
// macOS/Linux: ~/.config/g3/databricks/oauth
|
||||
// Windows: ~\AppData\Roaming\g3\config\databricks\oauth\
|
||||
let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
|
||||
path.push("g3");
|
||||
path.push("databricks");
|
||||
path.push("oauth");
|
||||
path
|
||||
}
|
||||
|
||||
impl TokenCache {
|
||||
fn new(host: &str, client_id: &str, scopes: &[String]) -> Self {
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(host.as_bytes());
|
||||
hasher.update(client_id.as_bytes());
|
||||
hasher.update(scopes.join(",").as_bytes());
|
||||
let hash = format!("{:x}", hasher.finalize());
|
||||
|
||||
fs::create_dir_all(get_base_path()).unwrap_or_else(|_| {});
|
||||
let cache_path = get_base_path().join(format!("{}.json", hash));
|
||||
|
||||
Self { cache_path }
|
||||
}
|
||||
|
||||
fn load_token(&self) -> Option<TokenData> {
|
||||
if let Ok(contents) = fs::read_to_string(&self.cache_path) {
|
||||
if let Ok(token_data) = serde_json::from_str::<TokenData>(&contents) {
|
||||
// Only return tokens that have a refresh token
|
||||
if token_data.refresh_token.is_some() {
|
||||
// If token is not expired, return it for immediate use
|
||||
if let Some(expires_at) = token_data.expires_at {
|
||||
if expires_at > Utc::now() {
|
||||
return Some(token_data);
|
||||
}
|
||||
// If token is expired but has refresh token, return it so we can refresh
|
||||
return Some(token_data);
|
||||
}
|
||||
// No expiration time but has refresh token, return it
|
||||
return Some(token_data);
|
||||
}
|
||||
// Token doesn't have a refresh token, ignore it to force a new OAuth flow
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn save_token(&self, token_data: &TokenData) -> Result<()> {
|
||||
if let Some(parent) = self.cache_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let contents = serde_json::to_string(token_data)?;
|
||||
fs::write(&self.cache_path, contents)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_workspace_endpoints(host: &str) -> Result<OidcEndpoints> {
|
||||
let base_url = Url::parse(host).expect("Invalid host URL");
|
||||
let oidc_url = base_url
|
||||
.join("oidc/.well-known/oauth-authorization-server")
|
||||
.expect("Invalid OIDC URL");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client.get(oidc_url.clone()).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get OIDC configuration from {}",
|
||||
oidc_url.to_string()
|
||||
));
|
||||
}
|
||||
|
||||
let oidc_config: Value = resp.json().await?;
|
||||
|
||||
let authorization_endpoint = oidc_config
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OIDC configuration"))?
|
||||
.to_string();
|
||||
|
||||
let token_endpoint = oidc_config
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OIDC configuration"))?
|
||||
.to_string();
|
||||
|
||||
Ok(OidcEndpoints {
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
struct OAuthFlow {
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
scopes: Vec<String>,
|
||||
state: String,
|
||||
verifier: String,
|
||||
}
|
||||
|
||||
impl OAuthFlow {
|
||||
fn new(
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
scopes: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
endpoints,
|
||||
client_id,
|
||||
redirect_url,
|
||||
scopes,
|
||||
state: nanoid::nanoid!(16),
|
||||
verifier: nanoid::nanoid!(64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts token data from an OAuth 2.0 token response.
|
||||
fn extract_token_data(
|
||||
&self,
|
||||
token_response: &Value,
|
||||
old_refresh_token: Option<&str>,
|
||||
) -> Result<TokenData> {
|
||||
// Extract access token (required)
|
||||
let access_token = token_response
|
||||
.get("access_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
|
||||
.to_string();
|
||||
|
||||
// Extract refresh token if available
|
||||
let refresh_token = token_response
|
||||
.get("refresh_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| old_refresh_token.map(|s| s.to_string()));
|
||||
|
||||
// Handle token expiration
|
||||
let expires_at =
|
||||
if let Some(expires_in) = token_response.get("expires_in").and_then(|v| v.as_u64()) {
|
||||
// Traditional OAuth flow with expires_in seconds
|
||||
Some(Utc::now() + chrono::Duration::seconds(expires_in as i64))
|
||||
} else {
|
||||
// If the server doesn't provide any expiration info, log it but don't set an expiration
|
||||
tracing::debug!(
|
||||
"No expiration information provided by server, token expiration unknown."
|
||||
);
|
||||
None
|
||||
};
|
||||
|
||||
Ok(TokenData {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_authorization_url(&self) -> String {
|
||||
let challenge = {
|
||||
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
|
||||
};
|
||||
|
||||
let params = [
|
||||
("response_type", "code"),
|
||||
("client_id", &self.client_id),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("scope", &self.scopes.join(" ")),
|
||||
("state", &self.state),
|
||||
("code_challenge", &challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
];
|
||||
|
||||
format!(
|
||||
"{}?{}",
|
||||
self.endpoints.authorization_endpoint,
|
||||
serde_urlencoded::to_string(params).unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
async fn exchange_code_for_token(&self, code: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("code_verifier", &self.verifier),
|
||||
("client_id", &self.client_id),
|
||||
];
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to exchange code for token: {}",
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
self.extract_token_data(&token_response, None)
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "refresh_token"),
|
||||
("refresh_token", refresh_token),
|
||||
("client_id", &self.client_id),
|
||||
];
|
||||
|
||||
tracing::debug!("Refreshing token using refresh_token");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!("Failed to refresh token: {}", err_text));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
self.extract_token_data(&token_response, Some(refresh_token))
|
||||
}
|
||||
|
||||
async fn execute(&self) -> Result<TokenData> {
|
||||
// Create a channel that will send the auth code from the app process
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let state = self.state.clone();
|
||||
let tx = Arc::new(TokioMutex::new(Some(tx)));
|
||||
|
||||
// Setup a server that will receive the redirect, capture the code, and display success/failure
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(move |Query(params): Query<HashMap<String, String>>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let code = params.get("code").cloned();
|
||||
let received_state = params.get("state").cloned();
|
||||
|
||||
if let (Some(code), Some(received_state)) = (code, received_state) {
|
||||
if received_state == state {
|
||||
if let Some(sender) = tx.lock().await.take() {
|
||||
if sender.send(code).is_ok() {
|
||||
return Html(
|
||||
"<h2>G3 Authentication Success</h2><p>You can close this window and return to your terminal.</p>",
|
||||
);
|
||||
}
|
||||
}
|
||||
Html("<h2>Error</h2><p>Authentication already completed.</p>")
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>State mismatch.</p>")
|
||||
}
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>Authentication failed.</p>")
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Start the server to accept the oauth code
|
||||
let redirect_url = Url::parse(&self.redirect_url)?;
|
||||
let port = redirect_url.port().unwrap_or(80);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let server = axum::serve(listener, app);
|
||||
server.await.unwrap();
|
||||
});
|
||||
|
||||
// Open the browser which will redirect with the code to the server
|
||||
let authorization_url = self.get_authorization_url();
|
||||
println!("🔐 Opening browser for Databricks authentication...");
|
||||
if webbrowser::open(&authorization_url).is_err() {
|
||||
println!(
|
||||
"Please open this URL in your browser:\n{}",
|
||||
authorization_url
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for the authorization code with a timeout
|
||||
let code = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120), // 2 minute timeout
|
||||
rx,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
|
||||
|
||||
// Stop the server
|
||||
server_handle.abort();
|
||||
|
||||
println!("✅ Authentication successful! Exchanging code for token...");
|
||||
|
||||
// Exchange the code for a token
|
||||
self.exchange_code_for_token(&code).await
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_oauth_token_async(
|
||||
host: &str,
|
||||
client_id: &str,
|
||||
redirect_url: &str,
|
||||
scopes: &[String],
|
||||
) -> Result<String> {
|
||||
let token_cache = TokenCache::new(host, client_id, scopes);
|
||||
|
||||
// Try cache first
|
||||
if let Some(token) = token_cache.load_token() {
|
||||
// If token has an expiration time, check if it's expired
|
||||
if let Some(expires_at) = token.expires_at {
|
||||
if expires_at > Utc::now() {
|
||||
tracing::debug!("Using cached token");
|
||||
return Ok(token.access_token);
|
||||
}
|
||||
// Token is expired, will try to refresh below
|
||||
tracing::debug!("Token is expired, attempting to refresh");
|
||||
} else {
|
||||
// No expiration time was provided by the server
|
||||
tracing::debug!("Token has no expiration time, using cached token");
|
||||
return Ok(token.access_token);
|
||||
}
|
||||
|
||||
// Token is expired or has no expiration, try to refresh if we have a refresh token
|
||||
if let Some(refresh_token) = token.refresh_token {
|
||||
// Get endpoints for token refresh
|
||||
match get_workspace_endpoints(host).await {
|
||||
Ok(endpoints) => {
|
||||
let flow = OAuthFlow::new(
|
||||
endpoints,
|
||||
client_id.to_string(),
|
||||
redirect_url.to_string(),
|
||||
scopes.to_vec(),
|
||||
);
|
||||
|
||||
// Try to refresh the token
|
||||
match flow.refresh_token(&refresh_token).await {
|
||||
Ok(new_token) => {
|
||||
if let Err(e) = token_cache.save_token(&new_token) {
|
||||
tracing::warn!("Failed to save refreshed token: {}", e);
|
||||
}
|
||||
tracing::info!("Successfully refreshed token");
|
||||
return Ok(new_token.access_token);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to refresh token, will try new auth flow: {}",
|
||||
e
|
||||
);
|
||||
// Continue to new auth flow
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to get endpoints for token refresh: {}", e);
|
||||
// Continue to new auth flow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get endpoints and execute flow for a new token
|
||||
let endpoints = get_workspace_endpoints(host).await?;
|
||||
let flow = OAuthFlow::new(
|
||||
endpoints,
|
||||
client_id.to_string(),
|
||||
redirect_url.to_string(),
|
||||
scopes.to_vec(),
|
||||
);
|
||||
|
||||
// Execute the OAuth flow and get token
|
||||
let token = flow.execute().await?;
|
||||
|
||||
// Cache and return
|
||||
token_cache.save_token(&token)?;
|
||||
println!("🎉 Databricks authentication complete!");
|
||||
Ok(token.access_token)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_token_cache() -> Result<()> {
|
||||
let cache = TokenCache::new(
|
||||
"https://example.com",
|
||||
"test-client",
|
||||
&["scope1".to_string()],
|
||||
);
|
||||
|
||||
// Test with expiration time
|
||||
let token_data = TokenData {
|
||||
access_token: "test-token".to_string(),
|
||||
refresh_token: Some("test-refresh-token".to_string()),
|
||||
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
|
||||
};
|
||||
|
||||
cache.save_token(&token_data)?;
|
||||
|
||||
let loaded_token = cache.load_token().unwrap();
|
||||
assert_eq!(loaded_token.access_token, token_data.access_token);
|
||||
assert_eq!(loaded_token.refresh_token, token_data.refresh_token);
|
||||
assert!(loaded_token.expires_at.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user