Merge pull request #1 from michaelneale/auto-download-qwen
Auto download qwen
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -720,6 +720,7 @@ version = "0.1.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"config",
|
"config",
|
||||||
|
"dirs 5.0.1",
|
||||||
"serde",
|
"serde",
|
||||||
"shellexpand",
|
"shellexpand",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
|
|||||||
@@ -11,3 +11,4 @@ anyhow = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
toml = "0.8"
|
toml = "0.8"
|
||||||
shellexpand = "3.0"
|
shellexpand = "3.0"
|
||||||
|
dirs = "5.0"
|
||||||
|
|||||||
@@ -71,6 +71,50 @@ impl Default for Config {
|
|||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn load(config_path: Option<&str>) -> Result<Self> {
|
pub fn load(config_path: Option<&str>) -> Result<Self> {
|
||||||
|
// Check if any config file exists
|
||||||
|
let config_exists = if let Some(path) = config_path {
|
||||||
|
Path::new(path).exists()
|
||||||
|
} else {
|
||||||
|
// Check default locations
|
||||||
|
let default_paths = [
|
||||||
|
"./g3.toml",
|
||||||
|
"~/.config/g3/config.toml",
|
||||||
|
"~/.g3.toml",
|
||||||
|
];
|
||||||
|
|
||||||
|
default_paths.iter().any(|path| {
|
||||||
|
let expanded_path = shellexpand::tilde(path);
|
||||||
|
Path::new(expanded_path.as_ref()).exists()
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
// If no config exists, create and save a default Qwen config
|
||||||
|
if !config_exists {
|
||||||
|
let qwen_config = Self::default_qwen_config();
|
||||||
|
|
||||||
|
// Save to default location
|
||||||
|
let config_dir = dirs::home_dir()
|
||||||
|
.map(|mut path| {
|
||||||
|
path.push(".config");
|
||||||
|
path.push("g3");
|
||||||
|
path
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| std::path::PathBuf::from("."));
|
||||||
|
|
||||||
|
// Create directory if it doesn't exist
|
||||||
|
std::fs::create_dir_all(&config_dir).ok();
|
||||||
|
|
||||||
|
let config_file = config_dir.join("config.toml");
|
||||||
|
if let Err(e) = qwen_config.save(config_file.to_str().unwrap()) {
|
||||||
|
eprintln!("Warning: Could not save default config: {}", e);
|
||||||
|
} else {
|
||||||
|
println!("Created default Qwen configuration at: {}", config_file.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(qwen_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Existing config loading logic
|
||||||
let mut settings = config::Config::builder();
|
let mut settings = config::Config::builder();
|
||||||
|
|
||||||
// Load default configuration
|
// Load default configuration
|
||||||
@@ -108,6 +152,30 @@ impl Config {
|
|||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_qwen_config() -> Self {
|
||||||
|
Self {
|
||||||
|
providers: ProvidersConfig {
|
||||||
|
openai: None,
|
||||||
|
anthropic: None,
|
||||||
|
embedded: Some(EmbeddedConfig {
|
||||||
|
model_path: "~/.cache/g3/models/qwen2.5-7b-instruct-q3_k_m.gguf".to_string(),
|
||||||
|
model_type: "qwen".to_string(),
|
||||||
|
context_length: Some(32768), // Qwen2.5 supports 32k context
|
||||||
|
max_tokens: Some(2048),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
gpu_layers: Some(32),
|
||||||
|
threads: Some(8),
|
||||||
|
}),
|
||||||
|
default_provider: "embedded".to_string(),
|
||||||
|
},
|
||||||
|
agent: AgentConfig {
|
||||||
|
max_context_length: 8192,
|
||||||
|
enable_streaming: true,
|
||||||
|
timeout_seconds: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn save(&self, path: &str) -> Result<()> {
|
pub fn save(&self, path: &str) -> Result<()> {
|
||||||
let toml_string = toml::to_string_pretty(self)?;
|
let toml_string = toml::to_string_pretty(self)?;
|
||||||
std::fs::write(path, toml_string)?;
|
std::fs::write(path, toml_string)?;
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ use llama_cpp::{
|
|||||||
standard_sampler::{SamplerStage, StandardSampler},
|
standard_sampler::{SamplerStage, StandardSampler},
|
||||||
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
LlamaModel, LlamaParams, LlamaSession, SessionParams,
|
||||||
};
|
};
|
||||||
use std::path::Path;
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::atomic::AtomicBool;
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
pub struct EmbeddedProvider {
|
pub struct EmbeddedProvider {
|
||||||
model: Arc<LlamaModel>,
|
model: Arc<LlamaModel>,
|
||||||
@@ -39,11 +40,19 @@ impl EmbeddedProvider {
|
|||||||
|
|
||||||
// Expand tilde in path
|
// Expand tilde in path
|
||||||
let expanded_path = shellexpand::tilde(&model_path);
|
let expanded_path = shellexpand::tilde(&model_path);
|
||||||
let model_path = Path::new(expanded_path.as_ref());
|
let model_path_buf = PathBuf::from(expanded_path.as_ref());
|
||||||
|
|
||||||
if !model_path.exists() {
|
// If model doesn't exist and it's the default Qwen model, offer to download it
|
||||||
anyhow::bail!("Model file not found: {}", model_path.display());
|
if !model_path_buf.exists() {
|
||||||
|
if model_path.contains("qwen2.5-7b-instruct-q3_k_m.gguf") {
|
||||||
|
info!("Model file not found. Attempting to download Qwen 2.5 7B model...");
|
||||||
|
Self::download_qwen_model(&model_path_buf)?;
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Model file not found: {}", model_path_buf.display());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let model_path = model_path_buf.as_path();
|
||||||
|
|
||||||
// Set up model parameters
|
// Set up model parameters
|
||||||
let mut params = LlamaParams::default();
|
let mut params = LlamaParams::default();
|
||||||
@@ -108,6 +117,45 @@ impl EmbeddedProvider {
|
|||||||
|
|
||||||
// Add the start of assistant response
|
// Add the start of assistant response
|
||||||
formatted.push_str("<|im_start|>assistant\n");
|
formatted.push_str("<|im_start|>assistant\n");
|
||||||
|
formatted
|
||||||
|
} else if model_name_lower.contains("mistral") {
|
||||||
|
// Mistral Instruct format: <s>[INST] ... [/INST] assistant_response</s>
|
||||||
|
let mut formatted = String::new();
|
||||||
|
let mut in_conversation = false;
|
||||||
|
|
||||||
|
for (i, message) in messages.iter().enumerate() {
|
||||||
|
match message.role {
|
||||||
|
MessageRole::System => {
|
||||||
|
// Mistral doesn't have a special system token, include it at the start
|
||||||
|
if i == 0 {
|
||||||
|
formatted.push_str("<s>[INST] ");
|
||||||
|
formatted.push_str(&message.content);
|
||||||
|
formatted.push_str("\n\n");
|
||||||
|
in_conversation = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MessageRole::User => {
|
||||||
|
if !in_conversation {
|
||||||
|
formatted.push_str("<s>[INST] ");
|
||||||
|
}
|
||||||
|
formatted.push_str(&message.content);
|
||||||
|
formatted.push_str(" [/INST]");
|
||||||
|
in_conversation = false;
|
||||||
|
}
|
||||||
|
MessageRole::Assistant => {
|
||||||
|
formatted.push_str(" ");
|
||||||
|
formatted.push_str(&message.content);
|
||||||
|
formatted.push_str("</s> ");
|
||||||
|
in_conversation = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the last message was from user, add a space for the assistant's response
|
||||||
|
if messages.last().map_or(false, |m| matches!(m.role, MessageRole::User)) {
|
||||||
|
formatted.push_str(" ");
|
||||||
|
}
|
||||||
|
|
||||||
formatted
|
formatted
|
||||||
} else {
|
} else {
|
||||||
// Use Llama/CodeLlama format for other models
|
// Use Llama/CodeLlama format for other models
|
||||||
@@ -377,6 +425,66 @@ impl EmbeddedProvider {
|
|||||||
|
|
||||||
cleaned.trim().to_string()
|
cleaned.trim().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Download the Qwen 2.5 7B model if it doesn't exist
|
||||||
|
fn download_qwen_model(model_path: &Path) -> Result<()> {
|
||||||
|
use std::fs;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
const MODEL_URL: &str = "https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-GGUF/resolve/main/qwen2.5-7b-instruct-q3_k_m.gguf";
|
||||||
|
const MODEL_SIZE_MB: u64 = 3631; // Approximate size in MB
|
||||||
|
|
||||||
|
// Create the parent directory if it doesn't exist
|
||||||
|
if let Some(parent) = model_path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Downloading Qwen 2.5 7B model (Q3_K_M quantization, ~3.5GB)...");
|
||||||
|
info!("This is a one-time download that may take several minutes depending on your connection.");
|
||||||
|
info!("Downloading to: {}", model_path.display());
|
||||||
|
|
||||||
|
// Use curl with progress bar for download
|
||||||
|
let output = Command::new("curl")
|
||||||
|
.args(&[
|
||||||
|
"-L", // Follow redirects
|
||||||
|
"-#", // Show progress bar
|
||||||
|
"-f", // Fail on HTTP errors
|
||||||
|
"-o", model_path.to_str().unwrap(),
|
||||||
|
MODEL_URL,
|
||||||
|
])
|
||||||
|
.output()?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
|
||||||
|
// If curl is not available, provide alternative instructions
|
||||||
|
if stderr.contains("command not found") || stderr.contains("not found") {
|
||||||
|
error!("curl is not installed. Please install curl or manually download the model.");
|
||||||
|
error!("Manual download instructions:");
|
||||||
|
error!("1. Download from: {}", MODEL_URL);
|
||||||
|
error!("2. Save to: {}", model_path.display());
|
||||||
|
anyhow::bail!("curl not found - please install curl or download the model manually");
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Failed to download model: {}", stderr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file was created and has reasonable size
|
||||||
|
let metadata = fs::metadata(model_path)?;
|
||||||
|
let size_mb = metadata.len() / (1024 * 1024);
|
||||||
|
|
||||||
|
if size_mb < MODEL_SIZE_MB - 100 { // Allow some variance
|
||||||
|
fs::remove_file(model_path).ok(); // Clean up partial download
|
||||||
|
anyhow::bail!(
|
||||||
|
"Downloaded file appears incomplete ({}MB vs expected ~{}MB). Please try again.",
|
||||||
|
size_mb, MODEL_SIZE_MB
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Successfully downloaded Qwen 2.5 7B model ({}MB)", size_mb);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|||||||
Reference in New Issue
Block a user