From bcba99ec6c852cf43695ac23e662cf797a35f3c5 Mon Sep 17 00:00:00 2001 From: Dhanji Prasanna Date: Sat, 4 Oct 2025 17:32:48 +1000 Subject: [PATCH] auto refresh token --- crates/g3-cli/src/retro_tui.rs | 156 ++++++++++++++++++++++---- crates/g3-providers/src/databricks.rs | 109 +++++++++++++++++- 2 files changed, 238 insertions(+), 27 deletions(-) diff --git a/crates/g3-cli/src/retro_tui.rs b/crates/g3-cli/src/retro_tui.rs index c0587fb..ae4eb0e 100644 --- a/crates/g3-cli/src/retro_tui.rs +++ b/crates/g3-cli/src/retro_tui.rs @@ -16,6 +16,7 @@ use std::io; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio::sync::mpsc; +use std::collections::VecDeque; // Retro sci-fi color scheme inspired by Alien terminals const TERMINAL_GREEN: Color = Color::Rgb(136, 244, 152); // Mid green @@ -98,6 +99,12 @@ struct TerminalState { should_exit: bool, /// Track the last tool header line index for updating it last_tool_header_index: Option, + /// Token rate tracking for chart + token_rate_history: VecDeque<(f64, f64)>, // (time_seconds, tokens_per_second) + /// Start time for token tracking + session_start: Instant, + /// Last token count for rate calculation + last_token_count: u32, } impl TerminalState { @@ -128,6 +135,9 @@ impl TerminalState { is_processing: false, should_exit: false, last_tool_header_index: None, + token_rate_history: VecDeque::with_capacity(60), // Keep last 60 data points + session_start: Instant::now(), + last_token_count: 0, } } @@ -383,6 +393,21 @@ impl RetroTui { percentage, } => { state.context_info = (used, total, percentage); + + // Update token rate history for the chart + let elapsed = state.session_start.elapsed().as_secs_f64(); + + // Calculate tokens per second since last update + let tokens_since_last = used.saturating_sub(state.last_token_count) as f64; + let rate = if tokens_since_last > 0.0 { tokens_since_last } else { 0.0 }; + + state.token_rate_history.push_back((elapsed, rate)); + + // Keep only last 60 data points (about 1 minute of history at 1 update/sec) + while state.token_rate_history.len() > 60 { + state.token_rate_history.pop_front(); + } + state.last_token_count = used; } TuiMessage::Error(err) => { state.add_output(&format!("ERROR: {}", err)); @@ -489,7 +514,7 @@ impl RetroTui { Self::draw_output_area(f, chunks[1], &state.output_history, state.scroll_offset); // Draw activity area (tool output) - Self::draw_activity_area(f, chunks[2], &state.tool_activity, state.tool_activity_scroll); + Self::draw_activity_area(f, chunks[2], state); // Draw status bar Self::draw_status_bar( @@ -678,8 +703,7 @@ impl RetroTui { fn draw_activity_area( f: &mut Frame, area: Rect, - tool_activity: &[String], - scroll_offset: usize, + state: &TerminalState, ) { // Note: scroll_offset is managed by the state and auto-scrolls to show latest content when new data arrives @@ -695,8 +719,8 @@ impl RetroTui { // Draw left half - Tool Activity // Calculate actual visible height accounting for borders let visible_height = chunks[0].height.saturating_sub(2).max(1) as usize; - let total_lines = tool_activity.len(); - + let total_lines = state.tool_activity.len(); + let scroll_offset = state.tool_activity_scroll; // Calculate scroll position let scroll = if total_lines <= visible_height { 0 @@ -705,13 +729,13 @@ impl RetroTui { }; // Get visible lines for tool activity - let visible_lines: Vec = if tool_activity.is_empty() { + let visible_lines: Vec = if state.tool_activity.is_empty() { vec![Line::from(Span::styled( " No tool activity yet", Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC), ))] } else { - tool_activity + state.tool_activity .iter() .skip(scroll) .take(visible_height) @@ -742,22 +766,110 @@ impl RetroTui { f.render_widget(tool_output, chunks[0]); - // Draw right half - Activity - let reserved = Paragraph::new(vec![Line::from(Span::styled( - " Activity log will appear here", - Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC), - ))]) - .block( - Block::default() - .title(" ACTIVITY ") - .title_alignment(Alignment::Center) - .borders(Borders::ALL) - .border_style(Style::default().fg(TERMINAL_DIM_GREEN)) - .style(Style::default().bg(TERMINAL_BG)), - ); - - f.render_widget(reserved, chunks[1]); + // Draw right half - Token Chart + Self::draw_token_chart(f, chunks[1], &state.token_rate_history, state.is_processing); } + + /// Draw a line chart showing tokens received over time + fn draw_token_chart( + f: &mut Frame, + area: Rect, + token_history: &VecDeque<(f64, f64)>, + is_processing: bool, + ) { + // Create the chart block + let block = Block::default() + .title(" TOKENS RECEIVED ") + .title_alignment(Alignment::Center) + .borders(Borders::ALL) + .border_style(Style::default().fg(TERMINAL_DIM_GREEN)) + .style(Style::default().bg(TERMINAL_BG)); + + // Calculate inner area for chart + let inner = block.inner(area); + + // Render the block first + f.render_widget(block, area); + + // If no data or area too small, show placeholder + if token_history.is_empty() || inner.width < 10 || inner.height < 3 { + let placeholder = Paragraph::new(vec![Line::from(Span::styled( + " Waiting for token data...", + Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC), + ))]) + .alignment(Alignment::Center); + f.render_widget(placeholder, inner); + return; + } + + // Calculate cumulative tokens for Y axis + let mut cumulative_tokens: Vec<(f64, f64)> = Vec::new(); + let mut total = 0.0; + for (time, rate) in token_history.iter() { + total += rate; + cumulative_tokens.push((*time, total)); + } + + // Find max for scaling + let max_tokens = cumulative_tokens + .iter() + .map(|(_, tokens)| *tokens) + .fold(10.0, f64::max); // Minimum scale of 10 tokens + + let chart_height = inner.height as usize; + let chart_width = inner.width as usize; + + // Create sparkline visualization + let mut lines: Vec = Vec::new(); + + // Add Y-axis label at top + lines.push(Line::from(vec![ + Span::styled( + format!("{:>5.0}", max_tokens), + Style::default().fg(TERMINAL_AMBER), + ), + Span::styled(" ┤", Style::default().fg(TERMINAL_DIM_GREEN)), + ])); + + // Draw the sparkline chart + if chart_height > 3 && !cumulative_tokens.is_empty() { + let sparkline_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']; + let mut chart_line = String::from(" │"); + + // Sample the data to fit the width + let sample_step = cumulative_tokens.len() as f64 / (chart_width - 7) as f64; + + for x in 0..(chart_width - 7) { + let idx = (x as f64 * sample_step) as usize; + if idx < cumulative_tokens.len() { + let (_, tokens) = cumulative_tokens[idx]; + let normalized = (tokens / max_tokens).min(1.0); + let char_idx = ((normalized * 7.0) as usize).min(7); + chart_line.push(sparkline_chars[char_idx]); + } else { + chart_line.push(' '); + } + } + + let color = if is_processing { TERMINAL_CYAN } else { TERMINAL_GREEN }; + lines.push(Line::from(Span::styled(chart_line, Style::default().fg(color)))); + + // Add bottom axis + lines.push(Line::from(vec![ + Span::styled(" 0", Style::default().fg(TERMINAL_AMBER)), + Span::styled(" └", Style::default().fg(TERMINAL_DIM_GREEN)), + Span::styled( + format!("{}T (seconds)", "─".repeat(chart_width.saturating_sub(15))), + Style::default().fg(TERMINAL_DIM_GREEN), + ), + ])); + } + + let chart_paragraph = Paragraph::new(lines); + f.render_widget(chart_paragraph, inner); + } + + /// Draw the status bar /// Draw the status bar fn draw_status_bar( diff --git a/crates/g3-providers/src/databricks.rs b/crates/g3-providers/src/databricks.rs index b38d755..6bae162 100644 --- a/crates/g3-providers/src/databricks.rs +++ b/crates/g3-providers/src/databricks.rs @@ -122,13 +122,24 @@ impl DatabricksAuth { client_id, redirect_url, scopes, - cached_token: _, + cached_token, } => { - // Use the OAuth implementation - crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await + // Use the OAuth implementation with automatic refresh + let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?; + // Cache the token for potential reuse within the same session + *cached_token = Some(token.clone()); + Ok(token) } } } + + /// Force a token refresh by clearing any cached token + /// This is useful when we get a 403 Invalid Token error + pub fn clear_cached_token(&mut self) { + if let DatabricksAuth::OAuth { cached_token, .. } = self { + *cached_token = None; + } + } } #[derive(Debug, Clone)] @@ -693,7 +704,7 @@ impl LLMProvider for DatabricksProvider { } let mut provider_clone = self.clone(); - let response = provider_clone + let mut response = provider_clone .create_request_builder(false) .await? .json(&request_body) @@ -707,7 +718,51 @@ impl LLMProvider for DatabricksProvider { .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); + + // Check if this is a 403 Invalid Token error that we can retry with token refresh + if status == reqwest::StatusCode::FORBIDDEN && + (error_text.contains("Invalid Token") || error_text.contains("invalid_token")) { + + info!("Received 403 Invalid Token error, attempting to refresh OAuth token"); + + // Try to refresh the token if we're using OAuth + if let DatabricksAuth::OAuth { .. } = &provider_clone.auth { + // Clear any cached token to force a refresh + provider_clone.auth.clear_cached_token(); + + // Try to get a new token (will attempt refresh or new OAuth flow) + match provider_clone.auth.get_token().await { + Ok(_new_token) => { + info!("Successfully refreshed OAuth token, retrying request"); + + // Retry the request with the new token + response = provider_clone + .create_request_builder(false) + .await? + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?; + + let retry_status = response.status(); + if !retry_status.is_success() { + let retry_error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text)); + } + } + Err(e) => { + return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text)); + } + } + } else { + return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } + } else { return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } } let response_text = response.text().await?; @@ -800,7 +855,7 @@ impl LLMProvider for DatabricksProvider { ); let mut provider_clone = self.clone(); - let response = provider_clone + let mut response = provider_clone .create_request_builder(true) .await? .json(&request_body) @@ -814,7 +869,51 @@ impl LLMProvider for DatabricksProvider { .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); + + // Check if this is a 403 Invalid Token error that we can retry with token refresh + if status == reqwest::StatusCode::FORBIDDEN && + (error_text.contains("Invalid Token") || error_text.contains("invalid_token")) { + + info!("Received 403 Invalid Token error, attempting to refresh OAuth token"); + + // Try to refresh the token if we're using OAuth + if let DatabricksAuth::OAuth { .. } = &provider_clone.auth { + // Clear any cached token to force a refresh + provider_clone.auth.clear_cached_token(); + + // Try to get a new token (will attempt refresh or new OAuth flow) + match provider_clone.auth.get_token().await { + Ok(_new_token) => { + info!("Successfully refreshed OAuth token, retrying streaming request"); + + // Retry the request with the new token + response = provider_clone + .create_request_builder(true) + .await? + .json(&request_body) + .send() + .await + .map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?; + + let retry_status = response.status(); + if !retry_status.is_success() { + let retry_error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text)); + } + } + Err(e) => { + return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text)); + } + } + } else { + return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } + } else { return Err(anyhow!("Databricks API error {}: {}", status, error_text)); + } } let stream = response.bytes_stream();