add context window monitor
Writes the current context window to logs/current_context_window (uses a symlink to a session ID). This PR was unfortunately generated by a different LLM and did a ton of superficial reformating, it's actually a fairly small and benign change, but I don't want to roll back everything. Hope that's ok.
This commit is contained in:
@@ -139,7 +139,7 @@ impl AnthropicProvider {
|
||||
.map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let model = model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string());
|
||||
|
||||
|
||||
debug!("Initialized Anthropic provider with model: {}", model);
|
||||
|
||||
Ok(Self {
|
||||
@@ -160,11 +160,11 @@ impl AnthropicProvider {
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", ANTHROPIC_VERSION)
|
||||
.header("content-type", "application/json");
|
||||
|
||||
|
||||
if self.enable_1m_context {
|
||||
builder = builder.header("anthropic-beta", "context-1m-2025-08-07");
|
||||
}
|
||||
|
||||
|
||||
if streaming {
|
||||
builder = builder.header("accept", "text/event-stream");
|
||||
}
|
||||
@@ -188,12 +188,17 @@ impl AnthropicProvider {
|
||||
};
|
||||
|
||||
// Extract properties and required fields from the input schema
|
||||
if let Ok(schema_obj) = serde_json::from_value::<serde_json::Map<String, serde_json::Value>>(tool.input_schema.clone()) {
|
||||
if let Ok(schema_obj) = serde_json::from_value::<
|
||||
serde_json::Map<String, serde_json::Value>,
|
||||
>(tool.input_schema.clone())
|
||||
{
|
||||
if let Some(properties) = schema_obj.get("properties") {
|
||||
schema.properties = properties.clone();
|
||||
}
|
||||
if let Some(required) = schema_obj.get("required") {
|
||||
if let Ok(required_vec) = serde_json::from_value::<Vec<String>>(required.clone()) {
|
||||
if let Ok(required_vec) =
|
||||
serde_json::from_value::<Vec<String>>(required.clone())
|
||||
{
|
||||
schema.required = Some(required_vec);
|
||||
}
|
||||
}
|
||||
@@ -208,7 +213,10 @@ impl AnthropicProvider {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_messages(&self, messages: &[Message]) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
||||
fn convert_messages(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
) -> Result<(Option<String>, Vec<AnthropicMessage>)> {
|
||||
let mut system_message = None;
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
@@ -225,7 +233,9 @@ impl AnthropicProvider {
|
||||
role: "user".to_string(),
|
||||
content: vec![AnthropicContent::Text {
|
||||
text: message.content.clone(),
|
||||
cache_control: message.cache_control.as_ref()
|
||||
cache_control: message
|
||||
.cache_control
|
||||
.as_ref()
|
||||
.map(Self::convert_cache_control),
|
||||
}],
|
||||
});
|
||||
@@ -235,7 +245,9 @@ impl AnthropicProvider {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![AnthropicContent::Text {
|
||||
text: message.content.clone(),
|
||||
cache_control: message.cache_control.as_ref()
|
||||
cache_control: message
|
||||
.cache_control
|
||||
.as_ref()
|
||||
.map(Self::convert_cache_control),
|
||||
}],
|
||||
});
|
||||
@@ -257,7 +269,9 @@ impl AnthropicProvider {
|
||||
let (system, anthropic_messages) = self.convert_messages(messages)?;
|
||||
|
||||
if anthropic_messages.is_empty() {
|
||||
return Err(anyhow!("At least one user or assistant message is required"));
|
||||
return Err(anyhow!(
|
||||
"At least one user or assistant message is required"
|
||||
));
|
||||
}
|
||||
|
||||
// Convert tools if provided
|
||||
@@ -292,13 +306,13 @@ impl AnthropicProvider {
|
||||
let mut accumulated_usage: Option<Usage> = None;
|
||||
let mut byte_buffer = Vec::new(); // Buffer for incomplete UTF-8 sequences
|
||||
let mut message_stopped = false; // Track if we've received message_stop
|
||||
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Append new bytes to our buffer
|
||||
byte_buffer.extend_from_slice(&chunk);
|
||||
|
||||
|
||||
// Try to convert the entire buffer to UTF-8
|
||||
let chunk_str = match std::str::from_utf8(&byte_buffer) {
|
||||
Ok(s) => {
|
||||
@@ -312,7 +326,8 @@ impl AnthropicProvider {
|
||||
let valid_up_to = e.valid_up_to();
|
||||
if valid_up_to > 0 {
|
||||
// We have some valid UTF-8, extract it and keep the rest for next iteration
|
||||
let valid_bytes = byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||
let valid_bytes =
|
||||
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||
std::str::from_utf8(&valid_bytes).unwrap().to_string()
|
||||
} else {
|
||||
// No valid UTF-8 at all, skip this chunk and continue
|
||||
@@ -346,7 +361,11 @@ impl AnthropicProvider {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
tool_calls: if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(current_tool_calls.clone())
|
||||
},
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
@@ -358,7 +377,10 @@ impl AnthropicProvider {
|
||||
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
debug!("Parsed event type: {}, event: {:?}", event.event_type, event);
|
||||
debug!(
|
||||
"Parsed event type: {}, event: {:?}",
|
||||
event.event_type, event
|
||||
);
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Extract usage data from message_start event
|
||||
@@ -367,19 +389,30 @@ impl AnthropicProvider {
|
||||
accumulated_usage = Some(Usage {
|
||||
prompt_tokens: usage.input_tokens,
|
||||
completion_tokens: usage.output_tokens,
|
||||
total_tokens: usage.input_tokens + usage.output_tokens,
|
||||
total_tokens: usage.input_tokens
|
||||
+ usage.output_tokens,
|
||||
});
|
||||
debug!("Captured usage from message_start: {:?}", accumulated_usage);
|
||||
debug!(
|
||||
"Captured usage from message_start: {:?}",
|
||||
accumulated_usage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_start" => {
|
||||
debug!("Received content_block_start event: {:?}", event);
|
||||
debug!(
|
||||
"Received content_block_start event: {:?}",
|
||||
event
|
||||
);
|
||||
if let Some(content_block) = event.content_block {
|
||||
match content_block {
|
||||
AnthropicContent::ToolUse { id, name, input } => {
|
||||
AnthropicContent::ToolUse {
|
||||
id,
|
||||
name,
|
||||
input,
|
||||
} => {
|
||||
debug!("Found tool use in content_block_start: id={}, name={}, input={:?}", id, name, input);
|
||||
|
||||
|
||||
// For native tool calls, create the tool call immediately if we have complete args
|
||||
// If args are empty, we'll wait for partial_json to accumulate them
|
||||
let tool_call = ToolCall {
|
||||
@@ -387,9 +420,14 @@ impl AnthropicProvider {
|
||||
tool: name.clone(),
|
||||
args: input.clone(),
|
||||
};
|
||||
|
||||
|
||||
// Check if we already have complete arguments
|
||||
if !input.is_null() && input != serde_json::Value::Object(serde_json::Map::new()) {
|
||||
if !input.is_null()
|
||||
&& input
|
||||
!= serde_json::Value::Object(
|
||||
serde_json::Map::new(),
|
||||
)
|
||||
{
|
||||
// We have complete arguments, send the tool call immediately
|
||||
debug!("Tool call has complete args, sending immediately: {:?}", tool_call);
|
||||
let chunk = CompletionChunk {
|
||||
@@ -410,7 +448,10 @@ impl AnthropicProvider {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
debug!("Non-tool content block: {:?}", content_block);
|
||||
debug!(
|
||||
"Non-tool content block: {:?}",
|
||||
content_block
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -418,7 +459,11 @@ impl AnthropicProvider {
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
debug!("Sending text chunk of length {}: '{}'", text.len(), text);
|
||||
debug!(
|
||||
"Sending text chunk of length {}: '{}'",
|
||||
text.len(),
|
||||
text
|
||||
);
|
||||
let chunk = CompletionChunk {
|
||||
content: text,
|
||||
finished: false,
|
||||
@@ -432,31 +477,51 @@ impl AnthropicProvider {
|
||||
}
|
||||
// Handle partial JSON for tool calls
|
||||
if let Some(partial_json) = delta.partial_json {
|
||||
debug!("Received partial JSON: {}", partial_json);
|
||||
debug!(
|
||||
"Received partial JSON: {}",
|
||||
partial_json
|
||||
);
|
||||
partial_tool_json.push_str(&partial_json);
|
||||
debug!("Accumulated tool JSON: {}", partial_tool_json);
|
||||
debug!(
|
||||
"Accumulated tool JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_stop" => {
|
||||
// Tool call block is complete - now parse the accumulated JSON
|
||||
if !current_tool_calls.is_empty() && !partial_tool_json.is_empty() {
|
||||
debug!("Parsing complete tool JSON: {}", partial_tool_json);
|
||||
|
||||
if !current_tool_calls.is_empty()
|
||||
&& !partial_tool_json.is_empty()
|
||||
{
|
||||
debug!(
|
||||
"Parsing complete tool JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
|
||||
// Parse the accumulated JSON and update the last tool call
|
||||
if let Ok(parsed_args) = serde_json::from_str::<serde_json::Value>(&partial_tool_json) {
|
||||
if let Some(last_tool) = current_tool_calls.last_mut() {
|
||||
if let Ok(parsed_args) =
|
||||
serde_json::from_str::<serde_json::Value>(
|
||||
&partial_tool_json,
|
||||
)
|
||||
{
|
||||
if let Some(last_tool) =
|
||||
current_tool_calls.last_mut()
|
||||
{
|
||||
last_tool.args = parsed_args;
|
||||
debug!("Updated tool call with complete args: {:?}", last_tool);
|
||||
}
|
||||
} else {
|
||||
debug!("Failed to parse accumulated JSON: {}", partial_tool_json);
|
||||
debug!(
|
||||
"Failed to parse accumulated JSON: {}",
|
||||
partial_tool_json
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Clear the accumulator
|
||||
partial_tool_json.clear();
|
||||
}
|
||||
|
||||
|
||||
// Send the complete tool call
|
||||
if !current_tool_calls.is_empty() {
|
||||
let chunk = CompletionChunk {
|
||||
@@ -478,7 +543,11 @@ impl AnthropicProvider {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls.clone()) },
|
||||
tool_calls: if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(current_tool_calls.clone())
|
||||
},
|
||||
};
|
||||
if tx.send(Ok(final_chunk)).await.is_err() {
|
||||
debug!("Receiver dropped, stopping stream");
|
||||
@@ -490,7 +559,10 @@ impl AnthropicProvider {
|
||||
if let Some(error) = event.error {
|
||||
error!("Anthropic API error: {:?}", error);
|
||||
let _ = tx
|
||||
.send(Err(anyhow!("Anthropic API error: {:?}", error)))
|
||||
.send(Err(anyhow!(
|
||||
"Anthropic API error: {:?}",
|
||||
error
|
||||
)))
|
||||
.await;
|
||||
break; // Break to let stream exhaust naturally
|
||||
}
|
||||
@@ -524,7 +596,11 @@ impl AnthropicProvider {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: accumulated_usage.clone(),
|
||||
tool_calls: if current_tool_calls.is_empty() { None } else { Some(current_tool_calls) },
|
||||
tool_calls: if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(current_tool_calls)
|
||||
},
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
accumulated_usage
|
||||
@@ -543,15 +619,17 @@ impl LLMProvider for AnthropicProvider {
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
false,
|
||||
max_tokens,
|
||||
temperature
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
false,
|
||||
max_tokens,
|
||||
temperature,
|
||||
)?;
|
||||
|
||||
debug!("Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||
debug!(
|
||||
"Sending request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature
|
||||
);
|
||||
|
||||
let response = self
|
||||
.create_request_builder(false)
|
||||
@@ -588,7 +666,8 @@ impl LLMProvider for AnthropicProvider {
|
||||
let usage = Usage {
|
||||
prompt_tokens: anthropic_response.usage.input_tokens,
|
||||
completion_tokens: anthropic_response.usage.output_tokens,
|
||||
total_tokens: anthropic_response.usage.input_tokens + anthropic_response.usage.output_tokens,
|
||||
total_tokens: anthropic_response.usage.input_tokens
|
||||
+ anthropic_response.usage.output_tokens,
|
||||
};
|
||||
|
||||
debug!(
|
||||
@@ -613,18 +692,24 @@ impl LLMProvider for AnthropicProvider {
|
||||
let temperature = request.temperature.unwrap_or(self.temperature);
|
||||
|
||||
let request_body = self.create_request_body(
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
true,
|
||||
max_tokens,
|
||||
temperature
|
||||
&request.messages,
|
||||
request.tools.as_deref(),
|
||||
true,
|
||||
max_tokens,
|
||||
temperature,
|
||||
)?;
|
||||
|
||||
debug!("Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature);
|
||||
|
||||
debug!(
|
||||
"Sending streaming request to Anthropic API: model={}, max_tokens={}, temperature={}",
|
||||
request_body.model, request_body.max_tokens, request_body.temperature
|
||||
);
|
||||
|
||||
// Debug: Log the full request body
|
||||
debug!("Full request body: {}", serde_json::to_string_pretty(&request_body).unwrap_or_else(|_| "Failed to serialize".to_string()));
|
||||
debug!(
|
||||
"Full request body: {}",
|
||||
serde_json::to_string_pretty(&request_body)
|
||||
.unwrap_or_else(|_| "Failed to serialize".to_string())
|
||||
);
|
||||
|
||||
let response = self
|
||||
.create_request_builder(true)
|
||||
@@ -673,16 +758,16 @@ impl LLMProvider for AnthropicProvider {
|
||||
// Claude models support native tool calling
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
// Anthropic supports cache control
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
@@ -729,7 +814,7 @@ struct AnthropicMessage {
|
||||
#[serde(tag = "type")]
|
||||
enum AnthropicContent {
|
||||
#[serde(rename = "text")]
|
||||
Text {
|
||||
Text {
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<crate::CacheControl>,
|
||||
@@ -798,17 +883,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let provider = AnthropicProvider::new(
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
let provider =
|
||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
||||
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(
|
||||
MessageRole::System,
|
||||
"You are a helpful assistant.".to_string(),
|
||||
),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||
];
|
||||
@@ -830,7 +912,8 @@ mod tests {
|
||||
Some(0.5),
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||
|
||||
@@ -848,31 +931,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let provider = AnthropicProvider::new(
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
let provider =
|
||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
||||
|
||||
let tools = vec![
|
||||
Tool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Get the current weather".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
},
|
||||
];
|
||||
let tools = vec![Tool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Get the current weather".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}),
|
||||
}];
|
||||
|
||||
let anthropic_tools = provider.convert_tools(&tools);
|
||||
|
||||
@@ -881,31 +956,30 @@ mod tests {
|
||||
assert_eq!(anthropic_tools[0].description, "Get the current weather");
|
||||
assert_eq!(anthropic_tools[0].input_schema.schema_type, "object");
|
||||
assert!(anthropic_tools[0].input_schema.required.is_some());
|
||||
assert_eq!(anthropic_tools[0].input_schema.required.as_ref().unwrap()[0], "location");
|
||||
assert_eq!(
|
||||
anthropic_tools[0].input_schema.required.as_ref().unwrap()[0],
|
||||
"location"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_control_serialization() {
|
||||
let provider = AnthropicProvider::new(
|
||||
"test-key".to_string(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
).unwrap();
|
||||
let provider =
|
||||
AnthropicProvider::new("test-key".to_string(), None, None, None, None, None).unwrap();
|
||||
|
||||
// Test message WITHOUT cache_control
|
||||
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||
let (_, anthropic_messages_without) = provider.convert_messages(&messages_without).unwrap();
|
||||
let json_without = serde_json::to_string(&anthropic_messages_without).unwrap();
|
||||
|
||||
|
||||
println!("Anthropic JSON without cache_control: {}", json_without);
|
||||
// Check if cache_control appears in the JSON
|
||||
if json_without.contains("cache_control") {
|
||||
println!("WARNING: JSON contains 'cache_control' field when not configured!");
|
||||
assert!(!json_without.contains("\"cache_control\":null"),
|
||||
"JSON should not contain 'cache_control: null'");
|
||||
assert!(
|
||||
!json_without.contains("\"cache_control\":null"),
|
||||
"JSON should not contain 'cache_control: null'"
|
||||
);
|
||||
}
|
||||
|
||||
// Test message WITH cache_control
|
||||
@@ -916,15 +990,21 @@ mod tests {
|
||||
)];
|
||||
let (_, anthropic_messages_with) = provider.convert_messages(&messages_with).unwrap();
|
||||
let json_with = serde_json::to_string(&anthropic_messages_with).unwrap();
|
||||
|
||||
|
||||
println!("Anthropic JSON with cache_control: {}", json_with);
|
||||
assert!(json_with.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured");
|
||||
assert!(json_with.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type");
|
||||
|
||||
assert!(
|
||||
json_with.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured"
|
||||
);
|
||||
assert!(
|
||||
json_with.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type"
|
||||
);
|
||||
|
||||
// The key assertion: when cache_control is None, it should not appear in JSON
|
||||
assert!(!json_without.contains("cache_control") || !json_without.contains("null"),
|
||||
"JSON should not contain 'cache_control' field or null values when not configured");
|
||||
assert!(
|
||||
!json_without.contains("cache_control") || !json_without.contains("null"),
|
||||
"JSON should not contain 'cache_control' field or null values when not configured"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +312,7 @@ impl DatabricksProvider {
|
||||
|
||||
// Append new bytes to our buffer
|
||||
byte_buffer.extend_from_slice(&chunk);
|
||||
|
||||
|
||||
// Try to convert the entire buffer to UTF-8
|
||||
let chunk_str = match std::str::from_utf8(&byte_buffer) {
|
||||
Ok(s) => {
|
||||
@@ -326,7 +326,8 @@ impl DatabricksProvider {
|
||||
let valid_up_to = e.valid_up_to();
|
||||
if valid_up_to > 0 {
|
||||
// We have some valid UTF-8, extract it and keep the rest for next iteration
|
||||
let valid_bytes = byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||
let valid_bytes =
|
||||
byte_buffer.drain(..valid_up_to).collect::<Vec<_>>();
|
||||
std::str::from_utf8(&valid_bytes).unwrap().to_string()
|
||||
} else {
|
||||
// No valid UTF-8 at all, skip this chunk and continue
|
||||
@@ -593,7 +594,7 @@ impl DatabricksProvider {
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Stream error at chunk {}: {}", chunk_count, e);
|
||||
|
||||
|
||||
// Check if this is a connection error that might be recoverable
|
||||
let error_msg = e.to_string();
|
||||
if error_msg.contains("unexpected EOF") || error_msg.contains("connection") {
|
||||
@@ -610,10 +611,14 @@ impl DatabricksProvider {
|
||||
|
||||
// Log final state
|
||||
debug!("Stream ended after {} chunks", chunk_count);
|
||||
debug!("Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
|
||||
buffer.len(), incomplete_data_line.len(), byte_buffer.len());
|
||||
debug!(
|
||||
"Final state: buffer_len={}, incomplete_data_line_len={}, byte_buffer_len={}",
|
||||
buffer.len(),
|
||||
incomplete_data_line.len(),
|
||||
byte_buffer.len()
|
||||
);
|
||||
debug!("Accumulated tool calls: {}", current_tool_calls.len());
|
||||
|
||||
|
||||
// If we have any remaining data in buffers, log it for debugging
|
||||
if !buffer.is_empty() {
|
||||
debug!("Remaining buffer content: {:?}", buffer);
|
||||
@@ -924,7 +929,7 @@ impl LLMProvider for DatabricksProvider {
|
||||
"Processing Databricks streaming request with {} messages",
|
||||
request.messages.len()
|
||||
);
|
||||
|
||||
|
||||
// Debug: Log tool count
|
||||
if let Some(ref tools) = request.tools {
|
||||
debug!("Request has {} tools", tools.len());
|
||||
@@ -1051,15 +1056,15 @@ impl LLMProvider for DatabricksProvider {
|
||||
// This includes Claude, Llama, DBRX, and most other models on the platform
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
@@ -1181,7 +1186,10 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(
|
||||
MessageRole::System,
|
||||
"You are a helpful assistant.".to_string(),
|
||||
),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||
];
|
||||
@@ -1304,10 +1312,12 @@ mod tests {
|
||||
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
|
||||
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
|
||||
|
||||
|
||||
println!("JSON without cache_control: {}", json_without);
|
||||
assert!(!json_without.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured");
|
||||
assert!(
|
||||
!json_without.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured"
|
||||
);
|
||||
|
||||
// Test message WITH cache_control - should still NOT include it (Databricks doesn't support it)
|
||||
let messages_with = vec![Message::with_cache_control(
|
||||
@@ -1317,10 +1327,12 @@ mod tests {
|
||||
)];
|
||||
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
|
||||
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
|
||||
|
||||
|
||||
println!("JSON with cache_control: {}", json_with);
|
||||
assert!(!json_with.contains("cache_control"),
|
||||
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it");
|
||||
assert!(
|
||||
!json_with.contains("cache_control"),
|
||||
"JSON should NOT contain 'cache_control' field - Databricks doesn't support it"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1343,7 +1355,13 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(!claude_provider.supports_cache_control(), "Databricks should not support cache_control even for Claude models");
|
||||
assert!(!llama_provider.supports_cache_control(), "Databricks should not support cache_control for Llama models");
|
||||
assert!(
|
||||
!claude_provider.supports_cache_control(),
|
||||
"Databricks should not support cache_control even for Claude models"
|
||||
);
|
||||
assert!(
|
||||
!llama_provider.supports_cache_control(),
|
||||
"Databricks should not support cache_control for Llama models"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Usage,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use llama_cpp::{
|
||||
standard_sampler::{SamplerStage, StandardSampler},
|
||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||
@@ -37,7 +37,7 @@ impl EmbeddedProvider {
|
||||
// Expand tilde in path
|
||||
let expanded_path = shellexpand::tilde(&model_path);
|
||||
let model_path_buf = PathBuf::from(expanded_path.as_ref());
|
||||
|
||||
|
||||
// If model doesn't exist and it's the default Qwen model, offer to download it
|
||||
if !model_path_buf.exists() {
|
||||
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
|
||||
@@ -47,7 +47,7 @@ impl EmbeddedProvider {
|
||||
anyhow::bail!("Model file not found: {}", model_path_buf.display());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let model_path = model_path_buf.as_path();
|
||||
|
||||
// Set up model parameters
|
||||
@@ -93,24 +93,24 @@ impl EmbeddedProvider {
|
||||
fn format_messages(&self, messages: &[Message]) -> String {
|
||||
// Determine the appropriate format based on model type
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
// Qwen format: <|im_start|>role\ncontent<|im_end|>
|
||||
let mut formatted = String::new();
|
||||
|
||||
|
||||
for message in messages {
|
||||
let role = match message.role {
|
||||
MessageRole::System => "system",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::User => "user",
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
|
||||
formatted.push_str(&format!(
|
||||
"<|im_start|>{}\n{}<|im_end|>\n",
|
||||
role, message.content
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
// Add the start of assistant response
|
||||
formatted.push_str("<|im_start|>assistant\n");
|
||||
formatted
|
||||
@@ -118,7 +118,7 @@ impl EmbeddedProvider {
|
||||
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
|
||||
let mut formatted = String::new();
|
||||
let mut in_conversation = false;
|
||||
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
match message.role {
|
||||
MessageRole::System => {
|
||||
@@ -146,12 +146,15 @@ impl EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If the last message was from user, add a space for the assistant's response
|
||||
if messages.last().is_some_and(|m| matches!(m.role, MessageRole::User)) {
|
||||
if messages
|
||||
.last()
|
||||
.is_some_and(|m| matches!(m.role, MessageRole::User))
|
||||
{
|
||||
formatted.push(' ');
|
||||
}
|
||||
|
||||
|
||||
formatted
|
||||
} else {
|
||||
// Use Llama/CodeLlama format for other models
|
||||
@@ -216,16 +219,25 @@ impl EmbeddedProvider {
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
debug!(
|
||||
"Session busy, retrying in {}ms (attempt {}/5)",
|
||||
100 * (attempt + 1),
|
||||
attempt + 1
|
||||
);
|
||||
std::thread::sleep(std::time::Duration::from_millis(
|
||||
100 * (attempt + 1) as u64,
|
||||
));
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again"));
|
||||
return Err(anyhow::anyhow!(
|
||||
"Model is busy after 5 attempts, please try again"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut session = session_guard.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
let mut session = session_guard
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to acquire session lock"))?;
|
||||
|
||||
debug!(
|
||||
"Starting inference with prompt length: {} chars, estimated {} tokens",
|
||||
@@ -297,7 +309,7 @@ impl EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if hit_stop {
|
||||
break;
|
||||
}
|
||||
@@ -308,7 +320,7 @@ impl EmbeddedProvider {
|
||||
token_count,
|
||||
start_time.elapsed()
|
||||
);
|
||||
|
||||
|
||||
Ok((generated_text, token_count))
|
||||
}),
|
||||
)
|
||||
@@ -347,21 +359,22 @@ impl EmbeddedProvider {
|
||||
fn get_stop_sequences(&self) -> Vec<&'static str> {
|
||||
// Determine model type from model_name
|
||||
let model_name_lower = self.model_name.to_lowercase();
|
||||
|
||||
|
||||
if model_name_lower.contains("qwen") {
|
||||
vec![
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
"<|im_end|>", // Qwen ChatML format end token
|
||||
"<|endoftext|>", // Alternative end token
|
||||
"</s>", // Generic end of sequence
|
||||
"<|im_start|>", // Start of new message (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama") {
|
||||
} else if model_name_lower.contains("codellama") || model_name_lower.contains("code-llama")
|
||||
{
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<</SYS>>", // End of system message
|
||||
"[INST]", // Start of new instruction (shouldn't appear in response)
|
||||
"<<SYS>>", // Start of system (shouldn't appear in response)
|
||||
]
|
||||
} else if model_name_lower.contains("llama") {
|
||||
vec![
|
||||
@@ -374,9 +387,9 @@ impl EmbeddedProvider {
|
||||
]
|
||||
} else if model_name_lower.contains("mistral") {
|
||||
vec![
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
"</s>", // End of sequence
|
||||
"[/INST]", // End of instruction
|
||||
"<|im_end|>", // ChatML format
|
||||
]
|
||||
} else if model_name_lower.contains("vicuna") || model_name_lower.contains("wizard") {
|
||||
vec![
|
||||
@@ -391,7 +404,7 @@ impl EmbeddedProvider {
|
||||
"### Instruction:", // Alpaca format
|
||||
"### Response:", // Alpaca format
|
||||
"### Input:", // Alpaca format
|
||||
"</s>", // End of sequence
|
||||
"</s>", // End of sequence
|
||||
]
|
||||
} else {
|
||||
// Generic/unknown model - use common stop sequences
|
||||
@@ -411,14 +424,14 @@ impl EmbeddedProvider {
|
||||
fn clean_stop_sequences(&self, text: &str) -> String {
|
||||
let mut cleaned = text.to_string();
|
||||
let stop_sequences = self.get_stop_sequences();
|
||||
|
||||
|
||||
for stop_seq in &stop_sequences {
|
||||
if let Some(pos) = cleaned.find(stop_seq) {
|
||||
cleaned.truncate(pos);
|
||||
break; // Only remove the first occurrence to avoid over-truncation
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
|
||||
@@ -426,57 +439,64 @@ impl EmbeddedProvider {
|
||||
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
|
||||
|
||||
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
||||
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
|
||||
|
||||
|
||||
// Create the parent directory if it doesn't exist
|
||||
if let Some(parent) = model_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
|
||||
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
|
||||
info!("This is a one-time download that may take several minutes depending on your connection.");
|
||||
info!("Downloading to: {}", model_path.display());
|
||||
|
||||
|
||||
// Use curl with progress bar for download
|
||||
let output = Command::new("curl")
|
||||
.args([
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
"-o", model_path.to_str().unwrap(),
|
||||
"-L", // Follow redirects
|
||||
"-#", // Show progress bar
|
||||
"-f", // Fail on HTTP errors
|
||||
"-o",
|
||||
model_path.to_str().unwrap(),
|
||||
MODEL_URL,
|
||||
])
|
||||
.output()?;
|
||||
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
|
||||
// If curl is not available, provide alternative instructions
|
||||
if stderr.contains("command not found") || stderr.contains("not found") {
|
||||
error!("curl is not installed. Please install curl or manually download the model.");
|
||||
error!(
|
||||
"curl is not installed. Please install curl or manually download the model."
|
||||
);
|
||||
error!("Manual download instructions:");
|
||||
error!("1. Download from: {}", MODEL_URL);
|
||||
error!("2. Save to: {}", model_path.display());
|
||||
anyhow::bail!("curl not found - please install curl or download the model manually");
|
||||
anyhow::bail!(
|
||||
"curl not found - please install curl or download the model manually"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
anyhow::bail!("Failed to download model: {}", stderr);
|
||||
}
|
||||
|
||||
|
||||
// Verify the file was created and has reasonable size
|
||||
let metadata = fs::metadata(model_path)?;
|
||||
let size_mb = metadata.len() / (1024 * 1024);
|
||||
|
||||
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
|
||||
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||
|
||||
if size_mb < MODEL_SIZE_MB - 100 {
|
||||
// Allow some variance
|
||||
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||
anyhow::bail!(
|
||||
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
|
||||
size_mb, MODEL_SIZE_MB
|
||||
size_mb,
|
||||
MODEL_SIZE_MB
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
|
||||
Ok(())
|
||||
}
|
||||
@@ -541,20 +561,29 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
Err(_) => {
|
||||
if attempt < 4 {
|
||||
debug!("Session busy, retrying in {}ms (attempt {}/5)", 100 * (attempt + 1), attempt + 1);
|
||||
std::thread::sleep(std::time::Duration::from_millis(100 * (attempt + 1) as u64));
|
||||
debug!(
|
||||
"Session busy, retrying in {}ms (attempt {}/5)",
|
||||
100 * (attempt + 1),
|
||||
attempt + 1
|
||||
);
|
||||
std::thread::sleep(std::time::Duration::from_millis(
|
||||
100 * (attempt + 1) as u64,
|
||||
));
|
||||
} else {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Model is busy after 5 attempts, please try again")));
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!(
|
||||
"Model is busy after 5 attempts, please try again"
|
||||
)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut session = match session_guard {
|
||||
Some(ctx) => ctx,
|
||||
None => {
|
||||
let _ = tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
let _ =
|
||||
tx.blocking_send(Err(anyhow::anyhow!("Failed to acquire session lock")));
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -588,17 +617,33 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let mut accumulated_text = String::new();
|
||||
let mut token_count = 0;
|
||||
let mut unsent_tokens = String::new(); // Buffer for tokens we're holding back
|
||||
|
||||
|
||||
// Get stop sequences dynamically based on model type
|
||||
let stop_sequences = if prompt.contains("<|im_start|>") {
|
||||
// Qwen ChatML format detected
|
||||
vec!["<|im_end|>", "<|endoftext|>", "</s>", "<|im_start|>"]
|
||||
} else if prompt.contains("[INST]") || prompt.contains("<<SYS>>") {
|
||||
// Llama/CodeLlama format detected
|
||||
vec!["</s>", "[/INST]", "<</SYS>>", "[INST]", "<<SYS>>", "### Human:", "### Assistant:"]
|
||||
vec![
|
||||
"</s>",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
"[INST]",
|
||||
"<<SYS>>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
]
|
||||
} else {
|
||||
// Generic format
|
||||
vec!["</s>", "<|endoftext|>", "<|im_end|>", "### Human:", "### Assistant:", "[/INST]", "<</SYS>>"]
|
||||
vec![
|
||||
"</s>",
|
||||
"<|endoftext|>",
|
||||
"<|im_end|>",
|
||||
"### Human:",
|
||||
"### Assistant:",
|
||||
"[/INST]",
|
||||
"<</SYS>>",
|
||||
]
|
||||
};
|
||||
|
||||
// Stream tokens with proper limits
|
||||
@@ -622,10 +667,10 @@ impl LLMProvider for EmbeddedProvider {
|
||||
if hit_stop {
|
||||
// Before stopping, check if there might be an incomplete tool call
|
||||
// Look for JSON tool call patterns that might be cut off by the stop sequence
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#) ||
|
||||
accumulated_text.contains(r#"{"{""tool"":"#) ||
|
||||
accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
let has_potential_tool_call = accumulated_text.contains(r#"{"tool":"#)
|
||||
|| accumulated_text.contains(r#"{"{""tool"":"#)
|
||||
|| accumulated_text.contains(r#"{{""tool"":"#);
|
||||
|
||||
if has_potential_tool_call {
|
||||
// Check if the tool call appears to be complete (has closing brace after the stop sequence)
|
||||
let mut complete_tool_call = false;
|
||||
@@ -645,7 +690,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// If tool call is incomplete, send the raw content including stop sequences
|
||||
// so the main parser can handle it properly
|
||||
if !complete_tool_call {
|
||||
@@ -666,7 +711,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Send any remaining clean content before stopping (original behavior)
|
||||
let mut clean_accumulated = accumulated_text.clone();
|
||||
for stop_seq in &stop_sequences {
|
||||
@@ -675,7 +720,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Calculate what part we haven't sent yet
|
||||
let already_sent_len = accumulated_text.len() - unsent_tokens.len();
|
||||
if clean_accumulated.len() > already_sent_len {
|
||||
@@ -711,7 +756,8 @@ impl LLMProvider for EmbeddedProvider {
|
||||
|
||||
if might_be_stop {
|
||||
// Hold back tokens, but only for a limited buffer size
|
||||
if unsent_tokens.len() > 20 { // Don't hold back more than 20 characters
|
||||
if unsent_tokens.len() > 20 {
|
||||
// Don't hold back more than 20 characters
|
||||
// Send the oldest part and keep only the recent part that might be a stop sequence
|
||||
let to_send = &unsent_tokens[..unsent_tokens.len() - 10];
|
||||
if !to_send.is_empty() {
|
||||
@@ -755,7 +801,7 @@ impl LLMProvider for EmbeddedProvider {
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
usage: None, // Embedded models calculate usage differently
|
||||
usage: None, // Embedded models calculate usage differently
|
||||
tool_calls: None,
|
||||
};
|
||||
let _ = tx.blocking_send(Ok(final_chunk));
|
||||
@@ -771,11 +817,11 @@ impl LLMProvider for EmbeddedProvider {
|
||||
fn model(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self.temperature
|
||||
}
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Trait for LLM providers
|
||||
#[async_trait::async_trait]
|
||||
pub trait LLMProvider: Send + Sync {
|
||||
/// Generate a completion for the given messages
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse>;
|
||||
|
||||
|
||||
/// Stream a completion for the given messages
|
||||
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream>;
|
||||
|
||||
|
||||
/// Get the provider name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
|
||||
/// Get the model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
|
||||
/// Check if the provider supports native tool calling
|
||||
fn has_native_tool_calling(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
/// Check if the provider supports cache control
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
/// Get the configured max_tokens for this provider
|
||||
fn max_tokens(&self) -> u32;
|
||||
|
||||
|
||||
/// Get the configured temperature for this provider
|
||||
fn temperature(&self) -> f32;
|
||||
}
|
||||
@@ -60,15 +60,24 @@ pub enum CacheType {
|
||||
|
||||
impl CacheControl {
|
||||
pub fn ephemeral() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: None }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn five_minute() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: Some("5m".to_string()) }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: Some("5m".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn one_hour() -> Self {
|
||||
Self { cache_type: CacheType::Ephemeral, ttl: Some("1h".to_string()) }
|
||||
Self {
|
||||
cache_type: CacheType::Ephemeral,
|
||||
ttl: Some("1h".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +85,7 @@ impl CacheControl {
|
||||
pub struct Message {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
#[serde(skip)]
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<CacheControl>,
|
||||
@@ -110,7 +120,7 @@ pub struct CompletionChunk {
|
||||
pub content: String,
|
||||
pub finished: bool,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||
pub usage: Option<Usage>, // Add usage tracking for streaming
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -144,7 +154,7 @@ impl Message {
|
||||
fn generate_id() -> String {
|
||||
let now = chrono::Local::now();
|
||||
let timestamp = now.format("%H%M%S").to_string();
|
||||
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let random_chars: String = (0..3)
|
||||
.map(|_| {
|
||||
@@ -153,10 +163,10 @@ impl Message {
|
||||
chars[idx] as char
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
format!("{}-{}", timestamp, random_chars)
|
||||
}
|
||||
|
||||
|
||||
/// Create a new message with optional cache control
|
||||
pub fn new(role: MessageRole, content: String) -> Self {
|
||||
Self {
|
||||
@@ -168,7 +178,11 @@ impl Message {
|
||||
}
|
||||
|
||||
/// Create a new message with cache control
|
||||
pub fn with_cache_control(role: MessageRole, content: String, cache_control: CacheControl) -> Self {
|
||||
pub fn with_cache_control(
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
cache_control: CacheControl,
|
||||
) -> Self {
|
||||
Self {
|
||||
role,
|
||||
content,
|
||||
@@ -176,13 +190,13 @@ impl Message {
|
||||
cache_control: Some(cache_control),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create a message with cache control, with provider validation
|
||||
pub fn with_cache_control_validated(
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
role: MessageRole,
|
||||
content: String,
|
||||
cache_control: CacheControl,
|
||||
provider: &dyn LLMProvider
|
||||
provider: &dyn LLMProvider,
|
||||
) -> Self {
|
||||
if !provider.supports_cache_control() {
|
||||
tracing::warn!(
|
||||
@@ -192,7 +206,7 @@ impl Message {
|
||||
);
|
||||
return Self::new(role, content);
|
||||
}
|
||||
|
||||
|
||||
Self::with_cache_control(role, content, cache_control)
|
||||
}
|
||||
}
|
||||
@@ -210,16 +224,16 @@ impl ProviderRegistry {
|
||||
default_provider: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
|
||||
let name = provider.name().to_string();
|
||||
self.providers.insert(name.clone(), Box::new(provider));
|
||||
|
||||
|
||||
if self.default_provider.is_empty() {
|
||||
self.default_provider = name;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub fn set_default(&mut self, provider_name: &str) -> Result<()> {
|
||||
if !self.providers.contains_key(provider_name) {
|
||||
anyhow::bail!("Provider '{}' not found", provider_name);
|
||||
@@ -227,7 +241,7 @@ impl ProviderRegistry {
|
||||
self.default_provider = provider_name.to_string();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub fn get(&self, provider_name: Option<&str>) -> Result<&dyn LLMProvider> {
|
||||
let name = provider_name.unwrap_or(&self.default_provider);
|
||||
self.providers
|
||||
@@ -235,7 +249,7 @@ impl ProviderRegistry {
|
||||
.map(|p| p.as_ref())
|
||||
.ok_or_else(|| anyhow::anyhow!("Provider '{}' not found", name))
|
||||
}
|
||||
|
||||
|
||||
pub fn list_providers(&self) -> Vec<&str> {
|
||||
self.providers.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
@@ -255,10 +269,12 @@ mod tests {
|
||||
fn test_message_serialization_without_cache_control() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON without cache_control: {}", json);
|
||||
assert!(!json.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured");
|
||||
assert!(
|
||||
!json.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -269,16 +285,24 @@ mod tests {
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured");
|
||||
assert!(json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' value");
|
||||
assert!(json.contains("\"type\":"),
|
||||
"JSON should contain 'type' field in cache_control");
|
||||
assert!(!json.contains("null"),
|
||||
"JSON should not contain null values");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' value"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"type\":"),
|
||||
"JSON should contain 'type' field in cache_control"
|
||||
);
|
||||
assert!(
|
||||
!json.contains("null"),
|
||||
"JSON should not contain null values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -289,11 +313,20 @@ mod tests {
|
||||
CacheControl::five_minute(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with 5-minute cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||
assert!(json.contains("\"ttl\":\"5m\""), "JSON should contain ttl field with 5m value");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"ttl\":\"5m\""),
|
||||
"JSON should contain ttl field with 5m value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -304,39 +337,53 @@ mod tests {
|
||||
CacheControl::one_hour(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON with 1-hour cache_control: {}", json);
|
||||
assert!(json.contains("cache_control"), "JSON should contain 'cache_control' field");
|
||||
assert!(json.contains("ephemeral"), "JSON should contain 'ephemeral' type");
|
||||
assert!(json.contains("\"ttl\":\"1h\""), "JSON should contain ttl field with 1h value");
|
||||
assert!(
|
||||
json.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field"
|
||||
);
|
||||
assert!(
|
||||
json.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type"
|
||||
);
|
||||
assert!(
|
||||
json.contains("\"ttl\":\"1h\""),
|
||||
"JSON should contain ttl field with 1h value"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_id_generation() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
|
||||
|
||||
// Check that id is not empty
|
||||
assert!(!msg.id.is_empty(), "Message ID should not be empty");
|
||||
|
||||
|
||||
// Check format: HHMMSS-XXX
|
||||
let parts: Vec<&str> = msg.id.split('-').collect();
|
||||
assert_eq!(parts.len(), 2, "Message ID should have format HHMMSS-XXX");
|
||||
|
||||
|
||||
// Check timestamp part is 6 digits
|
||||
assert_eq!(parts[0].len(), 6, "Timestamp should be 6 digits (HHMMSS)");
|
||||
assert!(parts[0].chars().all(|c| c.is_ascii_digit()), "Timestamp should be all digits");
|
||||
|
||||
assert!(
|
||||
parts[0].chars().all(|c| c.is_ascii_digit()),
|
||||
"Timestamp should be all digits"
|
||||
);
|
||||
|
||||
// Check random part is 3 alpha characters
|
||||
assert_eq!(parts[1].len(), 3, "Random part should be 3 characters");
|
||||
assert!(parts[1].chars().all(|c| c.is_ascii_alphabetic()),
|
||||
"Random part should be all alphabetic characters");
|
||||
assert!(
|
||||
parts[1].chars().all(|c| c.is_ascii_alphabetic()),
|
||||
"Random part should be all alphabetic characters"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_id_uniqueness() {
|
||||
let msg1 = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let msg2 = Message::new(MessageRole::User, "Hello".to_string());
|
||||
|
||||
|
||||
// IDs should be different (due to random component)
|
||||
// Note: There's a tiny chance they could be the same, but very unlikely
|
||||
println!("msg1.id: {}, msg2.id: {}", msg1.id, msg2.id);
|
||||
@@ -346,9 +393,12 @@ mod tests {
|
||||
fn test_message_id_not_serialized() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Message JSON: {}", json);
|
||||
assert!(!json.contains("\"id\""), "JSON should not contain 'id' field");
|
||||
assert!(
|
||||
!json.contains("\"id\""),
|
||||
"JSON should not contain 'id' field"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -358,8 +408,14 @@ mod tests {
|
||||
"Hello".to_string(),
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
|
||||
assert!(!msg.id.is_empty(), "Message with cache control should have an ID");
|
||||
assert!(msg.id.contains('-'), "Message ID should contain hyphen separator");
|
||||
|
||||
assert!(
|
||||
!msg.id.is_empty(),
|
||||
"Message with cache control should have an ID"
|
||||
);
|
||||
assert!(
|
||||
msg.id.contains('-'),
|
||||
"Message ID should contain hyphen separator"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider,
|
||||
Message, MessageRole, Tool, ToolCall, Usage,
|
||||
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
|
||||
MessageRole, Tool, ToolCall, Usage,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -138,7 +138,8 @@ impl OpenAIProvider {
|
||||
debug!("Received stream completion marker");
|
||||
|
||||
// Send final chunk with accumulated content and tool calls
|
||||
if !accumulated_content.is_empty() || !current_tool_calls.is_empty() {
|
||||
if !accumulated_content.is_empty() || !current_tool_calls.is_empty()
|
||||
{
|
||||
let tool_calls = if current_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
@@ -188,8 +189,9 @@ impl OpenAIProvider {
|
||||
if let Some(index) = delta_tool_call.index {
|
||||
// Ensure we have enough tool calls in our vector
|
||||
while current_tool_calls.len() <= index {
|
||||
current_tool_calls
|
||||
.push(OpenAIStreamingToolCall::default());
|
||||
current_tool_calls.push(
|
||||
OpenAIStreamingToolCall::default(),
|
||||
);
|
||||
}
|
||||
|
||||
let tool_call = &mut current_tool_calls[index];
|
||||
@@ -198,11 +200,14 @@ impl OpenAIProvider {
|
||||
tool_call.id = Some(id.clone());
|
||||
}
|
||||
|
||||
if let Some(function) = &delta_tool_call.function {
|
||||
if let Some(function) =
|
||||
&delta_tool_call.function
|
||||
{
|
||||
if let Some(name) = &function.name {
|
||||
tool_call.name = Some(name.clone());
|
||||
}
|
||||
if let Some(arguments) = &function.arguments {
|
||||
if let Some(arguments) = &function.arguments
|
||||
{
|
||||
tool_call.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
@@ -246,7 +251,7 @@ impl OpenAIProvider {
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
let final_chunk = CompletionChunk {
|
||||
content: String::new(),
|
||||
finished: true,
|
||||
@@ -254,7 +259,7 @@ impl OpenAIProvider {
|
||||
usage: accumulated_usage.clone(),
|
||||
};
|
||||
let _ = tx.send(Ok(final_chunk)).await;
|
||||
|
||||
|
||||
accumulated_usage
|
||||
}
|
||||
}
|
||||
@@ -291,7 +296,11 @@ impl LLMProvider for OpenAIProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||
return Err(anyhow::anyhow!(
|
||||
"OpenAI API error {}: {}",
|
||||
status,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
|
||||
let openai_response: OpenAIResponse = response.json().await?;
|
||||
@@ -334,7 +343,10 @@ impl LLMProvider for OpenAIProvider {
|
||||
request.temperature,
|
||||
);
|
||||
|
||||
debug!("Sending streaming request to OpenAI API: model={}", self.model);
|
||||
debug!(
|
||||
"Sending streaming request to OpenAI API: model={}",
|
||||
self.model
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
@@ -350,7 +362,11 @@ impl LLMProvider for OpenAIProvider {
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(anyhow::anyhow!("OpenAI API error {}: {}", status, error_text));
|
||||
return Err(anyhow::anyhow!(
|
||||
"OpenAI API error {}: {}",
|
||||
status,
|
||||
error_text
|
||||
));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
@@ -384,11 +400,11 @@ impl LLMProvider for OpenAIProvider {
|
||||
// OpenAI models support native tool calling
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
fn max_tokens(&self) -> u32 {
|
||||
self.max_tokens.unwrap_or(16000)
|
||||
}
|
||||
|
||||
|
||||
fn temperature(&self) -> f32 {
|
||||
self._temperature.unwrap_or(0.1)
|
||||
}
|
||||
@@ -472,9 +488,9 @@ impl OpenAIStreamingToolCall {
|
||||
fn to_tool_call(&self) -> Option<ToolCall> {
|
||||
let id = self.id.as_ref()?;
|
||||
let name = self.name.as_ref()?;
|
||||
|
||||
|
||||
let args = serde_json::from_str(&self.arguments).unwrap_or(serde_json::Value::Null);
|
||||
|
||||
|
||||
Some(ToolCall {
|
||||
id: id.clone(),
|
||||
tool: name.clone(),
|
||||
|
||||
@@ -20,18 +20,24 @@ fn test_no_wrong_serialization_format() {
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("Ephemeral message JSON: {}", json);
|
||||
|
||||
|
||||
// Should NOT contain the wrong format
|
||||
assert!(!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path");
|
||||
assert!(!json.contains("cache_control.ephemeral"),
|
||||
"JSON should not contain 'cache_control.ephemeral' path");
|
||||
|
||||
assert!(
|
||||
!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path"
|
||||
);
|
||||
assert!(
|
||||
!json.contains("cache_control.ephemeral"),
|
||||
"JSON should not contain 'cache_control.ephemeral' path"
|
||||
);
|
||||
|
||||
// Should contain the correct format
|
||||
assert!(json.contains(r#""cache_control":{"type":"ephemeral"}"#),
|
||||
"JSON should contain correct cache_control format");
|
||||
assert!(
|
||||
json.contains(r#""cache_control":{"type":"ephemeral"}"#),
|
||||
"JSON should contain correct cache_control format"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -42,20 +48,28 @@ fn test_five_minute_no_wrong_format() {
|
||||
CacheControl::five_minute(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("5-minute message JSON: {}", json);
|
||||
|
||||
|
||||
// Should NOT contain the wrong format
|
||||
assert!(!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path");
|
||||
assert!(!json.contains("cache_control.ephemeral.ttl"),
|
||||
"JSON should not contain 'cache_control.ephemeral.ttl' path");
|
||||
|
||||
assert!(
|
||||
!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path"
|
||||
);
|
||||
assert!(
|
||||
!json.contains("cache_control.ephemeral.ttl"),
|
||||
"JSON should not contain 'cache_control.ephemeral.ttl' path"
|
||||
);
|
||||
|
||||
// Should contain the correct format with ttl as a direct field
|
||||
assert!(json.contains(r#""type":"ephemeral""#),
|
||||
"JSON should contain type field");
|
||||
assert!(json.contains(r#""ttl":"5m""#),
|
||||
"JSON should contain ttl field with value 5m");
|
||||
assert!(
|
||||
json.contains(r#""type":"ephemeral""#),
|
||||
"JSON should contain type field"
|
||||
);
|
||||
assert!(
|
||||
json.contains(r#""ttl":"5m""#),
|
||||
"JSON should contain ttl field with value 5m"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -66,44 +80,59 @@ fn test_one_hour_no_wrong_format() {
|
||||
CacheControl::one_hour(),
|
||||
);
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
|
||||
|
||||
println!("1-hour message JSON: {}", json);
|
||||
|
||||
|
||||
// Should NOT contain the wrong format
|
||||
assert!(!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path");
|
||||
assert!(!json.contains("cache_control.ephemeral.ttl"),
|
||||
"JSON should not contain 'cache_control.ephemeral.ttl' path");
|
||||
|
||||
assert!(
|
||||
!json.contains("system.0.cache_control"),
|
||||
"JSON should not contain 'system.0.cache_control' path"
|
||||
);
|
||||
assert!(
|
||||
!json.contains("cache_control.ephemeral.ttl"),
|
||||
"JSON should not contain 'cache_control.ephemeral.ttl' path"
|
||||
);
|
||||
|
||||
// Should contain the correct format with ttl as a direct field
|
||||
assert!(json.contains(r#""type":"ephemeral""#),
|
||||
"JSON should contain type field");
|
||||
assert!(json.contains(r#""ttl":"1h""#),
|
||||
"JSON should contain ttl field with value 1h");
|
||||
assert!(
|
||||
json.contains(r#""type":"ephemeral""#),
|
||||
"JSON should contain type field"
|
||||
);
|
||||
assert!(
|
||||
json.contains(r#""ttl":"1h""#),
|
||||
"JSON should contain ttl field with value 1h"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_control_structure_is_flat() {
|
||||
// Verify that the cache_control object has a flat structure
|
||||
// with 'type' and optional 'ttl' at the same level
|
||||
|
||||
|
||||
let cache_control = CacheControl::five_minute();
|
||||
let json_value = serde_json::to_value(&cache_control).unwrap();
|
||||
|
||||
println!("Cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
|
||||
|
||||
|
||||
println!(
|
||||
"Cache control as JSON value: {}",
|
||||
serde_json::to_string_pretty(&json_value).unwrap()
|
||||
);
|
||||
|
||||
let obj = json_value.as_object().expect("Should be an object");
|
||||
|
||||
|
||||
// Should have exactly 2 keys at the top level
|
||||
assert_eq!(obj.len(), 2, "Cache control should have exactly 2 top-level fields");
|
||||
|
||||
assert_eq!(
|
||||
obj.len(),
|
||||
2,
|
||||
"Cache control should have exactly 2 top-level fields"
|
||||
);
|
||||
|
||||
// Both 'type' and 'ttl' should be at the same level
|
||||
assert!(obj.contains_key("type"), "Should have 'type' field");
|
||||
assert!(obj.contains_key("ttl"), "Should have 'ttl' field");
|
||||
|
||||
|
||||
// 'type' should be a string, not an object
|
||||
assert!(obj["type"].is_string(), "'type' should be a string value");
|
||||
|
||||
|
||||
// 'ttl' should be a string, not nested
|
||||
assert!(obj["ttl"].is_string(), "'ttl' should be a string value");
|
||||
}
|
||||
@@ -112,20 +141,30 @@ fn test_cache_control_structure_is_flat() {
|
||||
fn test_ephemeral_cache_control_structure() {
|
||||
let cache_control = CacheControl::ephemeral();
|
||||
let json_value = serde_json::to_value(&cache_control).unwrap();
|
||||
|
||||
println!("Ephemeral cache control as JSON value: {}", serde_json::to_string_pretty(&json_value).unwrap());
|
||||
|
||||
|
||||
println!(
|
||||
"Ephemeral cache control as JSON value: {}",
|
||||
serde_json::to_string_pretty(&json_value).unwrap()
|
||||
);
|
||||
|
||||
let obj = json_value.as_object().expect("Should be an object");
|
||||
|
||||
|
||||
// Should have exactly 1 key (only 'type', no 'ttl')
|
||||
assert_eq!(obj.len(), 1, "Ephemeral cache control should have exactly 1 top-level field");
|
||||
|
||||
assert_eq!(
|
||||
obj.len(),
|
||||
1,
|
||||
"Ephemeral cache control should have exactly 1 top-level field"
|
||||
);
|
||||
|
||||
// Should have 'type' field
|
||||
assert!(obj.contains_key("type"), "Should have 'type' field");
|
||||
|
||||
|
||||
// Should NOT have 'ttl' field
|
||||
assert!(!obj.contains_key("ttl"), "Ephemeral should not have 'ttl' field");
|
||||
|
||||
assert!(
|
||||
!obj.contains_key("ttl"),
|
||||
"Ephemeral should not have 'ttl' field"
|
||||
);
|
||||
|
||||
// 'type' should be a string with value "ephemeral"
|
||||
assert_eq!(obj["type"].as_str().unwrap(), "ephemeral");
|
||||
}
|
||||
|
||||
@@ -10,13 +10,19 @@ use serde_json::json;
|
||||
fn test_ephemeral_cache_control_serialization() {
|
||||
let cache_control = CacheControl::ephemeral();
|
||||
let json = serde_json::to_value(&cache_control).unwrap();
|
||||
|
||||
println!("Ephemeral cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
assert_eq!(json, json!({
|
||||
"type": "ephemeral"
|
||||
}));
|
||||
|
||||
|
||||
println!(
|
||||
"Ephemeral cache_control JSON: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"type": "ephemeral"
|
||||
})
|
||||
);
|
||||
|
||||
// Verify no ttl field is present
|
||||
assert!(!json.as_object().unwrap().contains_key("ttl"));
|
||||
}
|
||||
@@ -25,26 +31,38 @@ fn test_ephemeral_cache_control_serialization() {
|
||||
fn test_five_minute_cache_control_serialization() {
|
||||
let cache_control = CacheControl::five_minute();
|
||||
let json = serde_json::to_value(&cache_control).unwrap();
|
||||
|
||||
println!("5-minute cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
assert_eq!(json, json!({
|
||||
"type": "ephemeral",
|
||||
"ttl": "5m"
|
||||
}));
|
||||
|
||||
println!(
|
||||
"5-minute cache_control JSON: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"type": "ephemeral",
|
||||
"ttl": "5m"
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_hour_cache_control_serialization() {
|
||||
let cache_control = CacheControl::one_hour();
|
||||
let json = serde_json::to_value(&cache_control).unwrap();
|
||||
|
||||
println!("1-hour cache_control JSON: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
assert_eq!(json, json!({
|
||||
"type": "ephemeral",
|
||||
"ttl": "1h"
|
||||
}));
|
||||
|
||||
println!(
|
||||
"1-hour cache_control JSON: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
json,
|
||||
json!({
|
||||
"type": "ephemeral",
|
||||
"ttl": "1h"
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -54,11 +72,16 @@ fn test_message_with_ephemeral_cache_control() {
|
||||
"System prompt".to_string(),
|
||||
CacheControl::ephemeral(),
|
||||
);
|
||||
|
||||
|
||||
let json = serde_json::to_value(&msg).unwrap();
|
||||
println!("Message with ephemeral cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||
println!(
|
||||
"Message with ephemeral cache_control: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
let cache_control = json
|
||||
.get("cache_control")
|
||||
.expect("cache_control field should exist");
|
||||
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||
assert!(!cache_control.as_object().unwrap().contains_key("ttl"));
|
||||
}
|
||||
@@ -70,11 +93,16 @@ fn test_message_with_five_minute_cache_control() {
|
||||
"System prompt".to_string(),
|
||||
CacheControl::five_minute(),
|
||||
);
|
||||
|
||||
|
||||
let json = serde_json::to_value(&msg).unwrap();
|
||||
println!("Message with 5-minute cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||
println!(
|
||||
"Message with 5-minute cache_control: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
let cache_control = json
|
||||
.get("cache_control")
|
||||
.expect("cache_control field should exist");
|
||||
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||
assert_eq!(cache_control.get("ttl").unwrap(), "5m");
|
||||
}
|
||||
@@ -86,11 +114,16 @@ fn test_message_with_one_hour_cache_control() {
|
||||
"System prompt".to_string(),
|
||||
CacheControl::one_hour(),
|
||||
);
|
||||
|
||||
|
||||
let json = serde_json::to_value(&msg).unwrap();
|
||||
println!("Message with 1-hour cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
let cache_control = json.get("cache_control").expect("cache_control field should exist");
|
||||
println!(
|
||||
"Message with 1-hour cache_control: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
let cache_control = json
|
||||
.get("cache_control")
|
||||
.expect("cache_control field should exist");
|
||||
assert_eq!(cache_control.get("type").unwrap(), "ephemeral");
|
||||
assert_eq!(cache_control.get("ttl").unwrap(), "1h");
|
||||
}
|
||||
@@ -98,10 +131,13 @@ fn test_message_with_one_hour_cache_control() {
|
||||
#[test]
|
||||
fn test_message_without_cache_control() {
|
||||
let msg = Message::new(MessageRole::User, "Hello".to_string());
|
||||
|
||||
|
||||
let json = serde_json::to_value(&msg).unwrap();
|
||||
println!("Message without cache_control: {}", serde_json::to_string(&json).unwrap());
|
||||
|
||||
println!(
|
||||
"Message without cache_control: {}",
|
||||
serde_json::to_string(&json).unwrap()
|
||||
);
|
||||
|
||||
// cache_control field should not be present when not set
|
||||
assert!(!json.as_object().unwrap().contains_key("cache_control"));
|
||||
}
|
||||
@@ -110,9 +146,9 @@ fn test_message_without_cache_control() {
|
||||
fn test_cache_control_json_format_ephemeral() {
|
||||
let cache_control = CacheControl::ephemeral();
|
||||
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||
|
||||
|
||||
println!("Ephemeral JSON string: {}", json_str);
|
||||
|
||||
|
||||
// Verify exact JSON format
|
||||
assert_eq!(json_str, r#"{"type":"ephemeral"}"#);
|
||||
}
|
||||
@@ -121,9 +157,9 @@ fn test_cache_control_json_format_ephemeral() {
|
||||
fn test_cache_control_json_format_five_minute() {
|
||||
let cache_control = CacheControl::five_minute();
|
||||
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||
|
||||
|
||||
println!("5-minute JSON string: {}", json_str);
|
||||
|
||||
|
||||
// Verify exact JSON format
|
||||
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"5m"}"#);
|
||||
}
|
||||
@@ -132,9 +168,9 @@ fn test_cache_control_json_format_five_minute() {
|
||||
fn test_cache_control_json_format_one_hour() {
|
||||
let cache_control = CacheControl::one_hour();
|
||||
let json_str = serde_json::to_string(&cache_control).unwrap();
|
||||
|
||||
|
||||
println!("1-hour JSON string: {}", json_str);
|
||||
|
||||
|
||||
// Verify exact JSON format
|
||||
assert_eq!(json_str, r#"{"type":"ephemeral","ttl":"1h"}"#);
|
||||
}
|
||||
@@ -143,7 +179,7 @@ fn test_cache_control_json_format_one_hour() {
|
||||
fn test_deserialization_ephemeral() {
|
||||
let json_str = r#"{"type":"ephemeral"}"#;
|
||||
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
|
||||
assert_eq!(cache_control.ttl, None);
|
||||
}
|
||||
|
||||
@@ -151,7 +187,7 @@ fn test_deserialization_ephemeral() {
|
||||
fn test_deserialization_five_minute() {
|
||||
let json_str = r#"{"type":"ephemeral","ttl":"5m"}"#;
|
||||
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
|
||||
assert_eq!(cache_control.ttl, Some("5m".to_string()));
|
||||
}
|
||||
|
||||
@@ -159,6 +195,6 @@ fn test_deserialization_five_minute() {
|
||||
fn test_deserialization_one_hour() {
|
||||
let json_str = r#"{"type":"ephemeral","ttl":"1h"}"#;
|
||||
let cache_control: CacheControl = serde_json::from_str(json_str).unwrap();
|
||||
|
||||
|
||||
assert_eq!(cache_control.ttl, Some("1h".to_string()));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user