adds cache_control

This commit is contained in:
Jochen
2025-11-18 22:38:52 +11:00
parent 39efa24c55
commit 296bf5a449
7 changed files with 466 additions and 125 deletions

View File

@@ -39,10 +39,7 @@
//! // Create a completion request
//! let request = CompletionRequest {
//! messages: vec![
//! Message {
//! role: MessageRole::User,
//! content: "Hello! How are you?".to_string(),
//! },
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
@@ -241,6 +238,15 @@ impl DatabricksProvider {
.collect()
}
fn convert_cache_control(cache_control: &crate::CacheControl) -> DatabricksCacheControl {
let cache_type = match cache_control {
crate::CacheControl::Ephemeral => "ephemeral",
crate::CacheControl::FiveMinute => "5minute",
crate::CacheControl::OneHour => "1hour",
};
DatabricksCacheControl { cache_type: cache_type.to_string() }
}
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
let mut databricks_messages = Vec::new();
@@ -251,9 +257,24 @@ impl DatabricksProvider {
MessageRole::Assistant => "assistant",
};
// If message has cache_control, use content array format
let content = if message.cache_control.is_some() {
// Use array format with cache_control
let content_block = DatabricksContent::Text {
content_type: "text".to_string(),
text: message.content.clone(),
cache_control: message.cache_control.as_ref()
.map(Self::convert_cache_control),
};
serde_json::to_value(vec![content_block])?
} else {
// Use simple string format
serde_json::Value::String(message.content.clone())
};
databricks_messages.push(DatabricksMessage {
role: role.to_string(),
content: Some(message.content.clone()),
content: Some(content),
tool_calls: None, // Only used in responses, not requests
});
}
@@ -864,8 +885,22 @@ impl LLMProvider for DatabricksProvider {
let content = databricks_response
.choices
.first()
.and_then(|choice| choice.message.content.as_ref())
.cloned()
.and_then(|choice| {
choice.message.content.as_ref().map(|c| {
// Handle both string and array formats
if let Some(s) = c.as_str() {
s.to_string()
} else if let Some(arr) = c.as_array() {
// Extract text from content blocks
arr.iter()
.filter_map(|block| block.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("")
} else {
String::new()
}
})
})
.unwrap_or_default();
// Check if there are tool calls in the response
@@ -1037,6 +1072,11 @@ impl LLMProvider for DatabricksProvider {
// This includes Claude, Llama, DBRX, and most other models on the platform
true
}
fn supports_cache_control(&self) -> bool {
// Databricks supports cache control when using Anthropic models
self.model.contains("claude")
}
}
// Databricks API request/response structures
@@ -1064,10 +1104,29 @@ struct DatabricksFunction {
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksCacheControl {
#[serde(rename = "type")]
cache_type: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
enum DatabricksContent {
Text {
#[serde(rename = "type")]
content_type: String,
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<DatabricksCacheControl>,
},
}
#[derive(Debug, Serialize, Deserialize)]
struct DatabricksMessage {
role: String,
content: Option<String>, // Make content optional since tool calls might not have content
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<serde_json::Value>, // Can be string or array of content blocks
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
}
@@ -1154,18 +1213,9 @@ mod tests {
.unwrap();
let messages = vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
Message {
role: MessageRole::Assistant,
content: "Hi there!".to_string(),
},
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
Message::new(MessageRole::User, "Hello!".to_string()),
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
];
let databricks_messages = provider.convert_messages(&messages).unwrap();
@@ -1187,10 +1237,7 @@ mod tests {
)
.unwrap();
let messages = vec![Message {
role: MessageRole::User,
content: "Test message".to_string(),
}];
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
let request_body = provider
.create_request_body(&messages, None, false, 1000, 0.5)
@@ -1273,4 +1320,53 @@ mod tests {
assert!(llama_provider.has_native_tool_calling());
assert!(dbrx_provider.has_native_tool_calling());
}
#[test]
fn test_cache_control_serialization() {
let provider = DatabricksProvider::from_token(
"https://test.databricks.com".to_string(),
"test-token".to_string(),
"databricks-claude-sonnet-4".to_string(),
None,
None,
)
.unwrap();
// Test message WITHOUT cache_control - should use string format
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
println!("JSON without cache_control: {}", json_without);
assert!(!json_without.contains("cache_control"),
"JSON should not contain 'cache_control' field when not configured");
// Test message WITH cache_control - should use array format
let messages_with = vec![Message::with_cache_control(
MessageRole::User,
"Hello".to_string(),
crate::CacheControl::Ephemeral,
)];
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
println!("JSON with cache_control: {}", json_with);
assert!(json_with.contains("cache_control"),
"JSON should contain 'cache_control' field when configured");
assert!(json_with.contains("ephemeral"),
"JSON should contain 'ephemeral' type");
assert!(!json_with.contains("null"),
"JSON should not contain null values");
// Verify the structure is correct
let msg_with = &databricks_messages_with[0];
if let Some(content) = &msg_with.content {
if let Some(arr) = content.as_array() {
assert_eq!(arr.len(), 1, "Content array should have one element");
assert!(arr[0].get("cache_control").is_some(), "Content should have cache_control");
} else {
panic!("Content should be an array when cache_control is present");
}
}
}
}