adds cache_control
This commit is contained in:
@@ -39,10 +39,7 @@
|
||||
//! // Create a completion request
|
||||
//! let request = CompletionRequest {
|
||||
//! messages: vec![
|
||||
//! Message {
|
||||
//! role: MessageRole::User,
|
||||
//! content: "Hello! How are you?".to_string(),
|
||||
//! },
|
||||
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
|
||||
//! ],
|
||||
//! max_tokens: Some(1000),
|
||||
//! temperature: Some(0.7),
|
||||
@@ -241,6 +238,15 @@ impl DatabricksProvider {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_cache_control(cache_control: &crate::CacheControl) -> DatabricksCacheControl {
|
||||
let cache_type = match cache_control {
|
||||
crate::CacheControl::Ephemeral => "ephemeral",
|
||||
crate::CacheControl::FiveMinute => "5minute",
|
||||
crate::CacheControl::OneHour => "1hour",
|
||||
};
|
||||
DatabricksCacheControl { cache_type: cache_type.to_string() }
|
||||
}
|
||||
|
||||
fn convert_messages(&self, messages: &[Message]) -> Result<Vec<DatabricksMessage>> {
|
||||
let mut databricks_messages = Vec::new();
|
||||
|
||||
@@ -251,9 +257,24 @@ impl DatabricksProvider {
|
||||
MessageRole::Assistant => "assistant",
|
||||
};
|
||||
|
||||
// If message has cache_control, use content array format
|
||||
let content = if message.cache_control.is_some() {
|
||||
// Use array format with cache_control
|
||||
let content_block = DatabricksContent::Text {
|
||||
content_type: "text".to_string(),
|
||||
text: message.content.clone(),
|
||||
cache_control: message.cache_control.as_ref()
|
||||
.map(Self::convert_cache_control),
|
||||
};
|
||||
serde_json::to_value(vec![content_block])?
|
||||
} else {
|
||||
// Use simple string format
|
||||
serde_json::Value::String(message.content.clone())
|
||||
};
|
||||
|
||||
databricks_messages.push(DatabricksMessage {
|
||||
role: role.to_string(),
|
||||
content: Some(message.content.clone()),
|
||||
content: Some(content),
|
||||
tool_calls: None, // Only used in responses, not requests
|
||||
});
|
||||
}
|
||||
@@ -864,8 +885,22 @@ impl LLMProvider for DatabricksProvider {
|
||||
let content = databricks_response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.as_ref())
|
||||
.cloned()
|
||||
.and_then(|choice| {
|
||||
choice.message.content.as_ref().map(|c| {
|
||||
// Handle both string and array formats
|
||||
if let Some(s) = c.as_str() {
|
||||
s.to_string()
|
||||
} else if let Some(arr) = c.as_array() {
|
||||
// Extract text from content blocks
|
||||
arr.iter()
|
||||
.filter_map(|block| block.get("text").and_then(|t| t.as_str()))
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Check if there are tool calls in the response
|
||||
@@ -1037,6 +1072,11 @@ impl LLMProvider for DatabricksProvider {
|
||||
// This includes Claude, Llama, DBRX, and most other models on the platform
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_cache_control(&self) -> bool {
|
||||
// Databricks supports cache control when using Anthropic models
|
||||
self.model.contains("claude")
|
||||
}
|
||||
}
|
||||
|
||||
// Databricks API request/response structures
|
||||
@@ -1064,10 +1104,29 @@ struct DatabricksFunction {
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatabricksCacheControl {
|
||||
#[serde(rename = "type")]
|
||||
cache_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum DatabricksContent {
|
||||
Text {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<DatabricksCacheControl>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct DatabricksMessage {
|
||||
role: String,
|
||||
content: Option<String>, // Make content optional since tool calls might not have content
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<serde_json::Value>, // Can be string or array of content blocks
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<DatabricksToolCall>>, // Add tool_calls field for responses
|
||||
}
|
||||
@@ -1154,18 +1213,9 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: MessageRole::System,
|
||||
content: "You are a helpful assistant.".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::User,
|
||||
content: "Hello!".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: MessageRole::Assistant,
|
||||
content: "Hi there!".to_string(),
|
||||
},
|
||||
Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
|
||||
Message::new(MessageRole::User, "Hello!".to_string()),
|
||||
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
|
||||
];
|
||||
|
||||
let databricks_messages = provider.convert_messages(&messages).unwrap();
|
||||
@@ -1187,10 +1237,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![Message {
|
||||
role: MessageRole::User,
|
||||
content: "Test message".to_string(),
|
||||
}];
|
||||
let messages = vec![Message::new(MessageRole::User, "Test message".to_string())];
|
||||
|
||||
let request_body = provider
|
||||
.create_request_body(&messages, None, false, 1000, 0.5)
|
||||
@@ -1273,4 +1320,53 @@ mod tests {
|
||||
assert!(llama_provider.has_native_tool_calling());
|
||||
assert!(dbrx_provider.has_native_tool_calling());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_control_serialization() {
|
||||
let provider = DatabricksProvider::from_token(
|
||||
"https://test.databricks.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"databricks-claude-sonnet-4".to_string(),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test message WITHOUT cache_control - should use string format
|
||||
let messages_without = vec![Message::new(MessageRole::User, "Hello".to_string())];
|
||||
let databricks_messages_without = provider.convert_messages(&messages_without).unwrap();
|
||||
let json_without = serde_json::to_string(&databricks_messages_without).unwrap();
|
||||
|
||||
println!("JSON without cache_control: {}", json_without);
|
||||
assert!(!json_without.contains("cache_control"),
|
||||
"JSON should not contain 'cache_control' field when not configured");
|
||||
|
||||
// Test message WITH cache_control - should use array format
|
||||
let messages_with = vec![Message::with_cache_control(
|
||||
MessageRole::User,
|
||||
"Hello".to_string(),
|
||||
crate::CacheControl::Ephemeral,
|
||||
)];
|
||||
let databricks_messages_with = provider.convert_messages(&messages_with).unwrap();
|
||||
let json_with = serde_json::to_string(&databricks_messages_with).unwrap();
|
||||
|
||||
println!("JSON with cache_control: {}", json_with);
|
||||
assert!(json_with.contains("cache_control"),
|
||||
"JSON should contain 'cache_control' field when configured");
|
||||
assert!(json_with.contains("ephemeral"),
|
||||
"JSON should contain 'ephemeral' type");
|
||||
assert!(!json_with.contains("null"),
|
||||
"JSON should not contain null values");
|
||||
|
||||
// Verify the structure is correct
|
||||
let msg_with = &databricks_messages_with[0];
|
||||
if let Some(content) = &msg_with.content {
|
||||
if let Some(arr) = content.as_array() {
|
||||
assert_eq!(arr.len(), 1, "Content array should have one element");
|
||||
assert!(arr[0].get("cache_control").is_some(), "Content should have cache_control");
|
||||
} else {
|
||||
panic!("Content should be an array when cache_control is present");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user