auto refresh token
This commit is contained in:
@@ -16,6 +16,7 @@ use std::io;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::mpsc;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
// Retro sci-fi color scheme inspired by Alien terminals
|
||||
const TERMINAL_GREEN: Color = Color::Rgb(136, 244, 152); // Mid green
|
||||
@@ -98,6 +99,12 @@ struct TerminalState {
|
||||
should_exit: bool,
|
||||
/// Track the last tool header line index for updating it
|
||||
last_tool_header_index: Option<usize>,
|
||||
/// Token rate tracking for chart
|
||||
token_rate_history: VecDeque<(f64, f64)>, // (time_seconds, tokens_per_second)
|
||||
/// Start time for token tracking
|
||||
session_start: Instant,
|
||||
/// Last token count for rate calculation
|
||||
last_token_count: u32,
|
||||
}
|
||||
|
||||
impl TerminalState {
|
||||
@@ -128,6 +135,9 @@ impl TerminalState {
|
||||
is_processing: false,
|
||||
should_exit: false,
|
||||
last_tool_header_index: None,
|
||||
token_rate_history: VecDeque::with_capacity(60), // Keep last 60 data points
|
||||
session_start: Instant::now(),
|
||||
last_token_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,6 +393,21 @@ impl RetroTui {
|
||||
percentage,
|
||||
} => {
|
||||
state.context_info = (used, total, percentage);
|
||||
|
||||
// Update token rate history for the chart
|
||||
let elapsed = state.session_start.elapsed().as_secs_f64();
|
||||
|
||||
// Calculate tokens per second since last update
|
||||
let tokens_since_last = used.saturating_sub(state.last_token_count) as f64;
|
||||
let rate = if tokens_since_last > 0.0 { tokens_since_last } else { 0.0 };
|
||||
|
||||
state.token_rate_history.push_back((elapsed, rate));
|
||||
|
||||
// Keep only last 60 data points (about 1 minute of history at 1 update/sec)
|
||||
while state.token_rate_history.len() > 60 {
|
||||
state.token_rate_history.pop_front();
|
||||
}
|
||||
state.last_token_count = used;
|
||||
}
|
||||
TuiMessage::Error(err) => {
|
||||
state.add_output(&format!("ERROR: {}", err));
|
||||
@@ -489,7 +514,7 @@ impl RetroTui {
|
||||
Self::draw_output_area(f, chunks[1], &state.output_history, state.scroll_offset);
|
||||
|
||||
// Draw activity area (tool output)
|
||||
Self::draw_activity_area(f, chunks[2], &state.tool_activity, state.tool_activity_scroll);
|
||||
Self::draw_activity_area(f, chunks[2], state);
|
||||
|
||||
// Draw status bar
|
||||
Self::draw_status_bar(
|
||||
@@ -678,8 +703,7 @@ impl RetroTui {
|
||||
fn draw_activity_area(
|
||||
f: &mut Frame,
|
||||
area: Rect,
|
||||
tool_activity: &[String],
|
||||
scroll_offset: usize,
|
||||
state: &TerminalState,
|
||||
) {
|
||||
// Note: scroll_offset is managed by the state and auto-scrolls to show latest content when new data arrives
|
||||
|
||||
@@ -695,8 +719,8 @@ impl RetroTui {
|
||||
// Draw left half - Tool Activity
|
||||
// Calculate actual visible height accounting for borders
|
||||
let visible_height = chunks[0].height.saturating_sub(2).max(1) as usize;
|
||||
let total_lines = tool_activity.len();
|
||||
|
||||
let total_lines = state.tool_activity.len();
|
||||
let scroll_offset = state.tool_activity_scroll;
|
||||
// Calculate scroll position
|
||||
let scroll = if total_lines <= visible_height {
|
||||
0
|
||||
@@ -705,13 +729,13 @@ impl RetroTui {
|
||||
};
|
||||
|
||||
// Get visible lines for tool activity
|
||||
let visible_lines: Vec<Line> = if tool_activity.is_empty() {
|
||||
let visible_lines: Vec<Line> = if state.tool_activity.is_empty() {
|
||||
vec![Line::from(Span::styled(
|
||||
" No tool activity yet",
|
||||
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
|
||||
))]
|
||||
} else {
|
||||
tool_activity
|
||||
state.tool_activity
|
||||
.iter()
|
||||
.skip(scroll)
|
||||
.take(visible_height)
|
||||
@@ -742,23 +766,111 @@ impl RetroTui {
|
||||
|
||||
f.render_widget(tool_output, chunks[0]);
|
||||
|
||||
// Draw right half - Activity
|
||||
let reserved = Paragraph::new(vec![Line::from(Span::styled(
|
||||
" Activity log will appear here",
|
||||
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
|
||||
))])
|
||||
.block(
|
||||
Block::default()
|
||||
.title(" ACTIVITY ")
|
||||
// Draw right half - Token Chart
|
||||
Self::draw_token_chart(f, chunks[1], &state.token_rate_history, state.is_processing);
|
||||
}
|
||||
|
||||
/// Draw a line chart showing tokens received over time
|
||||
fn draw_token_chart(
|
||||
f: &mut Frame,
|
||||
area: Rect,
|
||||
token_history: &VecDeque<(f64, f64)>,
|
||||
is_processing: bool,
|
||||
) {
|
||||
// Create the chart block
|
||||
let block = Block::default()
|
||||
.title(" TOKENS RECEIVED ")
|
||||
.title_alignment(Alignment::Center)
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(TERMINAL_DIM_GREEN))
|
||||
.style(Style::default().bg(TERMINAL_BG)),
|
||||
);
|
||||
.style(Style::default().bg(TERMINAL_BG));
|
||||
|
||||
f.render_widget(reserved, chunks[1]);
|
||||
// Calculate inner area for chart
|
||||
let inner = block.inner(area);
|
||||
|
||||
// Render the block first
|
||||
f.render_widget(block, area);
|
||||
|
||||
// If no data or area too small, show placeholder
|
||||
if token_history.is_empty() || inner.width < 10 || inner.height < 3 {
|
||||
let placeholder = Paragraph::new(vec![Line::from(Span::styled(
|
||||
" Waiting for token data...",
|
||||
Style::default().fg(TERMINAL_DIM_GREEN).add_modifier(Modifier::ITALIC),
|
||||
))])
|
||||
.alignment(Alignment::Center);
|
||||
f.render_widget(placeholder, inner);
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate cumulative tokens for Y axis
|
||||
let mut cumulative_tokens: Vec<(f64, f64)> = Vec::new();
|
||||
let mut total = 0.0;
|
||||
for (time, rate) in token_history.iter() {
|
||||
total += rate;
|
||||
cumulative_tokens.push((*time, total));
|
||||
}
|
||||
|
||||
// Find max for scaling
|
||||
let max_tokens = cumulative_tokens
|
||||
.iter()
|
||||
.map(|(_, tokens)| *tokens)
|
||||
.fold(10.0, f64::max); // Minimum scale of 10 tokens
|
||||
|
||||
let chart_height = inner.height as usize;
|
||||
let chart_width = inner.width as usize;
|
||||
|
||||
// Create sparkline visualization
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Add Y-axis label at top
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(
|
||||
format!("{:>5.0}", max_tokens),
|
||||
Style::default().fg(TERMINAL_AMBER),
|
||||
),
|
||||
Span::styled(" ┤", Style::default().fg(TERMINAL_DIM_GREEN)),
|
||||
]));
|
||||
|
||||
// Draw the sparkline chart
|
||||
if chart_height > 3 && !cumulative_tokens.is_empty() {
|
||||
let sparkline_chars = vec!['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
|
||||
let mut chart_line = String::from(" │");
|
||||
|
||||
// Sample the data to fit the width
|
||||
let sample_step = cumulative_tokens.len() as f64 / (chart_width - 7) as f64;
|
||||
|
||||
for x in 0..(chart_width - 7) {
|
||||
let idx = (x as f64 * sample_step) as usize;
|
||||
if idx < cumulative_tokens.len() {
|
||||
let (_, tokens) = cumulative_tokens[idx];
|
||||
let normalized = (tokens / max_tokens).min(1.0);
|
||||
let char_idx = ((normalized * 7.0) as usize).min(7);
|
||||
chart_line.push(sparkline_chars[char_idx]);
|
||||
} else {
|
||||
chart_line.push(' ');
|
||||
}
|
||||
}
|
||||
|
||||
let color = if is_processing { TERMINAL_CYAN } else { TERMINAL_GREEN };
|
||||
lines.push(Line::from(Span::styled(chart_line, Style::default().fg(color))));
|
||||
|
||||
// Add bottom axis
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(" 0", Style::default().fg(TERMINAL_AMBER)),
|
||||
Span::styled(" └", Style::default().fg(TERMINAL_DIM_GREEN)),
|
||||
Span::styled(
|
||||
format!("{}T (seconds)", "─".repeat(chart_width.saturating_sub(15))),
|
||||
Style::default().fg(TERMINAL_DIM_GREEN),
|
||||
),
|
||||
]));
|
||||
}
|
||||
|
||||
let chart_paragraph = Paragraph::new(lines);
|
||||
f.render_widget(chart_paragraph, inner);
|
||||
}
|
||||
|
||||
/// Draw the status bar
|
||||
|
||||
/// Draw the status bar
|
||||
fn draw_status_bar(
|
||||
f: &mut Frame,
|
||||
|
||||
@@ -122,13 +122,24 @@ impl DatabricksAuth {
|
||||
client_id,
|
||||
redirect_url,
|
||||
scopes,
|
||||
cached_token: _,
|
||||
cached_token,
|
||||
} => {
|
||||
// Use the OAuth implementation
|
||||
crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await
|
||||
// Use the OAuth implementation with automatic refresh
|
||||
let token = crate::oauth::get_oauth_token_async(host, client_id, redirect_url, scopes).await?;
|
||||
// Cache the token for potential reuse within the same session
|
||||
*cached_token = Some(token.clone());
|
||||
Ok(token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Force a token refresh by clearing any cached token
|
||||
/// This is useful when we get a 403 Invalid Token error
|
||||
pub fn clear_cached_token(&mut self) {
|
||||
if let DatabricksAuth::OAuth { cached_token, .. } = self {
|
||||
*cached_token = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -693,7 +704,7 @@ impl LLMProvider for DatabricksProvider {
|
||||
}
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
let mut response = provider_clone
|
||||
.create_request_builder(false)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
@@ -707,8 +718,52 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying request");
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(false)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.text().await?;
|
||||
debug!("Raw Databricks API response: {}", response_text);
|
||||
@@ -800,7 +855,7 @@ impl LLMProvider for DatabricksProvider {
|
||||
);
|
||||
|
||||
let mut provider_clone = self.clone();
|
||||
let response = provider_clone
|
||||
let mut response = provider_clone
|
||||
.create_request_builder(true)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
@@ -814,8 +869,52 @@ impl LLMProvider for DatabricksProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check if this is a 403 Invalid Token error that we can retry with token refresh
|
||||
if status == reqwest::StatusCode::FORBIDDEN &&
|
||||
(error_text.contains("Invalid Token") || error_text.contains("invalid_token")) {
|
||||
|
||||
info!("Received 403 Invalid Token error, attempting to refresh OAuth token");
|
||||
|
||||
// Try to refresh the token if we're using OAuth
|
||||
if let DatabricksAuth::OAuth { .. } = &provider_clone.auth {
|
||||
// Clear any cached token to force a refresh
|
||||
provider_clone.auth.clear_cached_token();
|
||||
|
||||
// Try to get a new token (will attempt refresh or new OAuth flow)
|
||||
match provider_clone.auth.get_token().await {
|
||||
Ok(_new_token) => {
|
||||
info!("Successfully refreshed OAuth token, retrying streaming request");
|
||||
|
||||
// Retry the request with the new token
|
||||
response = provider_clone
|
||||
.create_request_builder(true)
|
||||
.await?
|
||||
.json(&request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow!("Failed to send streaming request to Databricks API after token refresh: {}", e))?;
|
||||
|
||||
let retry_status = response.status();
|
||||
if !retry_status.is_success() {
|
||||
let retry_error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow!("Databricks API error {} after token refresh: {}", retry_status, retry_error_text));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Failed to refresh OAuth token: {}. Original error: {}", e, error_text));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("Databricks API error {}: {}", status, error_text));
|
||||
}
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
Reference in New Issue
Block a user