add context window monitor

Writes the current context window to logs/current_context_window (uses a symlink to a session ID).

This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
Jochen
2025-11-27 21:00:02 +11:00
parent 93dc4acf86
commit 52f78653b4
89 changed files with 4040 additions and 2576 deletions

View File

@@ -139,7 +139,7 @@ impl AnthropicProvider {
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
let model = model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string());
debug!("Initialized Anthropic provider with model: {}", model);
Ok(Self {
@@ -160,11 +160,11 @@ impl AnthropicProvider {
.header("x-api-key", &self.api_key)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json");
if self.enable_1m_context {
builder = builder.header("anthropic-beta", "context-1m-2025-08-07");
}
if streaming {
builder = builder.header("accept", "text/event-stream");
}
@@ -188,12 +188,17 @@ impl AnthropicProvider {
};
// Extract properties and required fields from the input schema
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
if let Ok(schema_obj) = serde_json::from_value::<
serde_json::Map<String, serde_json::Value>,
>(tool.input_schema.clone())
{
if let Some(properties) = schema_obj.get("properties") {
schema.properties = properties.clone();
}
if let Some(required) = schema_obj.get("required") {
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
if let Ok(required_vec) =
serde_json::from_value::<Vec<String>>(required.clone())
{
schema.required = Some(required_vec);
}
}
@@ -208,7 +213,10 @@ impl AnthropicProvider {
.collect()
}
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
fn convert_messages(
&self,
messages: &[Message],
) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
let mut system_message = None;
let mut anthropic_messages = Vec::new();
@@ -225,7 +233,9 @@ impl AnthropicProvider {
role: "user".to_string(),
content: vec![AnthropicContent::Text {
text: message.content.clone(),
cache_control: message.cache_control.as_ref()
cache_control: message
.cache_control
.as_ref()
.map(Self::convert_cache_control),
}],
});
@@ -235,7 +245,9 @@ impl AnthropicProvider {
role: "assistant".to_string(),
content: vec![AnthropicContent::Text {
text: message.content.clone(),
cache_control: message.cache_control.as_ref()
cache_control: message
.cache_control
.as_ref()
.map(Self::convert_cache_control),
}],
});
@@ -257,7 +269,9 @@ impl AnthropicProvider {
let (system, anthropic_messages) = self.convert_messages(messages)?;
if anthropic_messages.is_empty() {
return Err(anyhow!("At least one user or assistant message is required"));
return Err(anyhow!(
"At least one user or assistant message is required"
));
}
// Convert tools if provided
@@ -292,13 +306,13 @@ impl AnthropicProvider {
let mut accumulated_usage: Option<Usage> = None;
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
let mut message_stopped = false; // Track if we've received message_stop
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
@@ -312,7 +326,8 @@ impl AnthropicProvider {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
// We have some valid UTF-8, extract it and keep the rest for next iteration
let valid_bytes = byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
// No valid UTF-8 at all, skip this chunk and continue
@@ -346,7 +361,11 @@ impl AnthropicProvider {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
tool_calls: if current_tool_calls.is_empty() {
None
} else {
Some(current_tool_calls.clone())
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -358,7 +377,10 @@ impl AnthropicProvider {
match serde_json::from_str::<AnthropicStreamEvent>(data) {
Ok(event) => {
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
debug!(
"Parsed event type: {}, event: {:?}",
event.event_type, event
);
match event.event_type.as_str() {
"message_start" => {
// Extract usage data from message_start event
@@ -367,19 +389,30 @@ impl AnthropicProvider {
accumulated_usage = Some(Usage {
prompt_tokens: usage.input_tokens,
completion_tokens: usage.output_tokens,
total_tokens: usage.input_tokens + usage.output_tokens,
total_tokens: usage.input_tokens
+ usage.output_tokens,
});
debug!("Captured usage from message_start: {:?}", accumulated_usage);
debug!(
"Captured usage from message_start: {:?}",
accumulated_usage
);
}
}
}
"content_block_start" => {
debug!("Received content_block_start event: {:?}", event);
debug!(
"Received content_block_start event: {:?}",
event
);
if let Some(content_block) = event.content_block {
match content_block {
AnthropicContent::ToolUse { id, name, input } => {
AnthropicContent::ToolUse {
id,
name,
input,
} => {
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
// For native tool calls, create the tool call immediately if we have complete args
// If args are empty, we'll wait for partial_json to accumulate them
let tool_call = ToolCall {
@@ -387,9 +420,14 @@ impl AnthropicProvider {
tool: name.clone(),
args: input.clone(),
};
// Check if we already have complete arguments
if !input.is_null() && input != serde_json::Value::Object(serde_json::Map::new()) {
if !input.is_null()
&& input
!= serde_json::Value::Object(
serde_json::Map::new(),
)
{
// We have complete arguments, send the tool call immediately
debug!("Tool call has complete args, sending immediately: {:?}", tool_call);
let chunk = CompletionChunk {
@@ -410,7 +448,10 @@ impl AnthropicProvider {
}
}
_ => {
debug!("Non-tool content block: {:?}", content_block);
debug!(
"Non-tool content block: {:?}",
content_block
);
}
}
}
@@ -418,7 +459,11 @@ impl AnthropicProvider {
"content_block_delta" => {
if let Some(delta) = event.delta {
if let Some(text) = delta.text {
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
debug!(
"Sending text chunk of length {}: '{}'",
text.len(),
text
);
let chunk = CompletionChunk {
content: text,
finished: false,
@@ -432,31 +477,51 @@ impl AnthropicProvider {
}
// Handle partial JSON for tool calls
if let Some(partial_json) = delta.partial_json {
debug!("Received partial JSON: {}", partial_json);
debug!(
"Received partial JSON: {}",
partial_json
);
partial_tool_json.push_str(&partial_json);
debug!("Accumulated tool JSON: {}", partial_tool_json);
debug!(
"Accumulated tool JSON: {}",
partial_tool_json
);
}
}
}
"content_block_stop" => {
// Tool call block is complete - now parse the accumulated JSON
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
debug!("Parsing complete tool JSON: {}", partial_tool_json);
if !current_tool_calls.is_empty()
&& !partial_tool_json.is_empty()
{
debug!(
"Parsing complete tool JSON: {}",
partial_tool_json
);
// Parse the accumulated JSON and update the last tool call
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
if let Some(last_tool) = current_tool_calls.last_mut() {
if let Ok(parsed_args) =
serde_json::from_str::<serde_json::Value>(
&partial_tool_json,
)
{
if let Some(last_tool) =
current_tool_calls.last_mut()
{
last_tool.args = parsed_args;
debug!("Updated tool call with complete args: {:?}", last_tool);
}
} else {
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
debug!(
"Failed to parse accumulated JSON: {}",
partial_tool_json
);
}
// Clear the accumulator
partial_tool_json.clear();
}
// Send the complete tool call
if !current_tool_calls.is_empty() {
let chunk = CompletionChunk {
@@ -478,7 +543,11 @@ impl AnthropicProvider {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
tool_calls: if current_tool_calls.is_empty() {
None
} else {
Some(current_tool_calls.clone())
},
};
if tx.send(Ok(final_chunk)).await.is_err() {
debug!("Receiver dropped, stopping stream");
@@ -490,7 +559,10 @@ impl AnthropicProvider {
if let Some(error) = event.error {
error!("Anthropic API error: {:?}", error);
let _ = tx
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
.send(Err(anyhow!(
"Anthropic API error: {:?}",
error
)))
.await;
break; // Break to let stream exhaust naturally
}
@@ -524,7 +596,11 @@ impl AnthropicProvider {
content: String::new(),
finished: true,
usage: accumulated_usage.clone(),
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
tool_calls: if current_tool_calls.is_empty() {
None
} else {
Some(current_tool_calls)
},
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
@@ -543,15 +619,17 @@ impl LLMProvider for AnthropicProvider {
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature
&request.messages,
request.tools.as_deref(),
false,
max_tokens,
temperature,
)?;
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
debug!(
"Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature
);
let response = self
.create_request_builder(false)
@@ -588,7 +666,8 @@ impl LLMProvider for AnthropicProvider {
let usage = Usage {
prompt_tokens: anthropic_response.usage.input_tokens,
completion_tokens: anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
total_tokens: anthropic_response.usage.input_tokens
+ anthropic_response.usage.output_tokens,
};
debug!(
@@ -613,18 +692,24 @@ impl LLMProvider for AnthropicProvider {
let temperature = request.temperature.unwrap_or(self.temperature);
let request_body = self.create_request_body(
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature
&request.messages,
request.tools.as_deref(),
true,
max_tokens,
temperature,
)?;
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature);
debug!(
"Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
request_body.model, request_body.max_tokens, request_body.temperature
);
// Debug: Log the full request body
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
debug!(
"Full request body: {}",
serde_json::to_string_pretty(&request_body)
.unwrap_or_else(|_| "Failed to serialize".to_string())
);
let response = self
.create_request_builder(true)
@@ -673,16 +758,16 @@ impl LLMProvider for AnthropicProvider {
// Claude models support native tool calling
true
}
fn supports_cache_control(&self) -> bool {
// Anthropic supports cache control
true
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
@@ -729,7 +814,7 @@ struct AnthropicMessage {
#[serde(tag = "type")]
enum AnthropicContent {
#[serde(rename = "text")]
Text {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<crate::CacheControl>,
@@ -798,17 +883,14 @@ mod tests {
#[test]
fn test_message_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
None,
None,
).unwrap();
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
let messages = vec![
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
Message::new(
MessageRole::System,
"You are a helpful assistant.".to_string(),
),
Message::new(MessageRole::User, "Hello!".to_string()),
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
];
@@ -830,7 +912,8 @@ mod tests {
Some(0.5),
None,
None,
).unwrap();
)
.unwrap();
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
@@ -848,31 +931,23 @@ mod tests {
#[test]
fn test_tool_conversion() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
None,
None,
).unwrap();
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
let tools = vec![
Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
},
];
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the current weather".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state"
}
},
"required": ["location"]
}),
}];
let anthropic_tools = provider.convert_tools(&tools);
@@ -881,31 +956,30 @@ mod tests {
assert_eq!(anthropic_tools[0].description, "Get the current weather");
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
assert!(anthropic_tools[0].input_schema.required.is_some());
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
assert_eq!(
anthropic_tools[0].input_schema.required.as_ref().unwrap()[0],
"location"
);
}
#[test]
fn test_cache_control_serialization() {
let provider = AnthropicProvider::new(
"test-key".to_string(),
None,
None,
None,
None,
None,
).unwrap();
let provider =
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
// Test message WITHOUT cache_control
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap();
let json_without = serde_json::to_string(&anthropic_messages_without).unwrap();
println!("Anthropic JSON without cache_control: {}", json_without);
// Check if cache_control appears in the JSON
if json_without.contains("cache_control") {
println!("WARNING: JSON contains 'cache_control' field when not configured!");
assert!(!json_without.contains("\"cache_control\":null"),
"JSON should not contain 'cache_control: null'");
assert!(
!json_without.contains("\"cache_control\":null"),
"JSON should not contain 'cache_control: null'"
);
}
// Test message WITH cache_control
@@ -916,15 +990,21 @@ mod tests {
)];
let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap();
let json_with = serde_json::to_string(&anthropic_messages_with).unwrap();
println!("Anthropic JSON with cache_control: {}", json_with);
assert!(json_with.contains("cache_control"),
"JSON should contain 'cache_control' field when configured");
assert!(json_with.contains("ephemeral"),
"JSON should contain 'ephemeral' type");
assert!(
json_with.contains("cache_control"),
"JSON should contain 'cache_control' field when configured"
);
assert!(
json_with.contains("ephemeral"),
"JSON should contain 'ephemeral' type"
);
// The key assertion: when cache_control is None, it should not appear in JSON
assert!(!json_without.contains("cache_control") || !json_without.contains("null"),
"JSON should not contain 'cache_control' field or null values when not configured");
assert!(
!json_without.contains("cache_control") || !json_without.contains("null"),
"JSON should not contain 'cache_control' field or null values when not configured"
);
}
}

View File

@@ -312,7 +312,7 @@ impl DatabricksProvider {
// Append new bytes to our buffer
byte_buffer.extend_from_slice(&chunk);
// Try to convert the entire buffer to UTF-8
let chunk_str = match std::str::from_utf8(&byte_buffer) {
Ok(s) => {
@@ -326,7 +326,8 @@ impl DatabricksProvider {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
// We have some valid UTF-8, extract it and keep the rest for next iteration
let valid_bytes = byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
let valid_bytes =
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
std::str::from_utf8(&valid_bytes).unwrap().to_string()
} else {
// No valid UTF-8 at all, skip this chunk and continue
@@ -593,7 +594,7 @@ impl DatabricksProvider {
}
Err(e) => {
error!("Stream error at chunk {}: {}", chunk_count, e);
// Check if this is a connection error that might be recoverable
let error_msg = e.to_string();
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
@@ -610,10 +611,14 @@ impl DatabricksProvider {
// Log final state
debug!("Stream ended after {} chunks", chunk_count);
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
debug!(
"Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
buffer.len(),
incomplete_data_line.len(),
byte_buffer.len()
);
debug!("Accumulated tool calls: {}", current_tool_calls.len());
// If we have any remaining data in buffers, log it for debugging
if !buffer.is_empty() {
debug!("Remaining buffer content: {:?}", buffer);
@@ -924,7 +929,7 @@ impl LLMProvider for DatabricksProvider {
"Processing Databricks streaming request with {} messages",
request.messages.len()
);
// Debug: Log tool count
if let Some(ref tools) = request.tools {
debug!("Request has {} tools", tools.len());
@@ -1051,15 +1056,15 @@ impl LLMProvider for DatabricksProvider {
// This includes Claude, Llama, DBRX, and most other models on the platform
true
}
fn supports_cache_control(&self) -> bool {
false
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
@@ -1181,7 +1186,10 @@ mod tests {
.unwrap();
let messages = vec![
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
Message::new(
MessageRole::System,
"You are a helpful assistant.".to_string(),
),
Message::new(MessageRole::User, "Hello!".to_string()),
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
];
@@ -1304,10 +1312,12 @@ mod tests {
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
println!("JSON without cache_control: {}", json_without);
assert!(!json_without.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured");
assert!(
!json_without.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured"
);
// Test message WITH cache_control - should still NOT include it (Databricks doesn't support it)
let messages_with = vec![Message::with_cache_control(
@@ -1317,10 +1327,12 @@ mod tests {
)];
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
println!("JSON with cache_control: {}", json_with);
assert!(!json_with.contains("cache_control"),
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it");
assert!(
!json_with.contains("cache_control"),
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it"
);
}
#[test]
@@ -1343,7 +1355,13 @@ mod tests {
)
.unwrap();
assert!(!claude_provider.supports_cache_control(), "Databricks should not support cache_control even for Claude models");
assert!(!llama_provider.supports_cache_control(), "Databricks should not support cache_control for Llama models");
assert!(
!claude_provider.supports_cache_control(),
"Databricks should not support cache_control even for Claude models"
);
assert!(
!llama_provider.supports_cache_control(),
"Databricks should not support cache_control for Llama models"
);
}
}

View File

@@ -1,8 +1,8 @@
use anyhow::Result;
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Usage,
};
use anyhow::Result;
use llama_cpp::{
standard_sampler::{SamplerStage, StandardSampler},
LlamaModel, LlamaParams, LlamaSession, SessionParams,
@@ -37,7 +37,7 @@ impl EmbeddedProvider {
// Expand tilde in path
let expanded_path = shellexpand::tilde(&model_path);
let model_path_buf = PathBuf::from(expanded_path.as_ref());
// If model doesn't exist and it's the default Qwen model, offer to download it
if !model_path_buf.exists() {
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
@@ -47,7 +47,7 @@ impl EmbeddedProvider {
anyhow::bail!("Model file not found: {}", model_path_buf.display());
}
}
let model_path = model_path_buf.as_path();
// Set up model parameters
@@ -93,24 +93,24 @@ impl EmbeddedProvider {
fn format_messages(&self, messages: &[Message]) -> String {
// Determine the appropriate format based on model type
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
// Qwen format: <|im_start|>role\ncontent<|im_end|>
let mut formatted = String::new();
for message in messages {
let role = match message.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
};
formatted.push_str(&format!(
"<|im_start|>{}\n{}<|im_end|>\n",
role, message.content
));
}
// Add the start of assistant response
formatted.push_str("<|im_start|>assistant\n");
formatted
@@ -118,7 +118,7 @@ impl EmbeddedProvider {
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
let mut formatted = String::new();
let mut in_conversation = false;
for (i, message) in messages.iter().enumerate() {
match message.role {
MessageRole::System => {
@@ -146,12 +146,15 @@ impl EmbeddedProvider {
}
}
}
// If the last message was from user, add a space for the assistant's response
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
if messages
.last()
.is_some_and(|m| matches!(m.role, MessageRole::User))
{
formatted.push(' ');
}
formatted
} else {
// Use Llama/CodeLlama format for other models
@@ -216,16 +219,25 @@ impl EmbeddedProvider {
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
debug!(
"Session busy, retrying in {}ms (attempt {}/5)",
100 * (attempt + 1),
attempt + 1
);
std::thread::sleep(std::time::Duration::from_millis(
100 * (attempt + 1) as u64,
));
} else {
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
return Err(anyhow::anyhow!(
"Model is busy after 5 attempts, please try again"
));
}
}
}
}
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
let mut session = session_guard
.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
debug!(
"Starting inference with prompt length: {} chars, estimated {} tokens",
@@ -297,7 +309,7 @@ impl EmbeddedProvider {
break;
}
}
if hit_stop {
break;
}
@@ -308,7 +320,7 @@ impl EmbeddedProvider {
token_count,
start_time.elapsed()
);
Ok((generated_text, token_count))
}),
)
@@ -347,21 +359,22 @@ impl EmbeddedProvider {
fn get_stop_sequences(&self) -> Vec<&'static str> {
// Determine model type from model_name
let model_name_lower = self.model_name.to_lowercase();
if model_name_lower.contains("qwen") {
vec![
"<|im_end|>", // Qwen ChatML format end token
"<|endoftext|>", // Alternative end token
"</s>", // Generic end of sequence
"<|im_start|>", // Start of new message (shouldn't appear in response)
"<|im_end|>", // Qwen ChatML format end token
"<|endoftext|>", // Alternative end token
"</s>", // Generic end of sequence
"<|im_start|>", // Start of new message (shouldn't appear in response)
]
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama")
{
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
"</s>", // End of sequence
"[/INST]", // End of instruction
"<</SYS>>", // End of system message
"[INST]", // Start of new instruction (shouldn't appear in response)
"<<SYS>>", // Start of system (shouldn't appear in response)
]
} else if model_name_lower.contains("llama") {
vec![
@@ -374,9 +387,9 @@ impl EmbeddedProvider {
]
} else if model_name_lower.contains("mistral") {
vec![
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
"</s>", // End of sequence
"[/INST]", // End of instruction
"<|im_end|>", // ChatML format
]
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
vec![
@@ -391,7 +404,7 @@ impl EmbeddedProvider {
"### Instruction:", // Alpaca format
"### Response:", // Alpaca format
"### Input:", // Alpaca format
"</s>", // End of sequence
"</s>", // End of sequence
]
} else {
// Generic/unknown model - use common stop sequences
@@ -411,14 +424,14 @@ impl EmbeddedProvider {
fn clean_stop_sequences(&self, text: &str) -> String {
let mut cleaned = text.to_string();
let stop_sequences = self.get_stop_sequences();
for stop_seq in &stop_sequences {
if let Some(pos) = cleaned.find(stop_seq) {
cleaned.truncate(pos);
break; // Only remove the first occurrence to avoid over-truncation
}
}
cleaned.trim().to_string()
}
@@ -426,57 +439,64 @@ impl EmbeddedProvider {
fn download_qwen_model(model_path: &Path) -> Result<()> {
use std::fs;
use std::process::Command;
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
// Create the parent directory if it doesn't exist
if let Some(parent) = model_path.parent() {
fs::create_dir_all(parent)?;
}
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
info!("This is a one-time download that may take several minutes depending on your connection.");
info!("Downloading to: {}", model_path.display());
// Use curl with progress bar for download
let output = Command::new("curl")
.args([
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
"-o", model_path.to_str().unwrap(),
"-L", // Follow redirects
"-#", // Show progress bar
"-f", // Fail on HTTP errors
"-o",
model_path.to_str().unwrap(),
MODEL_URL,
])
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// If curl is not available, provide alternative instructions
if stderr.contains("command not found") || stderr.contains("not found") {
error!("curl is not installed. Please install curl or manually download the model.");
error!(
"curl is not installed. Please install curl or manually download the model."
);
error!("Manual download instructions:");
error!("1. Download from: {}", MODEL_URL);
error!("2. Save to: {}", model_path.display());
anyhow::bail!("curl not found - please install curl or download the model manually");
anyhow::bail!(
"curl not found - please install curl or download the model manually"
);
}
anyhow::bail!("Failed to download model: {}", stderr);
}
// Verify the file was created and has reasonable size
let metadata = fs::metadata(model_path)?;
let size_mb = metadata.len() / (1024 * 1024);
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
fs::remove_file(model_path).ok(); // Clean up partial download
if size_mb < MODEL_SIZE_MB - 100 {
// Allow some variance
fs::remove_file(model_path).ok(); // Clean up partial download
anyhow::bail!(
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
size_mb, MODEL_SIZE_MB
size_mb,
MODEL_SIZE_MB
);
}
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
Ok(())
}
@@ -541,20 +561,29 @@ impl LLMProvider for EmbeddedProvider {
}
Err(_) => {
if attempt < 4 {
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
debug!(
"Session busy, retrying in {}ms (attempt {}/5)",
100 * (attempt + 1),
attempt + 1
);
std::thread::sleep(std::time::Duration::from_millis(
100 * (attempt + 1) as u64,
));
} else {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
let _ = tx.blocking_send(Err(anyhow::anyhow!(
"Model is busy after 5 attempts, please try again"
)));
return;
}
}
}
}
let mut session = match session_guard {
Some(ctx) => ctx,
None => {
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
let _ =
tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
return;
}
};
@@ -588,17 +617,33 @@ impl LLMProvider for EmbeddedProvider {
let mut accumulated_text = String::new();
let mut token_count = 0;
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
// Get stop sequences dynamically based on model type
let stop_sequences = if prompt.contains("<|im_start|>") {
// Qwen ChatML format detected
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
// Llama/CodeLlama format detected
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
vec![
"</s>",
"[/INST]",
"<</SYS>>",
"[INST]",
"<<SYS>>",
"### Human:",
"### Assistant:",
]
} else {
// Generic format
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
vec![
"</s>",
"<|endoftext|>",
"<|im_end|>",
"### Human:",
"### Assistant:",
"[/INST]",
"<</SYS>>",
]
};
// Stream tokens with proper limits
@@ -622,10 +667,10 @@ impl LLMProvider for EmbeddedProvider {
if hit_stop {
// Before stopping, check if there might be an incomplete tool call
// Look for JSON tool call patterns that might be cut off by the stop sequence
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
accumulated_text.contains(r#"{"{""tool"":"#) ||
accumulated_text.contains(r#"{{""tool"":"#);
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#)
|| accumulated_text.contains(r#"{"{""tool"":"#)
|| accumulated_text.contains(r#"{{""tool"":"#);
if has_potential_tool_call {
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
let mut complete_tool_call = false;
@@ -645,7 +690,7 @@ impl LLMProvider for EmbeddedProvider {
}
}
}
// If tool call is incomplete, send the raw content including stop sequences
// so the main parser can handle it properly
if !complete_tool_call {
@@ -666,7 +711,7 @@ impl LLMProvider for EmbeddedProvider {
break;
}
}
// Send any remaining clean content before stopping (original behavior)
let mut clean_accumulated = accumulated_text.clone();
for stop_seq in &stop_sequences {
@@ -675,7 +720,7 @@ impl LLMProvider for EmbeddedProvider {
break;
}
}
// Calculate what part we haven't sent yet
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
if clean_accumulated.len() > already_sent_len {
@@ -711,7 +756,8 @@ impl LLMProvider for EmbeddedProvider {
if might_be_stop {
// Hold back tokens, but only for a limited buffer size
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
if unsent_tokens.len() > 20 {
// Don't hold back more than 20 characters
// Send the oldest part and keep only the recent part that might be a stop sequence
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
if !to_send.is_empty() {
@@ -755,7 +801,7 @@ impl LLMProvider for EmbeddedProvider {
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
usage: None, // Embedded models calculate usage differently
usage: None, // Embedded models calculate usage differently
tool_calls: None,
};
let _ = tx.blocking_send(Ok(final_chunk));
@@ -771,11 +817,11 @@ impl LLMProvider for EmbeddedProvider {
fn model(&self) -> &str {
&self.model_name
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}

View File

@@ -1,36 +1,36 @@
use serde::{Deserialize, Serialize};
use anyhow::Result;
use std::collections::HashMap;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Trait for LLM providers
#[async_trait::async_trait]
pub trait LLMProvider: Send + Sync {
/// Generate a completion for the given messages
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
/// Stream a completion for the given messages
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
/// Get the provider name
fn name(&self) -> &str;
/// Get the model name
fn model(&self) -> &str;
/// Check if the provider supports native tool calling
fn has_native_tool_calling(&self) -> bool {
false
}
/// Check if the provider supports cache control
fn supports_cache_control(&self) -> bool {
false
}
/// Get the configured max_tokens for this provider
fn max_tokens(&self) -> u32;
/// Get the configured temperature for this provider
fn temperature(&self) -> f32;
}
@@ -60,15 +60,24 @@ pub enum CacheType {
impl CacheControl {
pub fn ephemeral() -> Self {
Self { cache_type: CacheType::Ephemeral, ttl: None }
Self {
cache_type: CacheType::Ephemeral,
ttl: None,
}
}
pub fn five_minute() -> Self {
Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) }
Self {
cache_type: CacheType::Ephemeral,
ttl: Some("5m".to_string()),
}
}
pub fn one_hour() -> Self {
Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) }
Self {
cache_type: CacheType::Ephemeral,
ttl: Some("1h".to_string()),
}
}
}
@@ -76,6 +85,7 @@ impl CacheControl {
pub struct Message {
pub role: MessageRole,
pub content: String,
#[serde(skip)]
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_control: Option<CacheControl>,
@@ -110,7 +120,7 @@ pub struct CompletionChunk {
pub content: String,
pub finished: bool,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>, // Add usage tracking for streaming
pub usage: Option<Usage>, // Add usage tracking for streaming
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -144,7 +154,7 @@ impl Message {
fn generate_id() -> String {
let now = chrono::Local::now();
let timestamp = now.format("%H%M%S").to_string();
let mut rng = rand::thread_rng();
let random_chars: String = (0..3)
.map(|_| {
@@ -153,10 +163,10 @@ impl Message {
chars[idx] as char
})
.collect();
format!("{}-{}", timestamp, random_chars)
}
/// Create a new message with optional cache control
pub fn new(role: MessageRole, content: String) -> Self {
Self {
@@ -168,7 +178,11 @@ impl Message {
}
/// Create a new message with cache control
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
pub fn with_cache_control(
role: MessageRole,
content: String,
cache_control: CacheControl,
) -> Self {
Self {
role,
content,
@@ -176,13 +190,13 @@ impl Message {
cache_control: Some(cache_control),
}
}
/// Create a message with cache control, with provider validation
pub fn with_cache_control_validated(
role: MessageRole,
content: String,
role: MessageRole,
content: String,
cache_control: CacheControl,
provider: &dyn LLMProvider
provider: &dyn LLMProvider,
) -> Self {
if !provider.supports_cache_control() {
tracing::warn!(
@@ -192,7 +206,7 @@ impl Message {
);
return Self::new(role, content);
}
Self::with_cache_control(role, content, cache_control)
}
}
@@ -210,16 +224,16 @@ impl ProviderRegistry {
default_provider: String::new(),
}
}
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
let name = provider.name().to_string();
self.providers.insert(name.clone(), Box::new(provider));
if self.default_provider.is_empty() {
self.default_provider = name;
}
}
pub fn set_default(&mut self, provider_name: &str) -> Result<()> {
if !self.providers.contains_key(provider_name) {
anyhow::bail!("Provider '{}' not found", provider_name);
@@ -227,7 +241,7 @@ impl ProviderRegistry {
self.default_provider = provider_name.to_string();
Ok(())
}
pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> {
let name = provider_name.unwrap_or(&self.default_provider);
self.providers
@@ -235,7 +249,7 @@ impl ProviderRegistry {
.map(|p| p.as_ref())
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name))
}
pub fn list_providers(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
@@ -255,10 +269,12 @@ mod tests {
fn test_message_serialization_without_cache_control() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON without cache_control: {}", json);
assert!(!json.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured");
assert!(
!json.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured"
);
}
#[test]
@@ -269,16 +285,24 @@ mod tests {
CacheControl::ephemeral(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with cache_control: {}", json);
assert!(json.contains("cache_control"),
"JSON should contain 'cache_control' field when configured");
assert!(json.contains("ephemeral"),
"JSON should contain 'ephemeral' value");
assert!(json.contains("\"type\":"),
"JSON should contain 'type' field in cache_control");
assert!(!json.contains("null"),
"JSON should not contain null values");
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field when configured"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' value"
);
assert!(
json.contains("\"type\":"),
"JSON should contain 'type' field in cache_control"
);
assert!(
!json.contains("null"),
"JSON should not contain null values"
);
}
#[test]
@@ -289,11 +313,20 @@ mod tests {
CacheControl::five_minute(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with 5-minute cache_control: {}", json);
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value");
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' type"
);
assert!(
json.contains("\"ttl\":\"5m\""),
"JSON should contain ttl field with 5m value"
);
}
#[test]
@@ -304,39 +337,53 @@ mod tests {
CacheControl::one_hour(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON with 1-hour cache_control: {}", json);
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value");
assert!(
json.contains("cache_control"),
"JSON should contain 'cache_control' field"
);
assert!(
json.contains("ephemeral"),
"JSON should contain 'ephemeral' type"
);
assert!(
json.contains("\"ttl\":\"1h\""),
"JSON should contain ttl field with 1h value"
);
}
#[test]
fn test_message_id_generation() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
// Check that id is not empty
assert!(!msg.id.is_empty(), "Message ID should not be empty");
// Check format: HHMMSS-XXX
let parts: Vec<&str> = msg.id.split('-').collect();
assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX");
// Check timestamp part is 6 digits
assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)");
assert!(parts[0].chars().all(|c| c.is_ascii_digit()), "Timestamp should be all digits");
assert!(
parts[0].chars().all(|c| c.is_ascii_digit()),
"Timestamp should be all digits"
);
// Check random part is 3 alpha characters
assert_eq!(parts[1].len(), 3, "Random part should be 3 characters");
assert!(parts[1].chars().all(|c| c.is_ascii_alphabetic()),
"Random part should be all alphabetic characters");
assert!(
parts[1].chars().all(|c| c.is_ascii_alphabetic()),
"Random part should be all alphabetic characters"
);
}
#[test]
fn test_message_id_uniqueness() {
let msg1 = Message::new(MessageRole::User, "Hello".to_string());
let msg2 = Message::new(MessageRole::User, "Hello".to_string());
// IDs should be different (due to random component)
// Note: There's a tiny chance they could be the same, but very unlikely
println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id);
@@ -346,9 +393,12 @@ mod tests {
fn test_message_id_not_serialized() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_string(&msg).unwrap();
println!("Message JSON: {}", json);
assert!(!json.contains("\"id\""), "JSON should not contain 'id' field");
assert!(
!json.contains("\"id\""),
"JSON should not contain 'id' field"
);
}
#[test]
@@ -358,8 +408,14 @@ mod tests {
"Hello".to_string(),
CacheControl::ephemeral(),
);
assert!(!msg.id.is_empty(), "Message with cache control should have an ID");
assert!(msg.id.contains('-'), "Message ID should contain hyphen separator");
assert!(
!msg.id.is_empty(),
"Message with cache control should have an ID"
);
assert!(
msg.id.contains('-'),
"Message ID should contain hyphen separator"
);
}
}

View File

@@ -10,8 +10,8 @@ use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
Message, MessageRole, Tool, ToolCall, Usage,
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage,
};
#[derive(Clone)]
@@ -138,7 +138,8 @@ impl OpenAIProvider {
debug!("Received stream completion marker");
// Send final chunk with accumulated content and tool calls
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
if !accumulated_content.is_empty() || !current_tool_calls.is_empty()
{
let tool_calls = if current_tool_calls.is_empty() {
None
} else {
@@ -188,8 +189,9 @@ impl OpenAIProvider {
if let Some(index) = delta_tool_call.index {
// Ensure we have enough tool calls in our vector
while current_tool_calls.len() <= index {
current_tool_calls
.push(OpenAIStreamingToolCall::default());
current_tool_calls.push(
OpenAIStreamingToolCall::default(),
);
}
let tool_call = &mut current_tool_calls[index];
@@ -198,11 +200,14 @@ impl OpenAIProvider {
tool_call.id = Some(id.clone());
}
if let Some(function) = &delta_tool_call.function {
if let Some(function) =
&delta_tool_call.function
{
if let Some(name) = &function.name {
tool_call.name = Some(name.clone());
}
if let Some(arguments) = &function.arguments {
if let Some(arguments) = &function.arguments
{
tool_call.arguments.push_str(arguments);
}
}
@@ -246,7 +251,7 @@ impl OpenAIProvider {
.collect(),
)
};
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
@@ -254,7 +259,7 @@ impl OpenAIProvider {
usage: accumulated_usage.clone(),
};
let _ = tx.send(Ok(final_chunk)).await;
accumulated_usage
}
}
@@ -291,7 +296,11 @@ impl LLMProvider for OpenAIProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
return Err(anyhow::anyhow!(
"OpenAI API error {}: {}",
status,
error_text
));
}
let openai_response: OpenAIResponse = response.json().await?;
@@ -334,7 +343,10 @@ impl LLMProvider for OpenAIProvider {
request.temperature,
);
debug!("Sending streaming request to OpenAI API: model={}", self.model);
debug!(
"Sending streaming request to OpenAI API: model={}",
self.model
);
let response = self
.client
@@ -350,7 +362,11 @@ impl LLMProvider for OpenAIProvider {
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
return Err(anyhow::anyhow!(
"OpenAI API error {}: {}",
status,
error_text
));
}
let stream = response.bytes_stream();
@@ -384,11 +400,11 @@ impl LLMProvider for OpenAIProvider {
// OpenAI models support native tool calling
true
}
fn max_tokens(&self) -> u32 {
self.max_tokens.unwrap_or(16000)
}
fn temperature(&self) -> f32 {
self._temperature.unwrap_or(0.1)
}
@@ -472,9 +488,9 @@ impl OpenAIStreamingToolCall {
fn to_tool_call(&self) -> Option<ToolCall> {
let id = self.id.as_ref()?;
let name = self.name.as_ref()?;
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
Some(ToolCall {
id: id.clone(),
tool: name.clone(),

View File

@@ -20,18 +20,24 @@ fn test_no_wrong_serialization_format() {
CacheControl::ephemeral(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("Ephemeral message JSON: {}", json);
// Should NOT contain the wrong format
assert!(!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path");
assert!(!json.contains("cache_control.ephemeral"),
"JSON should not contain 'cache_control.ephemeral' path");
assert!(
!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path"
);
assert!(
!json.contains("cache_control.ephemeral"),
"JSON should not contain 'cache_control.ephemeral' path"
);
// Should contain the correct format
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#),
"JSON should contain correct cache_control format");
assert!(
json.contains(r#""cache_control":{"type":"ephemeral"}"#),
"JSON should contain correct cache_control format"
);
}
#[test]
@@ -42,20 +48,28 @@ fn test_five_minute_no_wrong_format() {
CacheControl::five_minute(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("5-minute message JSON: {}", json);
// Should NOT contain the wrong format
assert!(!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path");
assert!(!json.contains("cache_control.ephemeral.ttl"),
"JSON should not contain 'cache_control.ephemeral.ttl' path");
assert!(
!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path"
);
assert!(
!json.contains("cache_control.ephemeral.ttl"),
"JSON should not contain 'cache_control.ephemeral.ttl' path"
);
// Should contain the correct format with ttl as a direct field
assert!(json.contains(r#""type":"ephemeral""#),
"JSON should contain type field");
assert!(json.contains(r#""ttl":"5m""#),
"JSON should contain ttl field with value 5m");
assert!(
json.contains(r#""type":"ephemeral""#),
"JSON should contain type field"
);
assert!(
json.contains(r#""ttl":"5m""#),
"JSON should contain ttl field with value 5m"
);
}
#[test]
@@ -66,44 +80,59 @@ fn test_one_hour_no_wrong_format() {
CacheControl::one_hour(),
);
let json = serde_json::to_string(&msg).unwrap();
println!("1-hour message JSON: {}", json);
// Should NOT contain the wrong format
assert!(!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path");
assert!(!json.contains("cache_control.ephemeral.ttl"),
"JSON should not contain 'cache_control.ephemeral.ttl' path");
assert!(
!json.contains("system.0.cache_control"),
"JSON should not contain 'system.0.cache_control' path"
);
assert!(
!json.contains("cache_control.ephemeral.ttl"),
"JSON should not contain 'cache_control.ephemeral.ttl' path"
);
// Should contain the correct format with ttl as a direct field
assert!(json.contains(r#""type":"ephemeral""#),
"JSON should contain type field");
assert!(json.contains(r#""ttl":"1h""#),
"JSON should contain ttl field with value 1h");
assert!(
json.contains(r#""type":"ephemeral""#),
"JSON should contain type field"
);
assert!(
json.contains(r#""ttl":"1h""#),
"JSON should contain ttl field with value 1h"
);
}
#[test]
fn test_cache_control_structure_is_flat() {
// Verify that the cache_control object has a flat structure
// with 'type' and optional 'ttl' at the same level
let cache_control = CacheControl::five_minute();
let json_value = serde_json::to_value(&cache_control).unwrap();
println!("Cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
println!(
"Cache control as JSON value: {}",
serde_json::to_string_pretty(&json_value).unwrap()
);
let obj = json_value.as_object().expect("Should be an object");
// Should have exactly 2 keys at the top level
assert_eq!(obj.len(), 2, "Cache control should have exactly 2 top-level fields");
assert_eq!(
obj.len(),
2,
"Cache control should have exactly 2 top-level fields"
);
// Both 'type' and 'ttl' should be at the same level
assert!(obj.contains_key("type"), "Should have 'type' field");
assert!(obj.contains_key("ttl"), "Should have 'ttl' field");
// 'type' should be a string, not an object
assert!(obj["type"].is_string(), "'type' should be a string value");
// 'ttl' should be a string, not nested
assert!(obj["ttl"].is_string(), "'ttl' should be a string value");
}
@@ -112,20 +141,30 @@ fn test_cache_control_structure_is_flat() {
fn test_ephemeral_cache_control_structure() {
let cache_control = CacheControl::ephemeral();
let json_value = serde_json::to_value(&cache_control).unwrap();
println!("Ephemeral cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
println!(
"Ephemeral cache control as JSON value: {}",
serde_json::to_string_pretty(&json_value).unwrap()
);
let obj = json_value.as_object().expect("Should be an object");
// Should have exactly 1 key (only 'type', no 'ttl')
assert_eq!(obj.len(), 1, "Ephemeral cache control should have exactly 1 top-level field");
assert_eq!(
obj.len(),
1,
"Ephemeral cache control should have exactly 1 top-level field"
);
// Should have 'type' field
assert!(obj.contains_key("type"), "Should have 'type' field");
// Should NOT have 'ttl' field
assert!(!obj.contains_key("ttl"), "Ephemeral should not have 'ttl' field");
assert!(
!obj.contains_key("ttl"),
"Ephemeral should not have 'ttl' field"
);
// 'type' should be a string with value "ephemeral"
assert_eq!(obj["type"].as_str().unwrap(), "ephemeral");
}

View File

@@ -10,13 +10,19 @@ use serde_json::json;
fn test_ephemeral_cache_control_serialization() {
let cache_control = CacheControl::ephemeral();
let json = serde_json::to_value(&cache_control).unwrap();
println!("Ephemeral cache_control JSON: {}", serde_json::to_string(&json).unwrap());
assert_eq!(json, json!({
"type": "ephemeral"
}));
println!(
"Ephemeral cache_control JSON: {}",
serde_json::to_string(&json).unwrap()
);
assert_eq!(
json,
json!({
"type": "ephemeral"
})
);
// Verify no ttl field is present
assert!(!json.as_object().unwrap().contains_key("ttl"));
}
@@ -25,26 +31,38 @@ fn test_ephemeral_cache_control_serialization() {
fn test_five_minute_cache_control_serialization() {
let cache_control = CacheControl::five_minute();
let json = serde_json::to_value(&cache_control).unwrap();
println!("5-minute cache_control JSON: {}", serde_json::to_string(&json).unwrap());
assert_eq!(json, json!({
"type": "ephemeral",
"ttl": "5m"
}));
println!(
"5-minute cache_control JSON: {}",
serde_json::to_string(&json).unwrap()
);
assert_eq!(
json,
json!({
"type": "ephemeral",
"ttl": "5m"
})
);
}
#[test]
fn test_one_hour_cache_control_serialization() {
let cache_control = CacheControl::one_hour();
let json = serde_json::to_value(&cache_control).unwrap();
println!("1-hour cache_control JSON: {}", serde_json::to_string(&json).unwrap());
assert_eq!(json, json!({
"type": "ephemeral",
"ttl": "1h"
}));
println!(
"1-hour cache_control JSON: {}",
serde_json::to_string(&json).unwrap()
);
assert_eq!(
json,
json!({
"type": "ephemeral",
"ttl": "1h"
})
);
}
#[test]
@@ -54,11 +72,16 @@ fn test_message_with_ephemeral_cache_control() {
"System prompt".to_string(),
CacheControl::ephemeral(),
);
let json = serde_json::to_value(&msg).unwrap();
println!("Message with ephemeral cache_control: {}", serde_json::to_string(&json).unwrap());
let cache_control = json.get("cache_control").expect("cache_control field should exist");
println!(
"Message with ephemeral cache_control: {}",
serde_json::to_string(&json).unwrap()
);
let cache_control = json
.get("cache_control")
.expect("cache_control field should exist");
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
assert!(!cache_control.as_object().unwrap().contains_key("ttl"));
}
@@ -70,11 +93,16 @@ fn test_message_with_five_minute_cache_control() {
"System prompt".to_string(),
CacheControl::five_minute(),
);
let json = serde_json::to_value(&msg).unwrap();
println!("Message with 5-minute cache_control: {}", serde_json::to_string(&json).unwrap());
let cache_control = json.get("cache_control").expect("cache_control field should exist");
println!(
"Message with 5-minute cache_control: {}",
serde_json::to_string(&json).unwrap()
);
let cache_control = json
.get("cache_control")
.expect("cache_control field should exist");
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
assert_eq!(cache_control.get("ttl").unwrap(), "5m");
}
@@ -86,11 +114,16 @@ fn test_message_with_one_hour_cache_control() {
"System prompt".to_string(),
CacheControl::one_hour(),
);
let json = serde_json::to_value(&msg).unwrap();
println!("Message with 1-hour cache_control: {}", serde_json::to_string(&json).unwrap());
let cache_control = json.get("cache_control").expect("cache_control field should exist");
println!(
"Message with 1-hour cache_control: {}",
serde_json::to_string(&json).unwrap()
);
let cache_control = json
.get("cache_control")
.expect("cache_control field should exist");
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
assert_eq!(cache_control.get("ttl").unwrap(), "1h");
}
@@ -98,10 +131,13 @@ fn test_message_with_one_hour_cache_control() {
#[test]
fn test_message_without_cache_control() {
let msg = Message::new(MessageRole::User, "Hello".to_string());
let json = serde_json::to_value(&msg).unwrap();
println!("Message without cache_control: {}", serde_json::to_string(&json).unwrap());
println!(
"Message without cache_control: {}",
serde_json::to_string(&json).unwrap()
);
// cache_control field should not be present when not set
assert!(!json.as_object().unwrap().contains_key("cache_control"));
}
@@ -110,9 +146,9 @@ fn test_message_without_cache_control() {
fn test_cache_control_json_format_ephemeral() {
let cache_control = CacheControl::ephemeral();
let json_str = serde_json::to_string(&cache_control).unwrap();
println!("Ephemeral JSON string: {}", json_str);
// Verify exact JSON format
assert_eq!(json_str, r#"{"type":"ephemeral"}"#);
}
@@ -121,9 +157,9 @@ fn test_cache_control_json_format_ephemeral() {
fn test_cache_control_json_format_five_minute() {
let cache_control = CacheControl::five_minute();
let json_str = serde_json::to_string(&cache_control).unwrap();
println!("5-minute JSON string: {}", json_str);
// Verify exact JSON format
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"5m"}"#);
}
@@ -132,9 +168,9 @@ fn test_cache_control_json_format_five_minute() {
fn test_cache_control_json_format_one_hour() {
let cache_control = CacheControl::one_hour();
let json_str = serde_json::to_string(&cache_control).unwrap();
println!("1-hour JSON string: {}", json_str);
// Verify exact JSON format
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"1h"}"#);
}
@@ -143,7 +179,7 @@ fn test_cache_control_json_format_one_hour() {
fn test_deserialization_ephemeral() {
let json_str = r#"{"type":"ephemeral"}"#;
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
assert_eq!(cache_control.ttl, None);
}
@@ -151,7 +187,7 @@ fn test_deserialization_ephemeral() {
fn test_deserialization_five_minute() {
let json_str = r#"{"type":"ephemeral","ttl":"5m"}"#;
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
assert_eq!(cache_control.ttl, Some("5m".to_string()));
}
@@ -159,6 +195,6 @@ fn test_deserialization_five_minute() {
fn test_deserialization_one_hour() {
let json_str = r#"{"type":"ephemeral","ttl":"1h"}"#;
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
assert_eq!(cache_control.ttl, Some("1h".to_string()));
}