add context window monitor
Writes the current context window to logs/current_context_window (uses a symlink to a session ID). This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
@@ -1,36 +1,36 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Trait for LLM providers
|
||||
#[async_trait::async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
/// Generate a completion for the given messages
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
|
||||
|
||||
|
||||
/// Stream a completion for the given messages
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
|
||||
|
||||
|
||||
/// Get the provider name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
|
||||
/// Get the model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
|
||||
/// Check if the provider supports native tool calling
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
/// Check if the provider supports cache control
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
/// Get the configured max_tokens for this provider
|
||||
fn max_tokens(&self) -> u32;
|
||||
|
||||
|
||||
/// Get the configured temperature for this provider
|
||||
fn temperature(&self) -> f32;
|
||||
}
|
||||
@@ -60,15 +60,24 @@ pub enum CacheType {
|
||||
|
||||
impl CacheControl {
|
||||
pub fn ephemeral() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: None }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn five_minute() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: Some("5m".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn one_hour() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: Some("1h".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +85,7 @@ impl CacheControl {
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
#[serde(skip)]
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
@@ -110,7 +120,7 @@ pub struct CompletionChunk {
|
||||
pub content: String,
|
||||
pub finished: bool,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -144,7 +154,7 @@ impl Message {
|
||||
fn generate_id() -> String {
|
||||
let now = chrono::Local::now();
|
||||
let timestamp = now.format("%H%M%S").to_string();
|
||||
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let random_chars: String = (0..3)
|
||||
.map(|_| {
|
||||
@@ -153,10 +163,10 @@ impl Message {
|
||||
chars[idx] as char
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
format!("{}-{}", timestamp, random_chars)
|
||||
}
|
||||
|
||||
|
||||
/// Create a new message with optional cache control
|
||||
pub fn new(role: MessageRole, content: String) -> Self {
|
||||
Self {
|
||||
@@ -168,7 +178,11 @@ impl Message {
|
||||
}
|
||||
|
||||
/// Create a new message with cache control
|
||||
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
|
||||
pub fn with_cache_control(
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
cache_control: CacheControl,
|
||||
) -> Self {
|
||||
Self {
|
||||
role,
|
||||
content,
|
||||
@@ -176,13 +190,13 @@ impl Message {
|
||||
cache_control: Some(cache_control),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create a message with cache control, with provider validation
|
||||
pub fn with_cache_control_validated(
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
cache_control: CacheControl,
|
||||
provider: &dyn LLMProvider
|
||||
provider: &dyn LLMProvider,
|
||||
) -> Self {
|
||||
if !provider.supports_cache_control() {
|
||||
tracing::warn!(
|
||||
@@ -192,7 +206,7 @@ impl Message {
|
||||
);
|
||||
return Self::new(role, content);
|
||||
}
|
||||
|
||||
|
||||
Self::with_cache_control(role, content, cache_control)
|
||||
}
|
||||
}
|
||||
@@ -210,16 +224,16 @@ impl ProviderRegistry {
|
||||
default_provider: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
|
||||
let name = provider.name().to_string();
|
||||
self.providers.insert(name.clone(), Box::new(provider));
|
||||
|
||||
|
||||
if self.default_provider.is_empty() {
|
||||
self.default_provider = name;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn set_default(&mut self, provider_name: &str) -> Result<()> {
|
||||
if !self.providers.contains_key(provider_name) {
|
||||
anyhow::bail!("Provider '{}' not found", provider_name);
|
||||
@@ -227,7 +241,7 @@ impl ProviderRegistry {
|
||||
self.default_provider = provider_name.to_string();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> {
|
||||
let name = provider_name.unwrap_or(&self.default_provider);
|
||||
self.providers
|
||||
@@ -235,7 +249,7 @@ impl ProviderRegistry {
|
||||
.map(|p| p.as_ref())
|
||||
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name))
|
||||
}
|
||||
|
||||
|
||||
pub fn list_providers(&self) -> Vec<&str> {
|
||||
self.providers.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
@@ -255,10 +269,12 @@ mod tests {
|
||||
fn test_message_serialization_without_cache_control() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON without cache_control: {}", json);
|
||||
assert!(!json.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured");
|
||||
assert!(
|
||||
!json.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -269,16 +285,24 @@ mod tests {
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured");
|
||||
assert!(json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' value");
|
||||
assert!(json.contains("\"type\":"),
|
||||
"JSON should contain 'type' field in cache_control");
|
||||
assert!(!json.contains("null"),
|
||||
"JSON should not contain null values");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' value"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"type\":"),
|
||||
"JSON should contain 'type' field in cache_control"
|
||||
);
|
||||
assert!(
|
||||
!json.contains("null"),
|
||||
"JSON should not contain null values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -289,11 +313,20 @@ mod tests {
|
||||
CacheControl::five_minute(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with 5-minute cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||
assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"ttl\":\"5m\""),
|
||||
"JSON should contain ttl field with 5m value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -304,39 +337,53 @@ mod tests {
|
||||
CacheControl::one_hour(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with 1-hour cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||
assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"ttl\":\"1h\""),
|
||||
"JSON should contain ttl field with 1h value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_id_generation() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
|
||||
|
||||
// Check that id is not empty
|
||||
assert!(!msg.id.is_empty(), "Message ID should not be empty");
|
||||
|
||||
|
||||
// Check format: HHMMSS-XXX
|
||||
let parts: Vec<&str> = msg.id.split('-').collect();
|
||||
assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX");
|
||||
|
||||
|
||||
// Check timestamp part is 6 digits
|
||||
assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)");
|
||||
assert!(parts[0].chars().all(|c| c.is_ascii_digit()), "Timestamp should be all digits");
|
||||
|
||||
assert!(
|
||||
parts[0].chars().all(|c| c.is_ascii_digit()),
|
||||
"Timestamp should be all digits"
|
||||
);
|
||||
|
||||
// Check random part is 3 alpha characters
|
||||
assert_eq!(parts[1].len(), 3, "Random part should be 3 characters");
|
||||
assert!(parts[1].chars().all(|c| c.is_ascii_alphabetic()),
|
||||
"Random part should be all alphabetic characters");
|
||||
assert!(
|
||||
parts[1].chars().all(|c| c.is_ascii_alphabetic()),
|
||||
"Random part should be all alphabetic characters"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_id_uniqueness() {
|
||||
let msg1 = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let msg2 = Message::new(MessageRole::User, "Hello".to_string());
|
||||
|
||||
|
||||
// IDs should be different (due to random component)
|
||||
// Note: There's a tiny chance they could be the same, but very unlikely
|
||||
println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id);
|
||||
@@ -346,9 +393,12 @@ mod tests {
|
||||
fn test_message_id_not_serialized() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON: {}", json);
|
||||
assert!(!json.contains("\"id\""), "JSON should not contain 'id' field");
|
||||
assert!(
|
||||
!json.contains("\"id\""),
|
||||
"JSON should not contain 'id' field"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -358,8 +408,14 @@ mod tests {
|
||||
"Hello".to_string(),
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
|
||||
assert!(!msg.id.is_empty(), "Message with cache control should have an ID");
|
||||
assert!(msg.id.contains('-'), "Message ID should contain hyphen separator");
|
||||
|
||||
assert!(
|
||||
!msg.id.is_empty(),
|
||||
"Message with cache control should have an ID"
|
||||
);
|
||||
assert!(
|
||||
msg.id.contains('-'),
|
||||
"Message ID should contain hyphen separator"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user