Add Google Gemini provider support

- Add GeminiProvider with streaming and native tool calling
- Support gemini-2.5-pro, gemini-2.0-flash, gemini-1.5-pro/flash models
- Model-specific context window detection (1M-2M tokens)
- Message conversion: assistant -> model role mapping
- System messages extracted to system_instruction field
- Tool schema conversion with functionCall/functionResponse parts
- SSE streaming with JSON array buffer parsing
- 8 unit tests for conversion and parsing logic
- Register provider in g3-core and validate in g3-cli
This commit is contained in:
Dhanji R. Prasanna
2026-01-29 10:11:42 +11:00
parent fe33568ee0
commit 735e9c9312
6 changed files with 860 additions and 2 deletions

View File

@@ -138,7 +138,7 @@ pub fn load_config_with_cli_overrides(cli: &Cli) -> Result<Config> {
// Validate provider if specified
if let Some(ref provider) = cli.provider {
let valid_providers = ["anthropic", "databricks", "embedded", "openai"];
let valid_providers = ["anthropic", "databricks", "embedded", "gemini", "openai"];
let provider_type = provider.split('.').next().unwrap_or(provider);
if !valid_providers.contains(&provider_type) {
return Err(anyhow::anyhow!(

View File

@@ -46,6 +46,10 @@ pub struct ProvidersConfig {
#[serde(default)]
pub embedded: HashMap<String, EmbeddedConfig>,
/// Named Gemini provider configs
#[serde(default)]
pub gemini: HashMap<String, GeminiConfig>,
/// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.)
#[serde(default)]
pub openai_compatible: HashMap<String, OpenAIConfig>,
@@ -92,6 +96,14 @@ pub struct EmbeddedConfig {
pub threads: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiConfig {
pub api_key: String,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub max_context_length: Option<u32>,
@@ -230,6 +242,7 @@ impl Default for Config {
openai: HashMap::new(),
databricks: databricks_configs,
embedded: HashMap::new(),
gemini: HashMap::new(),
openai_compatible: HashMap::new(),
},
agent: AgentConfig {
@@ -445,11 +458,20 @@ impl Config {
);
}
}
"gemini" => {
if !self.providers.gemini.contains_key(config_name) {
anyhow::bail!(
"Provider config 'gemini.{}' not found. Available: {:?}",
config_name,
self.providers.gemini.keys().collect::<Vec<_>>()
);
}
}
_ => {
// Check openai_compatible providers
if !self.providers.openai_compatible.contains_key(provider_type) {
anyhow::bail!(
"Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, or openai_compatible names",
"Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, gemini, or openai_compatible names",
provider_type
);
}
@@ -550,6 +572,18 @@ impl Config {
));
}
}
"gemini" => {
if let Some(ref mut gemini_config) =
config.providers.gemini.get_mut(&config_name)
{
gemini_config.model = model;
} else {
return Err(anyhow::anyhow!(
"Provider config 'gemini.{}' not found.",
config_name
));
}
}
_ => {
// Check openai_compatible
if let Some(ref mut compat_config) =
@@ -635,6 +669,11 @@ impl Config {
self.providers.embedded.get(name)
}
/// Get Gemini config by name
pub fn get_gemini_config(&self, name: &str) -> Option<&GeminiConfig> {
self.providers.gemini.get(name)
}
/// Get the current default provider's config
pub fn get_default_provider_config(&self) -> Result<ProviderConfigRef<'_>> {
let (provider_type, config_name) =
@@ -665,6 +704,12 @@ impl Config {
.get(&config_name)
.map(ProviderConfigRef::Embedded)
.ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name)),
"gemini" => self
.providers
.gemini
.get(&config_name)
.map(ProviderConfigRef::Gemini)
.ok_or_else(|| anyhow::anyhow!("Gemini config '{}' not found", config_name)),
_ => self
.providers
.openai_compatible
@@ -684,6 +729,7 @@ pub enum ProviderConfigRef<'a> {
OpenAI(&'a OpenAIConfig),
Databricks(&'a DatabricksConfig),
Embedded(&'a EmbeddedConfig),
Gemini(&'a GeminiConfig),
OpenAICompatible(&'a OpenAIConfig),
}

View File

@@ -708,6 +708,18 @@ impl<W: UiWriter> Agent<W> {
16384 // Conservative default for other Databricks models
}
}
"gemini" => {
// Gemini models - use provider's context_window_size()
if let Some(ctx_size) = provider.context_window_size() {
debug!(
"Using context window size {} from Gemini provider",
ctx_size
);
ctx_size
} else {
1_000_000 // Default for Gemini models
}
}
_ => config.agent.fallback_default_max_tokens as u32,
};

View File

@@ -57,6 +57,7 @@ pub async fn register_providers(
register_openai_providers(config, providers_to_register, &mut registry)?;
register_openai_compatible_providers(config, providers_to_register, &mut registry)?;
register_anthropic_providers(config, providers_to_register, &mut registry)?;
register_gemini_providers(config, providers_to_register, &mut registry)?;
register_databricks_providers(config, providers_to_register, &mut registry).await?;
// Set default provider
@@ -162,6 +163,27 @@ fn register_anthropic_providers(
Ok(())
}
/// Register Gemini providers from configuration.
fn register_gemini_providers(
config: &Config,
providers_to_register: &[String],
registry: &mut ProviderRegistry,
) -> Result<()> {
for (name, gemini_config) in &config.providers.gemini {
if should_register(providers_to_register, "gemini", name) {
let gemini_provider = g3_providers::GeminiProvider::new_with_name(
format!("gemini.{}", name),
gemini_config.api_key.clone(),
Some(gemini_config.model.clone()),
gemini_config.max_tokens,
gemini_config.temperature,
)?;
registry.register(gemini_provider);
}
}
Ok(())
}
/// Register Databricks providers from configuration.
///
/// This is async because OAuth authentication requires async operations.

View File

@@ -0,0 +1,776 @@
//! Google Gemini provider implementation for the g3-providers crate.
//!
//! This module provides an implementation of the `LLMProvider` trait for Google's Gemini models,
//! supporting both completion and streaming modes through the Gemini API.
//!
//! # Features
//!
//! - Support for Gemini models (gemini-2.0-flash, gemini-1.5-pro, etc.)
//! - Both completion and streaming response modes
//! - Proper message format conversion between g3 and Gemini formats
//! - Native tool calling support
//!
//! # Usage
//!
//! ```rust,no_run
//! use g3_providers::{GeminiProvider, LLMProvider, CompletionRequest, Message, MessageRole};
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let provider = GeminiProvider::new(
//! "your-api-key".to_string(),
//! Some("gemini-2.0-flash".to_string()),
//! Some(8192),
//! Some(0.7),
//! )?;
//!
//! let request = CompletionRequest {
//! messages: vec![
//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()),
//! Message::new(MessageRole::User, "Hello! How are you?".to_string()),
//! ],
//! max_tokens: Some(1000),
//! temperature: Some(0.7),
//! stream: false,
//! tools: None,
//! disable_thinking: false,
//! };
//!
//! let response = provider.complete(request).await?;
//! println!("Response: {}", response.content);
//!
//! Ok(())
//! }
//! ```
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, error};
use crate::{
CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message,
MessageRole, Tool, ToolCall, Usage, streaming::make_text_chunk,
};
// ============================================================================
// Provider Struct
// ============================================================================
#[derive(Clone)]
pub struct GeminiProvider {
client: Client,
api_key: String,
model: String,
max_tokens: u32,
temperature: f32,
name: String,
}
impl GeminiProvider {
pub fn new(
api_key: String,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gemini-2.0-flash".to_string()),
max_tokens: max_tokens.unwrap_or(16384),
temperature: temperature.unwrap_or(0.1),
name: "gemini".to_string(),
})
}
pub fn new_with_name(
name: String,
api_key: String,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
) -> Result<Self> {
Ok(Self {
client: Client::new(),
api_key,
model: model.unwrap_or_else(|| "gemini-2.0-flash".to_string()),
max_tokens: max_tokens.unwrap_or(16384),
temperature: temperature.unwrap_or(0.1),
name,
})
}
fn get_api_url(&self, stream: bool) -> String {
let method = if stream { "streamGenerateContent" } else { "generateContent" };
format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}",
self.model, method, self.api_key
)
}
}
// ============================================================================
// Gemini API Request/Response Types
// ============================================================================
/// Gemini API request body
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
generation_config: GeminiGenerationConfig,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct GeminiContent {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
parts: Vec<GeminiPart>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
enum GeminiPart {
Text {
text: String,
},
FunctionCall {
#[serde(rename = "functionCall")]
function_call: GeminiFunctionCall,
},
FunctionResponse {
#[serde(rename = "functionResponse")]
function_response: GeminiFunctionResponse,
},
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct GeminiFunctionCall {
name: String,
args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct GeminiFunctionResponse {
name: String,
response: serde_json::Value,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
description: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<serde_json::Value>,
}
/// Gemini API response
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
candidates: Option<Vec<GeminiCandidate>>,
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate {
content: Option<GeminiContent>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
prompt_token_count: Option<u32>,
candidates_token_count: Option<u32>,
total_token_count: Option<u32>,
}
// ============================================================================
// Message Conversion
// ============================================================================
/// Convert g3 messages to Gemini format
///
/// Key differences:
/// - Gemini uses "model" instead of "assistant"
/// - System messages go in system_instruction, not contents
/// - Gemini uses "parts" array with text objects
fn convert_messages(messages: &[Message]) -> (Vec<GeminiContent>, Option<GeminiContent>) {
let mut contents = Vec::new();
let mut system_instruction = None;
for msg in messages {
match msg.role {
MessageRole::System => {
// System messages go to system_instruction
system_instruction = Some(GeminiContent {
role: None, // system_instruction doesn't need a role
parts: vec![GeminiPart::Text { text: msg.content.clone() }],
});
}
MessageRole::User => {
contents.push(GeminiContent {
role: Some("user".to_string()),
parts: vec![GeminiPart::Text { text: msg.content.clone() }],
});
}
MessageRole::Assistant => {
// Gemini uses "model" instead of "assistant"
contents.push(GeminiContent {
role: Some("model".to_string()),
parts: vec![GeminiPart::Text { text: msg.content.clone() }],
});
}
}
}
(contents, system_instruction)
}
/// Convert g3 tools to Gemini format
fn convert_tools(tools: &[Tool]) -> Vec<GeminiTool> {
let declarations: Vec<GeminiFunctionDeclaration> = tools
.iter()
.map(|tool| GeminiFunctionDeclaration {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: if tool.input_schema.is_null() {
None
} else {
Some(tool.input_schema.clone())
},
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
}
/// Extract text content from Gemini response parts
fn extract_text_from_parts(parts: &[GeminiPart]) -> String {
parts
.iter()
.filter_map(|part| {
if let GeminiPart::Text { text } = part {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("")
}
/// Extract tool calls from Gemini response parts
fn extract_tool_calls_from_parts(parts: &[GeminiPart]) -> Vec<ToolCall> {
parts
.iter()
.filter_map(|part| {
if let GeminiPart::FunctionCall { function_call } = part {
Some(ToolCall {
id: format!("call_{}", nanoid::nanoid!(8)),
tool: function_call.name.clone(),
args: function_call.args.clone(),
})
} else {
None
}
})
.collect()
}
/// Convert Gemini usage metadata to g3 Usage
fn convert_usage(metadata: Option<&GeminiUsageMetadata>) -> Usage {
match metadata {
Some(m) => Usage {
prompt_tokens: m.prompt_token_count.unwrap_or(0),
completion_tokens: m.candidates_token_count.unwrap_or(0),
total_tokens: m.total_token_count.unwrap_or(0),
cache_creation_tokens: 0,
cache_read_tokens: 0,
},
None => Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
},
}
}
/// Convert Gemini finish reason to g3 stop reason
fn convert_finish_reason(reason: Option<&str>) -> Option<String> {
reason.map(|r| match r {
"STOP" => "end_turn".to_string(),
"MAX_TOKENS" => "max_tokens".to_string(),
"SAFETY" => "content_filter".to_string(),
"RECITATION" => "content_filter".to_string(),
other => other.to_lowercase(),
})
}
// ============================================================================
// Streaming Parser
// ============================================================================
/// Parse a streaming chunk from Gemini's SSE response
///
/// Gemini streams JSON objects, one per line (not SSE format with "data:" prefix)
fn parse_streaming_chunk(data: &str) -> Option<(String, Option<Vec<ToolCall>>, Option<String>, Option<GeminiUsageMetadata>)> {
// Skip empty lines
let data = data.trim();
if data.is_empty() {
return None;
}
// Try to parse as JSON
let response: GeminiResponse = match serde_json::from_str(data) {
Ok(r) => r,
Err(e) => {
debug!("Failed to parse Gemini streaming chunk: {} - data: {}", e, data);
return None;
}
};
// Extract content from candidates
let candidates = response.candidates?;
let candidate = candidates.first()?;
let content = candidate.content.as_ref()?;
let text = extract_text_from_parts(&content.parts);
let tool_calls = extract_tool_calls_from_parts(&content.parts);
let finish_reason = convert_finish_reason(candidate.finish_reason.as_deref());
Some((
text,
if tool_calls.is_empty() { None } else { Some(tool_calls) },
finish_reason,
response.usage_metadata,
))
}
/// Process streaming response from Gemini
async fn process_stream(
mut response: reqwest::Response,
tx: mpsc::Sender<Result<CompletionChunk>>,
) {
let mut buffer = String::new();
let mut accumulated_text = String::new();
let mut last_usage: Option<GeminiUsageMetadata> = None;
let mut last_finish_reason: Option<String> = None;
let mut pending_tool_calls: Vec<ToolCall> = Vec::new();
while let Some(chunk_result) = response.chunk().await.transpose() {
match chunk_result {
Ok(bytes) => {
let text = match String::from_utf8(bytes.to_vec()) {
Ok(t) => t,
Err(e) => {
error!("Invalid UTF-8 in Gemini stream: {}", e);
continue;
}
};
buffer.push_str(&text);
// Gemini streams as JSON array elements or newline-delimited JSON
// Try to parse complete JSON objects from the buffer
while let Some(parsed) = try_parse_json_from_buffer(&mut buffer) {
if let Some((content, tool_calls, finish_reason, usage)) = parse_streaming_chunk(&parsed) {
// Track usage and finish reason
if usage.is_some() {
last_usage = usage;
}
if finish_reason.is_some() {
last_finish_reason = finish_reason;
}
// Handle tool calls
if let Some(calls) = tool_calls {
pending_tool_calls.extend(calls);
}
// Send text content
if !content.is_empty() {
accumulated_text.push_str(&content);
if tx.send(Ok(make_text_chunk(content))).await.is_err() {
return;
}
}
}
}
}
Err(e) => {
error!("Error reading Gemini stream: {}", e);
let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await;
return;
}
}
}
// Send any pending tool calls
if !pending_tool_calls.is_empty() {
let chunk = CompletionChunk {
content: String::new(),
finished: false,
tool_calls: Some(pending_tool_calls),
usage: None,
stop_reason: None,
tool_call_streaming: None,
};
if tx.send(Ok(chunk)).await.is_err() {
return;
}
}
// Send final chunk with usage
let final_chunk = CompletionChunk {
content: String::new(),
finished: true,
tool_calls: None,
usage: Some(convert_usage(last_usage.as_ref())),
stop_reason: last_finish_reason,
tool_call_streaming: None,
};
let _ = tx.send(Ok(final_chunk)).await;
}
/// Try to extract a complete JSON object from the buffer
///
/// Gemini streams responses as a JSON array: [{...}, {...}, ...]
/// We need to handle the array brackets and extract individual objects
fn try_parse_json_from_buffer(buffer: &mut String) -> Option<String> {
let trimmed = buffer.trim_start();
// Skip leading array bracket or comma
let start_idx = if trimmed.starts_with('[') {
buffer.find('[')? + 1
} else if trimmed.starts_with(',') {
buffer.find(',')? + 1
} else {
0
};
// Find the start of a JSON object
let remaining = &buffer[start_idx..];
let obj_start = remaining.find('{')?;
let absolute_start = start_idx + obj_start;
// Find matching closing brace
let mut depth = 0;
let mut in_string = false;
let mut escape_next = false;
let mut end_idx = None;
for (i, c) in buffer[absolute_start..].char_indices() {
if escape_next {
escape_next = false;
continue;
}
match c {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => depth += 1,
'}' if !in_string => {
depth -= 1;
if depth == 0 {
end_idx = Some(absolute_start + i + 1);
break;
}
}
_ => {}
}
}
if let Some(end) = end_idx {
let json_str = buffer[absolute_start..end].to_string();
*buffer = buffer[end..].to_string();
Some(json_str)
} else {
None
}
}
// ============================================================================
// LLMProvider Implementation
// ============================================================================
#[async_trait]
impl LLMProvider for GeminiProvider {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
let (contents, system_instruction) = convert_messages(&request.messages);
let gemini_request = GeminiRequest {
contents,
system_instruction,
tools: request.tools.as_ref().map(|t| convert_tools(t)),
generation_config: GeminiGenerationConfig {
max_output_tokens: request.max_tokens.or(Some(self.max_tokens)),
temperature: request.temperature.or(Some(self.temperature)),
},
};
let url = self.get_api_url(false);
debug!("Gemini request URL: {}", url);
debug!("Gemini request body: {}", serde_json::to_string_pretty(&gemini_request).unwrap_or_default());
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&gemini_request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Gemini API error ({}): {}", status, error_text);
anyhow::bail!("Gemini API error ({}): {}", status, error_text);
}
let gemini_response: GeminiResponse = response.json().await?;
debug!("Gemini response: {:?}", gemini_response);
// Extract content from response
let content = gemini_response
.candidates
.as_ref()
.and_then(|c| c.first())
.and_then(|c| c.content.as_ref())
.map(|c| extract_text_from_parts(&c.parts))
.unwrap_or_default();
let usage = convert_usage(gemini_response.usage_metadata.as_ref());
Ok(CompletionResponse {
content,
usage,
model: self.model.clone(),
})
}
async fn stream(&self, request: CompletionRequest) -> Result<CompletionStream> {
let (contents, system_instruction) = convert_messages(&request.messages);
let gemini_request = GeminiRequest {
contents,
system_instruction,
tools: request.tools.as_ref().map(|t| convert_tools(t)),
generation_config: GeminiGenerationConfig {
max_output_tokens: request.max_tokens.or(Some(self.max_tokens)),
temperature: request.temperature.or(Some(self.temperature)),
},
};
// For streaming, add alt=sse parameter
let url = format!("{}&alt=sse", self.get_api_url(true));
debug!("Gemini streaming request URL: {}", url);
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&gemini_request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
error!("Gemini API error ({}): {}", status, error_text);
anyhow::bail!("Gemini API error ({}): {}", status, error_text);
}
let (tx, rx) = mpsc::channel(32);
tokio::spawn(process_stream(response, tx));
Ok(ReceiverStream::new(rx))
}
fn name(&self) -> &str {
&self.name
}
fn model(&self) -> &str {
&self.model
}
fn has_native_tool_calling(&self) -> bool {
true
}
fn max_tokens(&self) -> u32 {
self.max_tokens
}
fn temperature(&self) -> f32 {
self.temperature
}
fn context_window_size(&self) -> Option<u32> {
// Context window sizes by model
// https://ai.google.dev/gemini-api/docs/models
let size = if self.model.contains("1.5-pro") || self.model.contains("1.5-flash") {
2_000_000 // Gemini 1.5 models have 2M context
} else if self.model.contains("2.5-pro") || self.model.contains("2.5-flash") {
1_000_000 // Gemini 2.5 models have 1M context
} else if self.model.contains("2.0") {
1_000_000 // Gemini 2.0 models have 1M context
} else {
128_000 // Conservative default for unknown models
};
Some(size)
}
}
// ============================================================================
// Unit Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_convert_messages_basic() {
let messages = vec![
Message::new(MessageRole::User, "Hello".to_string()),
Message::new(MessageRole::Assistant, "Hi there!".to_string()),
];
let (contents, system) = convert_messages(&messages);
assert!(system.is_none());
assert_eq!(contents.len(), 2);
assert_eq!(contents[0].role, Some("user".to_string()));
assert_eq!(contents[1].role, Some("model".to_string())); // assistant -> model
}
#[test]
fn test_convert_messages_with_system() {
let messages = vec![
Message::new(MessageRole::System, "You are helpful.".to_string()),
Message::new(MessageRole::User, "Hello".to_string()),
];
let (contents, system) = convert_messages(&messages);
assert!(system.is_some());
let sys = system.unwrap();
assert!(sys.role.is_none()); // system_instruction has no role
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role, Some("user".to_string()));
}
#[test]
fn test_convert_tools() {
let tools = vec![Tool {
name: "get_weather".to_string(),
description: "Get the weather".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"location": { "type": "string" }
}
}),
}];
let gemini_tools = convert_tools(&tools);
assert_eq!(gemini_tools.len(), 1);
assert_eq!(gemini_tools[0].function_declarations.len(), 1);
assert_eq!(gemini_tools[0].function_declarations[0].name, "get_weather");
}
#[test]
fn test_parse_streaming_chunk() {
let chunk = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}"#;
let result = parse_streaming_chunk(chunk);
assert!(result.is_some());
let (text, tool_calls, finish_reason, usage) = result.unwrap();
assert_eq!(text, "Hello");
assert!(tool_calls.is_none());
assert_eq!(finish_reason, Some("end_turn".to_string()));
assert!(usage.is_some());
assert_eq!(usage.unwrap().total_token_count, Some(15));
}
#[test]
fn test_parse_streaming_chunk_with_tool_call() {
let chunk = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"location":"NYC"}}}],"role":"model"}}]}"#;
let result = parse_streaming_chunk(chunk);
assert!(result.is_some());
let (text, tool_calls, _, _) = result.unwrap();
assert_eq!(text, "");
assert!(tool_calls.is_some());
let calls = tool_calls.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].tool, "get_weather");
}
#[test]
fn test_try_parse_json_from_buffer() {
let mut buffer = r#"[{"test": 1}, {"test": 2}]"#.to_string();
let first = try_parse_json_from_buffer(&mut buffer);
assert!(first.is_some());
assert_eq!(first.unwrap(), r#"{"test": 1}"#);
let second = try_parse_json_from_buffer(&mut buffer);
assert!(second.is_some());
assert_eq!(second.unwrap(), r#"{"test": 2}"#);
}
#[test]
fn test_convert_finish_reason() {
assert_eq!(convert_finish_reason(Some("STOP")), Some("end_turn".to_string()));
assert_eq!(convert_finish_reason(Some("MAX_TOKENS")), Some("max_tokens".to_string()));
assert_eq!(convert_finish_reason(Some("SAFETY")), Some("content_filter".to_string()));
assert_eq!(convert_finish_reason(None), None);
}
#[test]
fn test_extract_text_from_parts() {
let parts = vec![
GeminiPart::Text { text: "Hello ".to_string() },
GeminiPart::Text { text: "world!".to_string() },
];
let text = extract_text_from_parts(&parts);
assert_eq!(text, "Hello world!");
}
}

View File

@@ -241,12 +241,14 @@ pub struct Tool {
pub mod anthropic;
pub mod databricks;
pub mod embedded;
pub mod gemini;
pub mod oauth;
pub mod openai;
pub use anthropic::AnthropicProvider;
pub use databricks::DatabricksProvider;
pub use embedded::EmbeddedProvider;
pub use gemini::GeminiProvider;
pub use openai::OpenAIProvider;
impl Message {