@@ -38,6 +38,13 @@ impl AgentModelSelector {
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
+
+ // Authenticate the provider when a model is selected
+ let registry = LanguageModelRegistry::read_global(cx);
+ if let Some(provider) = registry.provider(&model.provider_id()) {
+ provider.authenticate(cx).detach();
+ }
+
match &model_usage_context {
ModelUsageContext::Thread(thread) => {
thread.update(cx, |thread, cx| {
@@ -18,7 +18,7 @@ use ui::{ButtonLike, IconName, Indicator, prelude::*};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("local");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Local");
-const DEFAULT_MODEL: &str = "mlx-community/GLM-4.5-Air-3bit";
+const DEFAULT_MODEL: &str = "Qwen/Qwen2.5-0.5B-Instruct";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LocalSettings {
@@ -63,36 +63,47 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- if matches!(self.status, ModelStatus::Loading) {
+ // Skip if already loaded or currently loading
+ if matches!(self.status, ModelStatus::Loaded | ModelStatus::Loading) {
return Task::ready(Ok(()));
}
self.status = ModelStatus::Loading;
cx.notify();
- cx.spawn(async move |this, cx| match load_mistral_model().await {
- Ok(model) => {
- this.update(cx, |state, cx| {
- state.model = Some(model);
- state.status = ModelStatus::Loaded;
- cx.notify();
- })?;
- Ok(())
- }
- Err(e) => {
- let error_msg = e.to_string();
- this.update(cx, |state, cx| {
- state.status = ModelStatus::Error(error_msg.clone());
- cx.notify();
- })?;
- Err(AuthenticateError::Other(anyhow!(
- "Failed to load model: {}",
- error_msg
- )))
+ let background_executor = cx.background_executor().clone();
+ cx.spawn(async move |this, cx| {
+ eprintln!("Local model: Starting to load model");
+
+ // Move the model loading to a background thread
+ let model_result = background_executor
+ .spawn(async move { load_mistral_model().await })
+ .await;
+
+ match model_result {
+ Ok(model) => {
+ eprintln!("Local model: Model loaded successfully");
+ this.update(cx, |state, cx| {
+ state.model = Some(model);
+ state.status = ModelStatus::Loaded;
+ cx.notify();
+ eprintln!("Local model: Status updated to Loaded");
+ })?;
+ Ok(())
+ }
+ Err(e) => {
+ let error_msg = e.to_string();
+ eprintln!("Local model: Failed to load model - {}", error_msg);
+ this.update(cx, |state, cx| {
+ state.status = ModelStatus::Error(error_msg.clone());
+ cx.notify();
+ eprintln!("Local model: Status updated to Failed");
+ })?;
+ Err(AuthenticateError::Other(anyhow!(
+ "Failed to load model: {}",
+ error_msg
+ )))
+ }
}
})
}
@@ -100,12 +111,26 @@ impl State {
async fn load_mistral_model() -> Result<Arc<MistralModel>> {
println!("\n\n\n\nLoading mistral model...\n\n\n");
- let model = TextModelBuilder::new(DEFAULT_MODEL)
- .with_isq(IsqType::Q4_0)
- .build()
- .await?;
+ eprintln!("Starting to load model: {}", DEFAULT_MODEL);
+
+ // Configure the model builder to use background threads for downloads
+ eprintln!("Creating TextModelBuilder...");
+ let builder = TextModelBuilder::new(DEFAULT_MODEL).with_isq(IsqType::Q4K);
- Ok(Arc::new(model))
+ eprintln!("Building model (this should be quick for a 0.5B model)...");
+ let start_time = std::time::Instant::now();
+
+ match builder.build().await {
+ Ok(model) => {
+ let elapsed = start_time.elapsed();
+ eprintln!("Model loaded successfully in {:?}", elapsed);
+ Ok(Arc::new(model))
+ }
+ Err(e) => {
+ eprintln!("Failed to load model: {:?}", e);
+ Err(e)
+ }
+ }
}
impl LocalLanguageModelProvider {
@@ -256,7 +281,7 @@ impl LanguageModel for LocalLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ true
}
fn supports_images(&self) -> bool {
@@ -264,11 +289,11 @@ impl LanguageModel for LocalLanguageModel {
}
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
- false
+ true
}
fn max_token_count(&self) -> u64 {
- 128000 // GLM-4.5-Air supports 128k context
+ 128000 // Qwen2.5 supports 128k context
}
fn count_tokens(
@@ -315,11 +340,18 @@ impl LanguageModel for LocalLanguageModel {
> = limiter
.run(async move {
let model = cx
- .read_entity(&state, |state, _| state.model.clone())
+ .read_entity(&state, |state, _| {
+ eprintln!(
+ "Local model: Checking if model is loaded: {:?}",
+ state.status
+ );
+ state.model.clone()
+ })
.map_err(|_| {
LanguageModelCompletionError::Other(anyhow!("App state dropped"))
})?
.ok_or_else(|| {
+ eprintln!("Local model: Model is not loaded!");
LanguageModelCompletionError::Other(anyhow!("Model not loaded"))
})?;