Get a smaller model working

Richard Feldman created

Change summary

crates/agent_ui/src/agent_model_selector.rs  |   7 +
crates/language_models/src/provider/local.rs | 100 ++++++++++++++-------
2 files changed, 73 insertions(+), 34 deletions(-)

Detailed changes

crates/agent_ui/src/agent_model_selector.rs 🔗

@@ -38,6 +38,13 @@ impl AgentModelSelector {
                     move |model, cx| {
                         let provider = model.provider_id().0.to_string();
                         let model_id = model.id().0.to_string();
+
+                        // Authenticate the provider when a model is selected
+                        let registry = LanguageModelRegistry::read_global(cx);
+                        if let Some(provider) = registry.provider(&model.provider_id()) {
+                            provider.authenticate(cx).detach();
+                        }
+
                         match &model_usage_context {
                             ModelUsageContext::Thread(thread) => {
                                 thread.update(cx, |thread, cx| {

crates/language_models/src/provider/local.rs 🔗

@@ -18,7 +18,7 @@ use ui::{ButtonLike, IconName, Indicator, prelude::*};
 
 const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
 const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
-const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
+const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
 
 #[derive(Default, Debug, Clone, PartialEq)]
 pub struct LocalSettings {
@@ -63,36 +63,47 @@ impl State {
     }
 
     fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
-        if self.is_authenticated() {
-            return Task::ready(Ok(()));
-        }
-
-        if matches!(self.status, ModelStatus::Loading) {
+        // Skip if already loaded or currently loading
+        if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
             return Task::ready(Ok(()));
         }
 
         self.status = ModelStatus::Loading;
         cx.notify();
 
-        cx.spawn(async move |this, cx| match load_mistral_model().await {
-            Ok(model) => {
-                this.update(cx, |state, cx| {
-                    state.model = Some(model);
-                    state.status = ModelStatus::Loaded;
-                    cx.notify();
-                })?;
-                Ok(())
-            }
-            Err(e) => {
-                let error_msg = e.to_string();
-                this.update(cx, |state, cx| {
-                    state.status = ModelStatus::Error(error_msg.clone());
-                    cx.notify();
-                })?;
-                Err(AuthenticateError::Other(anyhow!(
-                    "Failed to load model: {}",
-                    error_msg
-                )))
+        let background_executor = cx.background_executor().clone();
+        cx.spawn(async move |this, cx| {
+            eprintln!("Local model: Starting to load model");
+
+            // Move the model loading to a background thread
+            let model_result = background_executor
+                .spawn(async move { load_mistral_model().await })
+                .await;
+
+            match model_result {
+                Ok(model) => {
+                    eprintln!("Local model: Model loaded successfully");
+                    this.update(cx, |state, cx| {
+                        state.model = Some(model);
+                        state.status = ModelStatus::Loaded;
+                        cx.notify();
+                        eprintln!("Local model: Status updated to Loaded");
+                    })?;
+                    Ok(())
+                }
+                Err(e) => {
+                    let error_msg = e.to_string();
+                    eprintln!("Local model: Failed to load model - {}", error_msg);
+                    this.update(cx, |state, cx| {
+                        state.status = ModelStatus::Error(error_msg.clone());
+                        cx.notify();
+                        eprintln!("Local model: Status updated to Failed");
+                    })?;
+                    Err(AuthenticateError::Other(anyhow!(
+                        "Failed to load model: {}",
+                        error_msg
+                    )))
+                }
             }
         })
     }
@@ -100,12 +111,26 @@ impl State {
 
 async fn load_mistral_model() -> Result<Arc<MistralModel>> {
     println!("\n\n\n\nLoading mistral model...\n\n\n");
-    let model = TextModelBuilder::new(DEFAULT_MODEL)
-        .with_isq(IsqType::Q4_0)
-        .build()
-        .await?;
+    eprintln!("Starting to load model: {}", DEFAULT_MODEL);
+
+    // Configure the model builder to use background threads for downloads
+    eprintln!("Creating TextModelBuilder...");
+    let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
 
-    Ok(Arc::new(model))
+    eprintln!("Building model (this should be quick for a 0.5B model)...");
+    let start_time = std::time::Instant::now();
+
+    match builder.build().await {
+        Ok(model) => {
+            let elapsed = start_time.elapsed();
+            eprintln!("Model loaded successfully in {:?}", elapsed);
+            Ok(Arc::new(model))
+        }
+        Err(e) => {
+            eprintln!("Failed to load model: {:?}", e);
+            Err(e)
+        }
+    }
 }
 
 impl LocalLanguageModelProvider {
@@ -256,7 +281,7 @@ impl LanguageModel for LocalLanguageModel {
     }
 
     fn supports_tools(&self) -> bool {
-        false
+        true
     }
 
     fn supports_images(&self) -> bool {
@@ -264,11 +289,11 @@ impl LanguageModel for LocalLanguageModel {
     }
 
     fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
-        false
+        true
     }
 
     fn max_token_count(&self) -> u64 {
-        128000 // GLM-4.5-Air supports 128k context
+        128000 // Qwen2.5 supports 128k context
     }
 
     fn count_tokens(
@@ -315,11 +340,18 @@ impl LanguageModel for LocalLanguageModel {
             > = limiter
                 .run(async move {
                     let model = cx
-                        .read_entity(&state, |state, _| state.model.clone())
+                        .read_entity(&state, |state, _| {
+                            eprintln!(
+                                "Local model: Checking if model is loaded: {:?}",
+                                state.status
+                            );
+                            state.model.clone()
+                        })
                         .map_err(|_| {
                             LanguageModelCompletionError::Other(anyhow!("App state dropped"))
                         })?
                         .ok_or_else(|| {
+                            eprintln!("Local model: Model is not loaded!");
                             LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
                         })?;