assistant: Fix issues when configuring different providers (#15072)

Bennet Bo Fenner and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

assets/settings/default.json                      | 12 ++
crates/assistant/src/assistant.rs                 |  4 
crates/assistant/src/assistant_settings.rs        |  6 
crates/assistant/src/inline_assistant.rs          |  2 
crates/assistant/src/terminal_inline_assistant.rs |  2 
crates/completion/src/completion.rs               | 15 ++-
crates/language_model/src/language_model.rs       | 12 +++
crates/language_model/src/provider/anthropic.rs   | 52 +++++++++----
crates/language_model/src/provider/cloud.rs       | 30 +++++--
crates/language_model/src/provider/fake.rs        | 17 +++
crates/language_model/src/provider/ollama.rs      | 64 ++++++++++------
crates/language_model/src/provider/open_ai.rs     | 52 ++++++++-----
crates/language_model/src/registry.rs             | 20 ++--
crates/language_model/src/settings.rs             | 16 ++--
crates/ollama/src/ollama.rs                       |  4 
docs/src/language-model-integration.md            | 61 +++++-----------
16 files changed, 223 insertions(+), 146 deletions(-)

Detailed changes

assets/settings/default.json 🔗

@@ -853,7 +853,17 @@
     }
   },
   // Different settings for specific language models.
-  "language_models": {},
+  "language_models": {
+    "anthropic": {
+      "api_url": "https://api.anthropic.com"
+    },
+    "openai": {
+      "api_url": "https://api.openai.com/v1"
+    },
+    "ollama": {
+      "api_url": "http://localhost:11434"
+    }
+  },
   // Zed's Prettier integration settings.
   // Allows to enable/disable formatting with Prettier
   // and configure default Prettier, used when no project-level Prettier installation is found.

crates/assistant/src/assistant.rs 🔗

@@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal
 use indexed_docs::IndexedDocsRegistry;
 pub(crate) use inline_assistant::*;
 use language_model::{
-    LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
+    LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
 };
 pub(crate) use model_selector::*;
 use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
@@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) {
 
 fn update_active_language_model_from_settings(cx: &mut AppContext) {
     let settings = AssistantSettings::get_global(cx);
-    let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
+    let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
     let model_id = LanguageModelId::from(settings.default_model.model.clone());
 
     let Some(provider) = LanguageModelRegistry::global(cx)

crates/assistant/src/assistant_settings.rs 🔗

@@ -144,8 +144,8 @@ impl AssistantSettingsContent {
                             fs,
                             cx,
                             move |content, _| {
-                                if content.open_ai.is_none() {
-                                    content.open_ai =
+                                if content.openai.is_none() {
+                                    content.openai =
                                         Some(language_model::settings::OpenAiSettingsContent {
                                             api_url,
                                             low_speed_timeout_in_seconds,
@@ -243,7 +243,7 @@ impl AssistantSettingsContent {
 
     pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
         let model = language_model.id().0.to_string();
-        let provider = language_model.provider_name().0.to_string();
+        let provider = language_model.provider_id().0.to_string();
 
         match self {
             AssistantSettingsContent::Versioned(settings) => match settings {

crates/assistant/src/inline_assistant.rs 🔗

@@ -1438,7 +1438,7 @@ impl Render for PromptEditor {
                                             {
                                                 let model_name = available_model.name().0.clone();
                                                 let provider =
-                                                    available_model.provider_name().0.clone();
+                                                    available_model.provider_id().0.clone();
                                                 move |_| {
                                                     h_flex()
                                                         .w_full()

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -565,7 +565,7 @@ impl Render for PromptEditor {
                                             {
                                                 let model_name = available_model.name().0.clone();
                                                 let provider =
-                                                    available_model.provider_name().0.clone();
+                                                    available_model.provider_id().0.clone();
                                                 move |_| {
                                                     h_flex()
                                                         .w_full()

crates/completion/src/completion.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AppContext, Global, Model, ModelContext, Task};
 use language_model::{
-    LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
+    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
     LanguageModelRequest,
 };
 use smol::lock::{Semaphore, SemaphoreGuardArc};
@@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider {
 
     pub fn set_active_provider(
         &mut self,
-        provider_name: LanguageModelProviderName,
+        provider_name: LanguageModelProviderId,
         cx: &mut ModelContext<Self>,
     ) {
         self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
@@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider {
 
     pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
         if self.active_model.as_ref().map_or(false, |m| {
-            m.id() == model.id() && m.provider_name() == model.provider_name()
+            m.id() == model.id() && m.provider_id() == model.provider_id()
         }) {
             return;
         }
 
         self.active_provider =
-            LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
-        self.active_model = Some(model);
+            LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
+        self.active_model = Some(model.clone());
+
+        if let Some(provider) = self.active_provider.as_ref() {
+            provider.load_model(model, cx);
+        }
+
         cx.notify();
     }
 

crates/language_model/src/language_model.rs 🔗

@@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 pub trait LanguageModel: Send + Sync {
     fn id(&self) -> LanguageModelId;
     fn name(&self) -> LanguageModelName;
+    fn provider_id(&self) -> LanguageModelProviderId;
     fn provider_name(&self) -> LanguageModelProviderName;
     fn telemetry_id(&self) -> String;
 
@@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
 }
 
 pub trait LanguageModelProvider: 'static {
+    fn id(&self) -> LanguageModelProviderId;
     fn name(&self) -> LanguageModelProviderName;
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
+    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
     fn is_authenticated(&self, cx: &AppContext) -> bool;
     fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
     fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
@@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
 pub struct LanguageModelName(pub SharedString);
 
+#[derive(Clone, Eq, PartialEq, Hash, Debug)]
+pub struct LanguageModelProviderId(pub SharedString);
+
 #[derive(Clone, Eq, PartialEq, Hash, Debug)]
 pub struct LanguageModelProviderName(pub SharedString);
 
@@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
     }
 }
 
+impl From<String> for LanguageModelProviderId {
+    fn from(value: String) -> Self {
+        Self(SharedString::from(value))
+    }
+}
+
 impl From<String> for LanguageModelProviderName {
     fn from(value: String) -> Self {
         Self(SharedString::from(value))

crates/language_model/src/provider/anthropic.rs 🔗

@@ -1,6 +1,5 @@
 use anthropic::{stream_completion, Request, RequestMessage};
 use anyhow::{anyhow, Result};
-use collections::HashMap;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{
@@ -9,7 +8,7 @@ use gpui::{
 };
 use http_client::HttpClient;
 use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{collections::BTreeMap, sync::Arc, time::Duration};
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::prelude::*;
@@ -17,11 +16,12 @@ use util::ResultExt;
 
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
-    LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, LanguageModelRequestMessage, Role,
+    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
 };
 
-const PROVIDER_NAME: &str = "anthropic";
+const PROVIDER_ID: &str = "anthropic";
+const PROVIDER_NAME: &str = "Anthropic";
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct AnthropicSettings {
@@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
 
 struct State {
     api_key: Option<String>,
-    settings: AnthropicSettings,
     _subscription: Subscription,
 }
 
@@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
     pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
         let state = cx.new_model(|cx| State {
             api_key: None,
-            settings: AnthropicSettings::default(),
-            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
-                this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
+            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
                 cx.notify();
             }),
         });
@@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 }
 
 impl LanguageModelProvider for AnthropicLanguageModelProvider {
+    fn id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
-        let mut models = HashMap::default();
+        let mut models = BTreeMap::default();
 
         // Add base models from anthropic::Model::iter()
         for model in anthropic::Model::iter() {
@@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
         }
 
         // Override with available models from settings
-        for model in &self.state.read(cx).settings.available_models {
+        for model in AllLanguageModelSettings::get_global(cx)
+            .anthropic
+            .available_models
+            .iter()
+        {
             models.insert(model.id().to_string(), model.clone());
         }
 
@@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
-            let api_url = self.state.read(cx).settings.api_url.clone();
+            let api_url = AllLanguageModelSettings::get_global(cx)
+                .anthropic
+                .api_url
+                .clone();
             let state = self.state.clone();
             cx.spawn(|mut cx| async move {
                 let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
@@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
 
     fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
         let state = self.state.clone();
-        let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
+        let delete_credentials =
+            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
         cx.spawn(|mut cx| async move {
             delete_credentials.await.log_err();
             state.update(&mut cx, |this, cx| {
@@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
         LanguageModelName::from(self.model.display_name().to_string())
     }
 
+    fn provider_id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn provider_name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
@@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
         let request = self.to_anthropic_request(request);
 
         let http_client = self.http_client.clone();
-        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
+
+        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
             (
                 state.api_key.clone(),
-                state.settings.api_url.clone(),
-                state.settings.low_speed_timeout,
+                settings.api_url.clone(),
+                settings.low_speed_timeout,
             )
         }) else {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@@ -365,7 +380,10 @@ impl AuthenticationPrompt {
         }
 
         let write_credentials = cx.write_credentials(
-            &self.state.read(cx).settings.api_url,
+            AllLanguageModelSettings::get_global(cx)
+                .anthropic
+                .api_url
+                .as_str(),
             "Bearer",
             api_key.as_bytes(),
         );

crates/language_model/src/provider/cloud.rs 🔗

@@ -1,15 +1,15 @@
 use super::open_ai::count_open_ai_tokens;
 use crate::{
     settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
-    LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelRequest,
 };
 use anyhow::Result;
 use client::Client;
-use collections::HashMap;
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
 use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::{collections::BTreeMap, sync::Arc};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
 
@@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
 
 use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
 
+pub const PROVIDER_ID: &str = "zed.dev";
 pub const PROVIDER_NAME: &str = "zed.dev";
 
 #[derive(Default, Clone, Debug, PartialEq)]
@@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
 struct State {
     client: Arc<Client>,
     status: client::Status,
-    settings: ZedDotDevSettings,
     _subscription: Subscription,
 }
 
@@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
         let state = cx.new_model(|cx| State {
             client: client.clone(),
             status,
-            settings: ZedDotDevSettings::default(),
-            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
-                this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
+            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
                 cx.notify();
             }),
         });
@@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
 }
 
 impl LanguageModelProvider for CloudLanguageModelProvider {
+    fn id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
-        let mut models = HashMap::default();
+        let mut models = BTreeMap::default();
 
         // Add base models from CloudModel::iter()
         for model in CloudModel::iter() {
@@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         }
 
         // Override with available models from settings
-        for model in &self.state.read(cx).settings.available_models {
+        for model in &AllLanguageModelSettings::get_global(cx)
+            .zed_dot_dev
+            .available_models
+        {
             models.insert(model.id().to_string(), model.clone());
         }
 
@@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
         LanguageModelName::from(self.model.display_name().to_string())
     }
 
+    fn provider_id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn provider_name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
@@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
             | CloudModel::Claude3Opus
             | CloudModel::Claude3Sonnet
             | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
+            CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
+                count_anthropic_tokens(request, cx)
+            }
             _ => {
                 let request = self.client.request(proto::CountTokensWithLanguageModel {
                     model: self.model.id().to_string(),

crates/language_model/src/provider/fake.rs 🔗

@@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
 
 use crate::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
-    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+    LanguageModelRequest,
 };
 use gpui::{AnyView, AppContext, AsyncAppContext, Task};
 use http_client::Result;
@@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
     LanguageModelName::from("Fake".to_string())
 }
 
+pub fn provider_id() -> LanguageModelProviderId {
+    LanguageModelProviderId::from("fake".to_string())
+}
+
 pub fn provider_name() -> LanguageModelProviderName {
-    LanguageModelProviderName::from("fake".to_string())
+    LanguageModelProviderName::from("Fake".to_string())
 }
 
 #[derive(Clone, Default)]
@@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
 }
 
 impl LanguageModelProvider for FakeLanguageModelProvider {
+    fn id(&self) -> LanguageModelProviderId {
+        provider_id()
+    }
+
     fn name(&self) -> LanguageModelProviderName {
         provider_name()
     }
@@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
         language_model_name()
     }
 
+    fn provider_id(&self) -> LanguageModelProviderId {
+        provider_id()
+    }
+
     fn provider_name(&self) -> LanguageModelProviderName {
         provider_name()
     }

crates/language_model/src/provider/ollama.rs 🔗

@@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
 use http_client::HttpClient;
-use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
+use ollama::{
+    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+};
 use settings::{Settings, SettingsStore};
 use std::{sync::Arc, time::Duration};
 use ui::{prelude::*, ButtonLike, ElevationIndex};
 
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
-    LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, Role,
+    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelRequest, Role,
 };
 
 const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 
-const PROVIDER_NAME: &str = "ollama";
+const PROVIDER_ID: &str = "ollama";
+const PROVIDER_NAME: &str = "Ollama";
 
 #[derive(Default, Debug, Clone, PartialEq)]
 pub struct OllamaSettings {
@@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
 struct State {
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<ollama::Model>,
-    settings: OllamaSettings,
     _subscription: Subscription,
 }
 
 impl State {
-    fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+    fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
+        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
         let http_client = self.http_client.clone();
-        let api_url = self.settings.api_url.clone();
+        let api_url = settings.api_url.clone();
 
         // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
         cx.spawn(|this, mut cx| async move {
@@ -66,23 +69,25 @@ impl State {
 
 impl OllamaLanguageModelProvider {
     pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
-        Self {
+        let this = Self {
             http_client: http_client.clone(),
             state: cx.new_model(|cx| State {
                 http_client,
                 available_models: Default::default(),
-                settings: OllamaSettings::default(),
                 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
-                    this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
+                    this.fetch_models(cx).detach_and_log_err(cx);
                     cx.notify();
                 }),
             }),
-        }
+        };
+        this.fetch_models(cx).detach_and_log_err(cx);
+        this
     }
 
     fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
+        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
         let http_client = self.http_client.clone();
-        let api_url = self.state.read(cx).settings.api_url.clone();
+        let api_url = settings.api_url.clone();
 
         let state = self.state.clone();
         // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
@@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
 }
 
 impl LanguageModelProvider for OllamaLanguageModelProvider {
+    fn id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
@@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
                     id: LanguageModelId::from(model.name.clone()),
                     model: model.clone(),
                     http_client: self.http_client.clone(),
-                    state: self.state.clone(),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
     }
 
+    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
+        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+        let http_client = self.http_client.clone();
+        let api_url = settings.api_url.clone();
+        let id = model.id().0.to_string();
+        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
+            .detach_and_log_err(cx);
+    }
+
     fn is_authenticated(&self, cx: &AppContext) -> bool {
         !self.state.read(cx).available_models.is_empty()
     }
@@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
 pub struct OllamaLanguageModel {
     id: LanguageModelId,
     model: ollama::Model,
-    state: gpui::Model<State>,
     http_client: Arc<dyn HttpClient>,
 }
 
@@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
         LanguageModelName::from(self.model.display_name().to_string())
     }
 
+    fn provider_id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
+    fn provider_name(&self) -> LanguageModelProviderName {
+        LanguageModelProviderName(PROVIDER_NAME.into())
+    }
+
     fn max_token_count(&self) -> usize {
         self.model.max_token_count()
     }
@@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
         format!("ollama/{}", self.model.id())
     }
 
-    fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
-    }
-
     fn count_tokens(
         &self,
         request: LanguageModelRequest,
@@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
         let request = self.to_ollama_request(request);
 
         let http_client = self.http_client.clone();
-        let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
-            (
-                state.settings.api_url.clone(),
-                state.settings.low_speed_timeout,
-            )
+        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+            (settings.api_url.clone(), settings.low_speed_timeout)
         }) else {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
         };

crates/language_model/src/provider/open_ai.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::{anyhow, Result};
-use collections::HashMap;
+use collections::BTreeMap;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{future::BoxFuture, FutureExt, StreamExt};
 use gpui::{
@@ -17,11 +17,12 @@ use util::ResultExt;
 
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
-    LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, Role,
+    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelRequest, Role,
 };
 
-const PROVIDER_NAME: &str = "openai";
+const PROVIDER_ID: &str = "openai";
+const PROVIDER_NAME: &str = "OpenAI";
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct OpenAiSettings {
@@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
 
 struct State {
     api_key: Option<String>,
-    settings: OpenAiSettings,
     _subscription: Subscription,
 }
 
@@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
     pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
         let state = cx.new_model(|cx| State {
             api_key: None,
-            settings: OpenAiSettings::default(),
-            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
-                this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
+            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
                 cx.notify();
             }),
         });
@@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 }
 
 impl LanguageModelProvider for OpenAiLanguageModelProvider {
+    fn id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
-        let mut models = HashMap::default();
+        let mut models = BTreeMap::default();
 
         // Add base models from open_ai::Model::iter()
         for model in open_ai::Model::iter() {
@@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
         }
 
         // Override with available models from settings
-        for model in &self.state.read(cx).settings.available_models {
+        for model in &AllLanguageModelSettings::get_global(cx)
+            .openai
+            .available_models
+        {
             models.insert(model.id().to_string(), model.clone());
         }
 
@@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
-            let api_url = self.state.read(cx).settings.api_url.clone();
+            let api_url = AllLanguageModelSettings::get_global(cx)
+                .openai
+                .api_url
+                .clone();
             let state = self.state.clone();
             cx.spawn(|mut cx| async move {
                 let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
@@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
     }
 
     fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
-        let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
+        let settings = &AllLanguageModelSettings::get_global(cx).openai;
+        let delete_credentials = cx.delete_credentials(&settings.api_url);
         let state = self.state.clone();
         cx.spawn(|mut cx| async move {
             delete_credentials.await.log_err();
@@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
         LanguageModelName::from(self.model.display_name().to_string())
     }
 
+    fn provider_id(&self) -> LanguageModelProviderId {
+        LanguageModelProviderId(PROVIDER_ID.into())
+    }
+
     fn provider_name(&self) -> LanguageModelProviderName {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
@@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
         let request = self.to_open_ai_request(request);
 
         let http_client = self.http_client.clone();
-        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
+        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+            let settings = &AllLanguageModelSettings::get_global(cx).openai;
             (
                 state.api_key.clone(),
-                state.settings.api_url.clone(),
-                state.settings.low_speed_timeout,
+                settings.api_url.clone(),
+                settings.low_speed_timeout,
             )
         }) else {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@@ -307,11 +321,9 @@ impl AuthenticationPrompt {
             return;
         }
 
-        let write_credentials = cx.write_credentials(
-            &self.state.read(cx).settings.api_url,
-            "Bearer",
-            api_key.as_bytes(),
-        );
+        let settings = &AllLanguageModelSettings::get_global(cx).openai;
+        let write_credentials =
+            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
         let state = self.state.clone();
         cx.spawn(|_, mut cx| async move {
             write_credentials.await?;

crates/language_model/src/registry.rs 🔗

@@ -9,7 +9,7 @@ use crate::{
         anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
         ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
     },
-    LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
+    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
 };
 
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@@ -48,7 +48,7 @@ fn register_language_model_providers(
                 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
             } else {
                 registry.unregister_provider(
-                    &LanguageModelProviderName::from(
+                    &LanguageModelProviderId::from(
                         crate::provider::cloud::PROVIDER_NAME.to_string(),
                     ),
                     cx,
@@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
 
 #[derive(Default)]
 pub struct LanguageModelRegistry {
-    providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
+    providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 }
 
 impl LanguageModelRegistry {
@@ -94,7 +94,7 @@ impl LanguageModelRegistry {
         provider: T,
         cx: &mut ModelContext<Self>,
     ) {
-        let name = provider.name();
+        let name = provider.id();
 
         if let Some(subscription) = provider.subscribe(cx) {
             subscription.detach();
@@ -106,7 +106,7 @@ impl LanguageModelRegistry {
 
     pub fn unregister_provider(
         &mut self,
-        name: &LanguageModelProviderName,
+        name: &LanguageModelProviderId,
         cx: &mut ModelContext<Self>,
     ) {
         if self.providers.remove(name).is_some() {
@@ -116,7 +116,7 @@ impl LanguageModelRegistry {
 
     pub fn providers(
         &self,
-    ) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
+    ) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
         self.providers.iter()
     }
 
@@ -130,7 +130,7 @@ impl LanguageModelRegistry {
     pub fn available_models_grouped_by_provider(
         &self,
         cx: &AppContext,
-    ) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
+    ) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
         self.providers
             .iter()
             .map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@@ -139,7 +139,7 @@ impl LanguageModelRegistry {
 
     pub fn provider(
         &self,
-        name: &LanguageModelProviderName,
+        name: &LanguageModelProviderId,
     ) -> Option<Arc<dyn LanguageModelProvider>> {
         self.providers.get(name).cloned()
     }
@@ -160,10 +160,10 @@ mod tests {
 
         let providers = registry.read(cx).providers().collect::<Vec<_>>();
         assert_eq!(providers.len(), 1);
-        assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
+        assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
 
         registry.update(cx, |registry, cx| {
-            registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
+            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
         });
 
         let providers = registry.read(cx).providers().collect::<Vec<_>>();

crates/language_model/src/settings.rs 🔗

@@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
 
 #[derive(Default)]
 pub struct AllLanguageModelSettings {
-    pub open_ai: OpenAiSettings,
     pub anthropic: AnthropicSettings,
     pub ollama: OllamaSettings,
+    pub openai: OpenAiSettings,
     pub zed_dot_dev: ZedDotDevSettings,
 }
 
@@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
 pub struct AllLanguageModelSettingsContent {
     pub anthropic: Option<AnthropicSettingsContent>,
     pub ollama: Option<OllamaSettingsContent>,
-    pub open_ai: Option<OpenAiSettingsContent>,
+    pub openai: Option<OpenAiSettingsContent>,
     #[serde(rename = "zed.dev")]
     pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
 }
@@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
             }
 
             merge(
-                &mut settings.open_ai.api_url,
-                value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
+                &mut settings.openai.api_url,
+                value.openai.as_ref().and_then(|s| s.api_url.clone()),
             );
             if let Some(low_speed_timeout_in_seconds) = value
-                .open_ai
+                .openai
                 .as_ref()
                 .and_then(|s| s.low_speed_timeout_in_seconds)
             {
-                settings.open_ai.low_speed_timeout =
+                settings.openai.low_speed_timeout =
                     Some(Duration::from_secs(low_speed_timeout_in_seconds));
             }
             merge(
-                &mut settings.open_ai.available_models,
+                &mut settings.openai.available_models,
                 value
-                    .open_ai
+                    .openai
                     .as_ref()
                     .and_then(|s| s.available_models.clone()),
             );

crates/ollama/src/ollama.rs 🔗

@@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use isahc::config::Configurable;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use std::{convert::TryFrom, time::Duration};
+use std::{convert::TryFrom, sync::Arc, time::Duration};
 
 pub const OLLAMA_API_URL: &str = "http://localhost:11434";
 
@@ -243,7 +243,7 @@ pub async fn get_models(
 }
 
 /// Sends an empty request to Ollama to trigger loading the model
-pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
+pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
     let uri = format!("{api_url}/api/generate");
     let request = HttpRequest::builder()
         .method(Method::POST)

docs/src/language-model-integration.md 🔗

@@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`:
 
 ```json
 {
-  "assistant": {
-    "version": "1",
-    "provider": {
-      "name": "openai",
-      "type": "openai",
-      "default_model": "gpt-4-turbo-preview",
+  "language_models": {
+    "openai": {
       "api_url": "http://localhost:11434/v1"
     }
   }
@@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`.
 
 You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
 
-1. Add the following to your Zed `settings.json`:
-
-  ```json
-  {
-    "assistant": {
-      "version": "1",
-      "provider": {
-        "name": "openai",
-        "type": "openai",
-        "default_model": "gpt-4-turbo-preview",
-        "api_url": "http://localhost:11434/v1"
-      }
-    }
-  }
+1. Download, for example, the `mistral` model with Ollama:
   ```
-2. Download, for example, the `mistral` model with Ollama:
+  ollama pull mistral
   ```
-  ollama run mistral
+2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
   ```
-3. Copy the model and change its name to match the model in the Zed `settings.json`:
+  ollama serve
   ```
-  ollama cp mistral gpt-4-turbo-preview
-  ```
-4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key:
-  ```
-  ollama
-  ```
-5. Restart Zed
-
-### Using Claude 3.5 Sonnet
-
-You can use Claude with the Zed assistant by adding the following settings:
+3. In the assistant panel, select one of the Ollama models using the model dropdown.
+4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
 
 ```json
-"assistant": {
-  "version": "1",
-  "provider": {
-    "default_model": "claude-3-5-sonnet",
-    "name": "anthropic"
+{
+  "language_models": {
+    "ollama": {
+      "api_url": "http://localhost:11434"
+    }
   }
-},
+}
 ```
 
-When you save the settings, the assistant panel will open and ask you to add your Anthropic API key.
-You need can obtain this key [here](https://console.anthropic.com/settings/keys).
+### Using Claude 3.5 Sonnet
+
+You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel.
+
+You need can obtain an API key [here](https://console.anthropic.com/settings/keys).
 
 Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.