auto refresh token

This commit is contained in:
Dhanji Prasanna
2025-10-04 17:32:48 +10:00
parent 1a57dd3b1d
commit bcba99ec6c
2 changed files with 238 additions and 27 deletions

View File

@@ -16,6 +16,7 @@ use std::io;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use std::collections::VecDeque;
// Retro sci-fi color scheme inspired by Alien terminals
const TERMINAL_GREEN: Color = Color::Rgb(136, 244, 152); // Mid green
@@ -98,6 +99,12 @@ struct TerminalState {
should_exit: bool,
/// Track the last tool header line index for updating it
last_tool_header_index: Option<usize>,
/// Token rate tracking for chart
token_rate_history: VecDeque<(f64, f64)>, // (time_seconds, tokens_per_second)
/// Start time for token tracking
session_start: Instant,
/// Last token count for rate calculation
last_token_count: u32,
}
impl TerminalState {
@@ -128,6 +135,9 @@ impl TerminalState {
is_processing: false,
should_exit: false,
last_tool_header_index: None,
token_rate_history: VecDeque::with_capacity(60), // Keep last 60 data points
session_start: Instant::now(),
last_token_count: 0,
}
}
@@ -383,6 +393,21 @@ impl RetroTui {
percentage,
} => {
state.context_info = (used, total, percentage);
// Update token rate history for the chart
let elapsed = state.session_start.elapsed().as_secs_f64();
// Calculate tokens per second since last update
let tokens_since_last = used.saturating_sub(state.last_token_count) as f64;
let rate = if tokens_since_last > 0.0 { tokens_since_last } else { 0.0 };
state.token_rate_history.push_back((elapsed, rate));
// Keep only last 60 data points (about 1 minute of history at 1 update/sec)
while state.token_rate_history.len() > 60 {
state.token_rate_history.pop_front();
}
state.last_token_count = used;
}
TuiMessage::Error(err) => {
state.add_output(&format!("ERROR: {}", err));
@@ -489,7 +514,7 @@ impl RetroTui {
Self::draw_output_area(f, chunks[1], &state.output_history, state.scroll_offset);
// Draw activity area (tool output)
Self::draw_activity_area(f, chunks[2], &state.tool_activity, state.tool_activity_scroll);
Self::draw_activity_area(f, chunks[2], state);
// Draw status bar
Self::draw_status_bar(
@@ -678,8 +703,7 @@ impl RetroTui {
fn draw_activity_area(
f: &mut Frame,
area: Rect,
tool_activity: &[String],
scroll_offset: usize,
state: &TerminalState,
) {
// Note: scroll_offset is managed by the state and auto-scrolls to show latest content when new data arrives
@@ -695,8 +719,8 @@ impl RetroTui {
// Draw left half - Tool Activity
// Calculate actual visible height accounting for borders
let visible_height = chunks[0].height.saturating_sub(2).max(1) as usize;
let total_lines = tool_activity.len();
let total_lines = state.tool_activity.len();
let scroll_offset = state.tool_activity_scroll;
// Calculate scroll position
let scroll = if total_lines <= visible_height {
0
@@ -705,13 +729,13 @@ impl RetroTui {
};
// Get visible lines for tool activity
let visible_lines: Vec<Line> = if tool_activity.is_empty() {
let visible_lines: Vec<Line> = if state.tool_activity.is_empty() {
vec![Line::from(Span::styled(
" No tool activity yet",
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
))]
} else {
tool_activity
state.tool_activity
.iter()
.skip(scroll)
.take(visible_height)
@@ -742,23 +766,111 @@ impl RetroTui {
f.render_widget(tool_output, chunks[0]);
// Draw right half - Activity
let reserved = Paragraph::new(vec![Line::from(Span::styled(
" Activity log will appear here",
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
))])
.block(
Block::default()
.title(" ACTIVITY ")
// Draw right half - Token Chart
Self::draw_token_chart(f, chunks[1], &state.token_rate_history, state.is_processing);
}
/// Draw a line chart showing tokens received over time
fn draw_token_chart(
f: &mut Frame,
area: Rect,
token_history: &VecDeque<(f64, f64)>,
is_processing: bool,
) {
// Create the chart block
let block = Block::default()
.title(" TOKENS RECEIVED ")
.title_alignment(Alignment::Center)
.borders(Borders::ALL)
.border_style(Style::default().fg(TERMINAL_DIM_GREEN))
.style(Style::default().bg(TERMINAL_BG)),
);
.style(Style::default().bg(TERMINAL_BG));
f.render_widget(reserved, chunks[1]);
// Calculate inner area for chart
let inner = block.inner(area);
// Render the block first
f.render_widget(block, area);
// If no data or area too small, show placeholder
if token_history.is_empty() || inner.width < 10 || inner.height < 3 {
let placeholder = Paragraph::new(vec![Line::from(Span::styled(
" Waiting for token data...",
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
))])
.alignment(Alignment::Center);
f.render_widget(placeholder, inner);
return;
}
// Calculate cumulative tokens for Y axis
let mut cumulative_tokens: Vec<(f64, f64)> = Vec::new();
let mut total = 0.0;
for (time, rate) in token_history.iter() {
total += rate;
cumulative_tokens.push((*time, total));
}
// Find max for scaling
let max_tokens = cumulative_tokens
.iter()
.map(|(_, tokens)| *tokens)
.fold(10.0, f64::max); // Minimum scale of 10 tokens
let chart_height = inner.height as usize;
let chart_width = inner.width as usize;
// Create sparkline visualization
let mut lines: Vec<Line> = Vec::new();
// Add Y-axis label at top
lines.push(Line::from(vec![
Span::styled(
format!("{:>5.0}", max_tokens),
Style::default().fg(TERMINAL_AMBER),
),
Span::styled("", Style::default().fg(TERMINAL_DIM_GREEN)),
]));
// Draw the sparkline chart
if chart_height > 3 && !cumulative_tokens.is_empty() {
let sparkline_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
let mut chart_line = String::from("");
// Sample the data to fit the width
let sample_step = cumulative_tokens.len() as f64 / (chart_width - 7) as f64;
for x in 0..(chart_width - 7) {
let idx = (x as f64 * sample_step) as usize;
if idx < cumulative_tokens.len() {
let (_, tokens) = cumulative_tokens[idx];
let normalized = (tokens / max_tokens).min(1.0);
let char_idx = ((normalized * 7.0) as usize).min(7);
chart_line.push(sparkline_chars[char_idx]);
} else {
chart_line.push(' ');
}
}
let color = if is_processing { TERMINAL_CYAN } else { TERMINAL_GREEN };
lines.push(Line::from(Span::styled(chart_line, Style::default().fg(color))));
// Add bottom axis
lines.push(Line::from(vec![
Span::styled(" 0", Style::default().fg(TERMINAL_AMBER)),
Span::styled("", Style::default().fg(TERMINAL_DIM_GREEN)),
Span::styled(
format!("{}T (seconds)", "".repeat(chart_width.saturating_sub(15))),
Style::default().fg(TERMINAL_DIM_GREEN),
),
]));
}
let chart_paragraph = Paragraph::new(lines);
f.render_widget(chart_paragraph, inner);
}
/// Draw the status bar
/// Draw the status bar
fn draw_status_bar(
f: &mut Frame,

View File

@@ -122,13 +122,24 @@ impl DatabricksAuth {
client_id,
redirect_url,
scopes,
cached_token: _,
cached_token,
} => {
// Use the OAuth implementation
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
// Use the OAuth implementation with automatic refresh
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
// Cache the token for potential reuse within the same session
*cached_token = Some(token.clone());
Ok(token)
}
}
}
/// Force a token refresh by clearing any cached token
/// This is useful when we get a 403 Invalid Token error
pub fn clear_cached_token(&mut self) {
if let DatabricksAuth::OAuth { cached_token, .. } = self {
*cached_token = None;
}
}
}
#[derive(Debug, Clone)]
@@ -693,7 +704,7 @@ impl LLMProvider for DatabricksProvider {
}
let mut provider_clone = self.clone();
let response = provider_clone
let mut response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
@@ -707,8 +718,52 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(false)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let response_text = response.text().await?;
debug!("Raw Databricks API response: {}", response_text);
@@ -800,7 +855,7 @@ impl LLMProvider for DatabricksProvider {
);
let mut provider_clone = self.clone();
let response = provider_clone
let mut response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
@@ -814,8 +869,52 @@ impl LLMProvider for DatabricksProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check if this is a 403 Invalid Token error that we can retry with token refresh
if status == reqwest::StatusCode::FORBIDDEN &&
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
// Try to refresh the token if we're using OAuth
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
// Clear any cached token to force a refresh
provider_clone.auth.clear_cached_token();
// Try to get a new token (will attempt refresh or new OAuth flow)
match provider_clone.auth.get_token().await {
Ok(_new_token) => {
info!("Successfully refreshed OAuth token, retrying streaming request");
// Retry the request with the new token
response = provider_clone
.create_request_builder(true)
.await?
.json(&request_body)
.send()
.await
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
let retry_status = response.status();
if !retry_status.is_success() {
let retry_error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
}
}
Err(e) => {
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
}
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
} else {
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
}
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel(100);