diff --git a/crates/g3-cli/src/utils.rs b/crates/g3-cli/src/utils.rs index 8d892f9..cec67c6 100644 --- a/crates/g3-cli/src/utils.rs +++ b/crates/g3-cli/src/utils.rs @@ -138,7 +138,7 @@ pub fn load_config_with_cli_overrides(cli: &Cli) -> Result { // Validate provider if specified if let Some(ref provider) = cli.provider { - let valid_providers = ["anthropic", "databricks", "embedded", "openai"]; + let valid_providers = ["anthropic", "databricks", "embedded", "gemini", "openai"]; let provider_type = provider.split('.').next().unwrap_or(provider); if !valid_providers.contains(&provider_type) { return Err(anyhow::anyhow!( diff --git a/crates/g3-config/src/lib.rs b/crates/g3-config/src/lib.rs index c57e727..f7137cf 100644 --- a/crates/g3-config/src/lib.rs +++ b/crates/g3-config/src/lib.rs @@ -46,6 +46,10 @@ pub struct ProvidersConfig { #[serde(default)] pub embedded: HashMap, + /// Named Gemini provider configs + #[serde(default)] + pub gemini: HashMap, + /// Multiple named OpenAI-compatible providers (e.g., openrouter, groq, etc.) #[serde(default)] pub openai_compatible: HashMap, @@ -92,6 +96,14 @@ pub struct EmbeddedConfig { pub threads: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeminiConfig { + pub api_key: String, + pub model: String, + pub max_tokens: Option, + pub temperature: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentConfig { pub max_context_length: Option, @@ -230,6 +242,7 @@ impl Default for Config { openai: HashMap::new(), databricks: databricks_configs, embedded: HashMap::new(), + gemini: HashMap::new(), openai_compatible: HashMap::new(), }, agent: AgentConfig { @@ -445,11 +458,20 @@ impl Config { ); } } + "gemini" => { + if !self.providers.gemini.contains_key(config_name) { + anyhow::bail!( + "Provider config 'gemini.{}' not found. Available: {:?}", + config_name, + self.providers.gemini.keys().collect::>() + ); + } + } _ => { // Check openai_compatible providers if !self.providers.openai_compatible.contains_key(provider_type) { anyhow::bail!( - "Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, or openai_compatible names", + "Unknown provider type '{}'. Valid types: anthropic, openai, databricks, embedded, gemini, or openai_compatible names", provider_type ); } @@ -550,6 +572,18 @@ impl Config { )); } } + "gemini" => { + if let Some(ref mut gemini_config) = + config.providers.gemini.get_mut(&config_name) + { + gemini_config.model = model; + } else { + return Err(anyhow::anyhow!( + "Provider config 'gemini.{}' not found.", + config_name + )); + } + } _ => { // Check openai_compatible if let Some(ref mut compat_config) = @@ -635,6 +669,11 @@ impl Config { self.providers.embedded.get(name) } + /// Get Gemini config by name + pub fn get_gemini_config(&self, name: &str) -> Option<&GeminiConfig> { + self.providers.gemini.get(name) + } + /// Get the current default provider's config pub fn get_default_provider_config(&self) -> Result> { let (provider_type, config_name) = @@ -665,6 +704,12 @@ impl Config { .get(&config_name) .map(ProviderConfigRef::Embedded) .ok_or_else(|| anyhow::anyhow!("Embedded config '{}' not found", config_name)), + "gemini" => self + .providers + .gemini + .get(&config_name) + .map(ProviderConfigRef::Gemini) + .ok_or_else(|| anyhow::anyhow!("Gemini config '{}' not found", config_name)), _ => self .providers .openai_compatible @@ -684,6 +729,7 @@ pub enum ProviderConfigRef<'a> { OpenAI(&'a OpenAIConfig), Databricks(&'a DatabricksConfig), Embedded(&'a EmbeddedConfig), + Gemini(&'a GeminiConfig), OpenAICompatible(&'a OpenAIConfig), } diff --git a/crates/g3-core/src/lib.rs b/crates/g3-core/src/lib.rs index b4009cb..a3206f8 100644 --- a/crates/g3-core/src/lib.rs +++ b/crates/g3-core/src/lib.rs @@ -708,6 +708,18 @@ impl Agent { 16384 // Conservative default for other Databricks models } } + "gemini" => { + // Gemini models - use provider's context_window_size() + if let Some(ctx_size) = provider.context_window_size() { + debug!( + "Using context window size {} from Gemini provider", + ctx_size + ); + ctx_size + } else { + 1_000_000 // Default for Gemini models + } + } _ => config.agent.fallback_default_max_tokens as u32, }; diff --git a/crates/g3-core/src/provider_registration.rs b/crates/g3-core/src/provider_registration.rs index 44baf39..746de51 100644 --- a/crates/g3-core/src/provider_registration.rs +++ b/crates/g3-core/src/provider_registration.rs @@ -57,6 +57,7 @@ pub async fn register_providers( register_openai_providers(config, providers_to_register, &mut registry)?; register_openai_compatible_providers(config, providers_to_register, &mut registry)?; register_anthropic_providers(config, providers_to_register, &mut registry)?; + register_gemini_providers(config, providers_to_register, &mut registry)?; register_databricks_providers(config, providers_to_register, &mut registry).await?; // Set default provider @@ -162,6 +163,27 @@ fn register_anthropic_providers( Ok(()) } +/// Register Gemini providers from configuration. +fn register_gemini_providers( + config: &Config, + providers_to_register: &[String], + registry: &mut ProviderRegistry, +) -> Result<()> { + for (name, gemini_config) in &config.providers.gemini { + if should_register(providers_to_register, "gemini", name) { + let gemini_provider = g3_providers::GeminiProvider::new_with_name( + format!("gemini.{}", name), + gemini_config.api_key.clone(), + Some(gemini_config.model.clone()), + gemini_config.max_tokens, + gemini_config.temperature, + )?; + registry.register(gemini_provider); + } + } + Ok(()) +} + /// Register Databricks providers from configuration. /// /// This is async because OAuth authentication requires async operations. diff --git a/crates/g3-providers/src/gemini.rs b/crates/g3-providers/src/gemini.rs new file mode 100644 index 0000000..005b7a4 --- /dev/null +++ b/crates/g3-providers/src/gemini.rs @@ -0,0 +1,776 @@ +//! Google Gemini provider implementation for the g3-providers crate. +//! +//! This module provides an implementation of the `LLMProvider` trait for Google's Gemini models, +//! supporting both completion and streaming modes through the Gemini API. +//! +//! # Features +//! +//! - Support for Gemini models (gemini-2.0-flash, gemini-1.5-pro, etc.) +//! - Both completion and streaming response modes +//! - Proper message format conversion between g3 and Gemini formats +//! - Native tool calling support +//! +//! # Usage +//! +//! ```rust,no_run +//! use g3_providers::{GeminiProvider, LLMProvider, CompletionRequest, Message, MessageRole}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! let provider = GeminiProvider::new( +//! "your-api-key".to_string(), +//! Some("gemini-2.0-flash".to_string()), +//! Some(8192), +//! Some(0.7), +//! )?; +//! +//! let request = CompletionRequest { +//! messages: vec![ +//! Message::new(MessageRole::System, "You are a helpful assistant.".to_string()), +//! Message::new(MessageRole::User, "Hello! How are you?".to_string()), +//! ], +//! max_tokens: Some(1000), +//! temperature: Some(0.7), +//! stream: false, +//! tools: None, +//! disable_thinking: false, +//! }; +//! +//! let response = provider.complete(request).await?; +//! println!("Response: {}", response.content); +//! +//! Ok(()) +//! } +//! ``` + +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tracing::{debug, error}; + +use crate::{ + CompletionChunk, CompletionRequest, CompletionResponse, CompletionStream, LLMProvider, Message, + MessageRole, Tool, ToolCall, Usage, streaming::make_text_chunk, +}; + +// ============================================================================ +// Provider Struct +// ============================================================================ + +#[derive(Clone)] +pub struct GeminiProvider { + client: Client, + api_key: String, + model: String, + max_tokens: u32, + temperature: f32, + name: String, +} + +impl GeminiProvider { + pub fn new( + api_key: String, + model: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + Ok(Self { + client: Client::new(), + api_key, + model: model.unwrap_or_else(|| "gemini-2.0-flash".to_string()), + max_tokens: max_tokens.unwrap_or(16384), + temperature: temperature.unwrap_or(0.1), + name: "gemini".to_string(), + }) + } + + pub fn new_with_name( + name: String, + api_key: String, + model: Option, + max_tokens: Option, + temperature: Option, + ) -> Result { + Ok(Self { + client: Client::new(), + api_key, + model: model.unwrap_or_else(|| "gemini-2.0-flash".to_string()), + max_tokens: max_tokens.unwrap_or(16384), + temperature: temperature.unwrap_or(0.1), + name, + }) + } + + fn get_api_url(&self, stream: bool) -> String { + let method = if stream { "streamGenerateContent" } else { "generateContent" }; + format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:{}?key={}", + self.model, method, self.api_key + ) + } +} + +// ============================================================================ +// Gemini API Request/Response Types +// ============================================================================ + +/// Gemini API request body +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + generation_config: GeminiGenerationConfig, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiContent { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + parts: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +enum GeminiPart { + Text { + text: String, + }, + FunctionCall { + #[serde(rename = "functionCall")] + function_call: GeminiFunctionCall, + }, + FunctionResponse { + #[serde(rename = "functionResponse")] + function_response: GeminiFunctionResponse, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiFunctionCall { + name: String, + args: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct GeminiFunctionResponse { + name: String, + response: serde_json::Value, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiGenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiTool { + function_declarations: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionDeclaration { + name: String, + description: String, + #[serde(skip_serializing_if = "Option::is_none")] + parameters: Option, +} + +/// Gemini API response +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiResponse { + candidates: Option>, + usage_metadata: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiCandidate { + content: Option, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct GeminiUsageMetadata { + prompt_token_count: Option, + candidates_token_count: Option, + total_token_count: Option, +} + +// ============================================================================ +// Message Conversion +// ============================================================================ + +/// Convert g3 messages to Gemini format +/// +/// Key differences: +/// - Gemini uses "model" instead of "assistant" +/// - System messages go in system_instruction, not contents +/// - Gemini uses "parts" array with text objects +fn convert_messages(messages: &[Message]) -> (Vec, Option) { + let mut contents = Vec::new(); + let mut system_instruction = None; + + for msg in messages { + match msg.role { + MessageRole::System => { + // System messages go to system_instruction + system_instruction = Some(GeminiContent { + role: None, // system_instruction doesn't need a role + parts: vec![GeminiPart::Text { text: msg.content.clone() }], + }); + } + MessageRole::User => { + contents.push(GeminiContent { + role: Some("user".to_string()), + parts: vec![GeminiPart::Text { text: msg.content.clone() }], + }); + } + MessageRole::Assistant => { + // Gemini uses "model" instead of "assistant" + contents.push(GeminiContent { + role: Some("model".to_string()), + parts: vec![GeminiPart::Text { text: msg.content.clone() }], + }); + } + } + } + + (contents, system_instruction) +} + +/// Convert g3 tools to Gemini format +fn convert_tools(tools: &[Tool]) -> Vec { + let declarations: Vec = tools + .iter() + .map(|tool| GeminiFunctionDeclaration { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: if tool.input_schema.is_null() { + None + } else { + Some(tool.input_schema.clone()) + }, + }) + .collect(); + + vec![GeminiTool { + function_declarations: declarations, + }] +} + +/// Extract text content from Gemini response parts +fn extract_text_from_parts(parts: &[GeminiPart]) -> String { + parts + .iter() + .filter_map(|part| { + if let GeminiPart::Text { text } = part { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("") +} + +/// Extract tool calls from Gemini response parts +fn extract_tool_calls_from_parts(parts: &[GeminiPart]) -> Vec { + parts + .iter() + .filter_map(|part| { + if let GeminiPart::FunctionCall { function_call } = part { + Some(ToolCall { + id: format!("call_{}", nanoid::nanoid!(8)), + tool: function_call.name.clone(), + args: function_call.args.clone(), + }) + } else { + None + } + }) + .collect() +} + +/// Convert Gemini usage metadata to g3 Usage +fn convert_usage(metadata: Option<&GeminiUsageMetadata>) -> Usage { + match metadata { + Some(m) => Usage { + prompt_tokens: m.prompt_token_count.unwrap_or(0), + completion_tokens: m.candidates_token_count.unwrap_or(0), + total_tokens: m.total_token_count.unwrap_or(0), + cache_creation_tokens: 0, + cache_read_tokens: 0, + }, + None => Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + cache_creation_tokens: 0, + cache_read_tokens: 0, + }, + } +} + +/// Convert Gemini finish reason to g3 stop reason +fn convert_finish_reason(reason: Option<&str>) -> Option { + reason.map(|r| match r { + "STOP" => "end_turn".to_string(), + "MAX_TOKENS" => "max_tokens".to_string(), + "SAFETY" => "content_filter".to_string(), + "RECITATION" => "content_filter".to_string(), + other => other.to_lowercase(), + }) +} + +// ============================================================================ +// Streaming Parser +// ============================================================================ + +/// Parse a streaming chunk from Gemini's SSE response +/// +/// Gemini streams JSON objects, one per line (not SSE format with "data:" prefix) +fn parse_streaming_chunk(data: &str) -> Option<(String, Option>, Option, Option)> { + // Skip empty lines + let data = data.trim(); + if data.is_empty() { + return None; + } + + // Try to parse as JSON + let response: GeminiResponse = match serde_json::from_str(data) { + Ok(r) => r, + Err(e) => { + debug!("Failed to parse Gemini streaming chunk: {} - data: {}", e, data); + return None; + } + }; + + // Extract content from candidates + let candidates = response.candidates?; + let candidate = candidates.first()?; + let content = candidate.content.as_ref()?; + + let text = extract_text_from_parts(&content.parts); + let tool_calls = extract_tool_calls_from_parts(&content.parts); + let finish_reason = convert_finish_reason(candidate.finish_reason.as_deref()); + + Some(( + text, + if tool_calls.is_empty() { None } else { Some(tool_calls) }, + finish_reason, + response.usage_metadata, + )) +} + +/// Process streaming response from Gemini +async fn process_stream( + mut response: reqwest::Response, + tx: mpsc::Sender>, +) { + let mut buffer = String::new(); + let mut accumulated_text = String::new(); + let mut last_usage: Option = None; + let mut last_finish_reason: Option = None; + let mut pending_tool_calls: Vec = Vec::new(); + + while let Some(chunk_result) = response.chunk().await.transpose() { + match chunk_result { + Ok(bytes) => { + let text = match String::from_utf8(bytes.to_vec()) { + Ok(t) => t, + Err(e) => { + error!("Invalid UTF-8 in Gemini stream: {}", e); + continue; + } + }; + + buffer.push_str(&text); + + // Gemini streams as JSON array elements or newline-delimited JSON + // Try to parse complete JSON objects from the buffer + while let Some(parsed) = try_parse_json_from_buffer(&mut buffer) { + if let Some((content, tool_calls, finish_reason, usage)) = parse_streaming_chunk(&parsed) { + // Track usage and finish reason + if usage.is_some() { + last_usage = usage; + } + if finish_reason.is_some() { + last_finish_reason = finish_reason; + } + + // Handle tool calls + if let Some(calls) = tool_calls { + pending_tool_calls.extend(calls); + } + + // Send text content + if !content.is_empty() { + accumulated_text.push_str(&content); + if tx.send(Ok(make_text_chunk(content))).await.is_err() { + return; + } + } + } + } + } + Err(e) => { + error!("Error reading Gemini stream: {}", e); + let _ = tx.send(Err(anyhow::anyhow!("Stream error: {}", e))).await; + return; + } + } + } + + // Send any pending tool calls + if !pending_tool_calls.is_empty() { + let chunk = CompletionChunk { + content: String::new(), + finished: false, + tool_calls: Some(pending_tool_calls), + usage: None, + stop_reason: None, + tool_call_streaming: None, + }; + if tx.send(Ok(chunk)).await.is_err() { + return; + } + } + + // Send final chunk with usage + let final_chunk = CompletionChunk { + content: String::new(), + finished: true, + tool_calls: None, + usage: Some(convert_usage(last_usage.as_ref())), + stop_reason: last_finish_reason, + tool_call_streaming: None, + }; + let _ = tx.send(Ok(final_chunk)).await; +} + +/// Try to extract a complete JSON object from the buffer +/// +/// Gemini streams responses as a JSON array: [{...}, {...}, ...] +/// We need to handle the array brackets and extract individual objects +fn try_parse_json_from_buffer(buffer: &mut String) -> Option { + let trimmed = buffer.trim_start(); + + // Skip leading array bracket or comma + let start_idx = if trimmed.starts_with('[') { + buffer.find('[')? + 1 + } else if trimmed.starts_with(',') { + buffer.find(',')? + 1 + } else { + 0 + }; + + // Find the start of a JSON object + let remaining = &buffer[start_idx..]; + let obj_start = remaining.find('{')?; + let absolute_start = start_idx + obj_start; + + // Find matching closing brace + let mut depth = 0; + let mut in_string = false; + let mut escape_next = false; + let mut end_idx = None; + + for (i, c) in buffer[absolute_start..].char_indices() { + if escape_next { + escape_next = false; + continue; + } + + match c { + '\\' if in_string => escape_next = true, + '"' => in_string = !in_string, + '{' if !in_string => depth += 1, + '}' if !in_string => { + depth -= 1; + if depth == 0 { + end_idx = Some(absolute_start + i + 1); + break; + } + } + _ => {} + } + } + + if let Some(end) = end_idx { + let json_str = buffer[absolute_start..end].to_string(); + *buffer = buffer[end..].to_string(); + Some(json_str) + } else { + None + } +} + +// ============================================================================ +// LLMProvider Implementation +// ============================================================================ + +#[async_trait] +impl LLMProvider for GeminiProvider { + async fn complete(&self, request: CompletionRequest) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages); + + let gemini_request = GeminiRequest { + contents, + system_instruction, + tools: request.tools.as_ref().map(|t| convert_tools(t)), + generation_config: GeminiGenerationConfig { + max_output_tokens: request.max_tokens.or(Some(self.max_tokens)), + temperature: request.temperature.or(Some(self.temperature)), + }, + }; + + let url = self.get_api_url(false); + debug!("Gemini request URL: {}", url); + debug!("Gemini request body: {}", serde_json::to_string_pretty(&gemini_request).unwrap_or_default()); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&gemini_request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Gemini API error ({}): {}", status, error_text); + anyhow::bail!("Gemini API error ({}): {}", status, error_text); + } + + let gemini_response: GeminiResponse = response.json().await?; + debug!("Gemini response: {:?}", gemini_response); + + // Extract content from response + let content = gemini_response + .candidates + .as_ref() + .and_then(|c| c.first()) + .and_then(|c| c.content.as_ref()) + .map(|c| extract_text_from_parts(&c.parts)) + .unwrap_or_default(); + + let usage = convert_usage(gemini_response.usage_metadata.as_ref()); + + Ok(CompletionResponse { + content, + usage, + model: self.model.clone(), + }) + } + + async fn stream(&self, request: CompletionRequest) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages); + + let gemini_request = GeminiRequest { + contents, + system_instruction, + tools: request.tools.as_ref().map(|t| convert_tools(t)), + generation_config: GeminiGenerationConfig { + max_output_tokens: request.max_tokens.or(Some(self.max_tokens)), + temperature: request.temperature.or(Some(self.temperature)), + }, + }; + + // For streaming, add alt=sse parameter + let url = format!("{}&alt=sse", self.get_api_url(true)); + debug!("Gemini streaming request URL: {}", url); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&gemini_request) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + error!("Gemini API error ({}): {}", status, error_text); + anyhow::bail!("Gemini API error ({}): {}", status, error_text); + } + + let (tx, rx) = mpsc::channel(32); + tokio::spawn(process_stream(response, tx)); + + Ok(ReceiverStream::new(rx)) + } + + fn name(&self) -> &str { + &self.name + } + + fn model(&self) -> &str { + &self.model + } + + fn has_native_tool_calling(&self) -> bool { + true + } + + fn max_tokens(&self) -> u32 { + self.max_tokens + } + + fn temperature(&self) -> f32 { + self.temperature + } + + fn context_window_size(&self) -> Option { + // Context window sizes by model + // https://ai.google.dev/gemini-api/docs/models + let size = if self.model.contains("1.5-pro") || self.model.contains("1.5-flash") { + 2_000_000 // Gemini 1.5 models have 2M context + } else if self.model.contains("2.5-pro") || self.model.contains("2.5-flash") { + 1_000_000 // Gemini 2.5 models have 1M context + } else if self.model.contains("2.0") { + 1_000_000 // Gemini 2.0 models have 1M context + } else { + 128_000 // Conservative default for unknown models + }; + Some(size) + } +} + +// ============================================================================ +// Unit Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_convert_messages_basic() { + let messages = vec![ + Message::new(MessageRole::User, "Hello".to_string()), + Message::new(MessageRole::Assistant, "Hi there!".to_string()), + ]; + + let (contents, system) = convert_messages(&messages); + + assert!(system.is_none()); + assert_eq!(contents.len(), 2); + assert_eq!(contents[0].role, Some("user".to_string())); + assert_eq!(contents[1].role, Some("model".to_string())); // assistant -> model + } + + #[test] + fn test_convert_messages_with_system() { + let messages = vec![ + Message::new(MessageRole::System, "You are helpful.".to_string()), + Message::new(MessageRole::User, "Hello".to_string()), + ]; + + let (contents, system) = convert_messages(&messages); + + assert!(system.is_some()); + let sys = system.unwrap(); + assert!(sys.role.is_none()); // system_instruction has no role + assert_eq!(contents.len(), 1); + assert_eq!(contents[0].role, Some("user".to_string())); + } + + #[test] + fn test_convert_tools() { + let tools = vec![Tool { + name: "get_weather".to_string(), + description: "Get the weather".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "location": { "type": "string" } + } + }), + }]; + + let gemini_tools = convert_tools(&tools); + + assert_eq!(gemini_tools.len(), 1); + assert_eq!(gemini_tools[0].function_declarations.len(), 1); + assert_eq!(gemini_tools[0].function_declarations[0].name, "get_weather"); + } + + #[test] + fn test_parse_streaming_chunk() { + let chunk = r#"{"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}"#; + + let result = parse_streaming_chunk(chunk); + assert!(result.is_some()); + + let (text, tool_calls, finish_reason, usage) = result.unwrap(); + assert_eq!(text, "Hello"); + assert!(tool_calls.is_none()); + assert_eq!(finish_reason, Some("end_turn".to_string())); + assert!(usage.is_some()); + assert_eq!(usage.unwrap().total_token_count, Some(15)); + } + + #[test] + fn test_parse_streaming_chunk_with_tool_call() { + let chunk = r#"{"candidates":[{"content":{"parts":[{"functionCall":{"name":"get_weather","args":{"location":"NYC"}}}],"role":"model"}}]}"#; + + let result = parse_streaming_chunk(chunk); + assert!(result.is_some()); + + let (text, tool_calls, _, _) = result.unwrap(); + assert_eq!(text, ""); + assert!(tool_calls.is_some()); + let calls = tool_calls.unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].tool, "get_weather"); + } + + #[test] + fn test_try_parse_json_from_buffer() { + let mut buffer = r#"[{"test": 1}, {"test": 2}]"#.to_string(); + + let first = try_parse_json_from_buffer(&mut buffer); + assert!(first.is_some()); + assert_eq!(first.unwrap(), r#"{"test": 1}"#); + + let second = try_parse_json_from_buffer(&mut buffer); + assert!(second.is_some()); + assert_eq!(second.unwrap(), r#"{"test": 2}"#); + } + + #[test] + fn test_convert_finish_reason() { + assert_eq!(convert_finish_reason(Some("STOP")), Some("end_turn".to_string())); + assert_eq!(convert_finish_reason(Some("MAX_TOKENS")), Some("max_tokens".to_string())); + assert_eq!(convert_finish_reason(Some("SAFETY")), Some("content_filter".to_string())); + assert_eq!(convert_finish_reason(None), None); + } + + #[test] + fn test_extract_text_from_parts() { + let parts = vec![ + GeminiPart::Text { text: "Hello ".to_string() }, + GeminiPart::Text { text: "world!".to_string() }, + ]; + + let text = extract_text_from_parts(&parts); + assert_eq!(text, "Hello world!"); + } +} diff --git a/crates/g3-providers/src/lib.rs b/crates/g3-providers/src/lib.rs index 18cce1f..7793a94 100644 --- a/crates/g3-providers/src/lib.rs +++ b/crates/g3-providers/src/lib.rs @@ -241,12 +241,14 @@ pub struct Tool { pub mod anthropic; pub mod databricks; pub mod embedded; +pub mod gemini; pub mod oauth; pub mod openai; pub use anthropic::AnthropicProvider; pub use databricks::DatabricksProvider; pub use embedded::EmbeddedProvider; +pub use gemini::GeminiProvider; pub use openai::OpenAIProvider; impl Message {