diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index acb39d3b45365eaefbf6c287e14dd68efb3e968e..8e8f00fc9aadbeef4c5be8420fb8d3a23535f76c 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -223,9 +223,17 @@ fn init_language_model_settings(cx: &mut AppContext) { cx.observe_global::(update_active_language_model_from_settings) .detach(); - cx.observe(&LanguageModelRegistry::global(cx), |_, cx| { - update_active_language_model_from_settings(cx) - }) + cx.subscribe( + &LanguageModelRegistry::global(cx), + |_, event: &language_model::Event, cx| match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + update_active_language_model_from_settings(cx); + } + _ => {} + }, + ) .detach(); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ae3925498016891690c015a0f0bbd74ae5016632..515f357b5677c474d52633cf73be0dd3deab87c6 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -394,8 +394,15 @@ impl AssistantPanel { cx.subscribe(&context_store, Self::handle_context_store_event), cx.subscribe( &LanguageModelRegistry::global(cx), - |this, _, _: &language_model::ActiveModelChanged, cx| { - this.completion_provider_changed(cx); + |this, _, event: &language_model::Event, cx| match event { + language_model::Event::ActiveModelChanged => { + this.completion_provider_changed(cx); + } + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + this.ensure_authenticated(cx); + } }, ), ]; @@ -588,6 +595,11 @@ impl AssistantPanel { } fn ensure_authenticated(&mut self, cx: &mut ViewContext) { + if self.is_authenticated(cx) { + self.set_authentication_prompt(None, cx); + return; + } + let Some(provider_id) = LanguageModelRegistry::read_global(cx) .active_provider() .map(|p| p.id()) @@ -596,29 +608,35 @@ impl AssistantPanel { }; let load_credentials = self.authenticate(cx); - let task = cx.spawn(|this, mut cx| async move { - let _ = load_credentials.await; - this.update(&mut cx, |this, cx| { - this.show_authentication_prompt(cx); - }) - .log_err(); - }); - self.authenticate_provider_task = Some((provider_id, task)); + self.authenticate_provider_task = Some(( + provider_id, + cx.spawn(|this, mut cx| async move { + let _ = load_credentials.await; + this.update(&mut cx, |this, cx| { + this.show_authentication_prompt(cx); + this.authenticate_provider_task = None; + }) + .log_err(); + }), + )); } fn show_authentication_prompt(&mut self, cx: &mut ViewContext) { + let prompt = Self::authentication_prompt(cx); + self.set_authentication_prompt(prompt, cx); + } + + fn set_authentication_prompt(&mut self, prompt: Option, cx: &mut ViewContext) { if self.active_context_editor(cx).is_none() { self.new_context(cx); } - let authentication_prompt = Self::authentication_prompt(cx); for context_editor in self.context_editors(cx) { context_editor.update(cx, |editor, cx| { - editor.set_authentication_prompt(authentication_prompt.clone(), cx); + editor.set_authentication_prompt(prompt.clone(), cx); }); } - cx.notify(); } diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 6dcc874721a993238244e777c8e81810b3b61771..611e4208c2dd396527f2e17eaba141dff5c4a81d 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -89,7 +89,20 @@ pub trait LanguageModelProvider: 'static { } pub trait LanguageModelProviderState: 'static { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option; + type ObservableEntity; + + fn observable_entity(&self) -> Option>; + + fn subscribe( + &self, + cx: &mut gpui::ModelContext, + callback: impl Fn(&mut T, &mut gpui::ModelContext) + 'static, + ) -> Option { + let entity = self.observable_entity()?; + Some(cx.observe(&entity, move |this, _, cx| { + callback(this, cx); + })) + } } #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)] diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index ddaad618c4612f1305b99cfdc4a12c125128a4a4..3999483da07c2de026216d7315bb8a602c34d065 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider { Self { http_client, state } } } + impl LanguageModelProviderState for AnthropicLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 362539fd852f542b2cffc340dbe977b3a31ee734..dac341b3082ca43e57a39afef0d08220e0dd931e 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -8,7 +8,7 @@ use anyhow::{anyhow, Context as _, Result}; use client::Client; use collections::BTreeMap; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; +use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -50,16 +50,19 @@ pub struct CloudLanguageModelProvider { _maintain_client_status: Task<()>, } -struct State { +pub struct State { client: Arc, status: client::Status, _subscription: Subscription, } impl State { - fn authenticate(&self, cx: &mut AppContext) -> Task> { + fn authenticate(&self, cx: &mut ModelContext) -> Task> { let client = self.client.clone(); - cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) + cx.spawn(move |this, mut cx| async move { + client.authenticate_and_connect(true, &cx).await?; + this.update(&mut cx, |_, cx| cx.notify()) + }) } } @@ -99,10 +102,10 @@ impl CloudLanguageModelProvider { } impl LanguageModelProviderState for CloudLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs index 072c87b92ef7fd014312c5912af4b63f23280ad6..f73ddb74bfb6549d4359817e3fc6f8ee1b159f47 100644 --- a/crates/language_model/src/provider/copilot_chat.rs +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -11,8 +11,8 @@ use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::{ - percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, - ModelContext, Render, Subscription, Task, Transformation, + percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render, + Subscription, Task, Transformation, }; use settings::{Settings, SettingsStore}; use std::time::Duration; @@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider { } impl LanguageModelProviderState for CopilotChatLanguageModelProvider { - fn subscribe(&self, cx: &mut ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index f92ecaf467530e0ae719d799410ee619bdda9465..70f8402bccf827755b60b501f8a7c211eb24f048 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider { } impl LanguageModelProviderState for FakeLanguageModelProvider { - fn subscribe(&self, _: &mut gpui::ModelContext) -> Option { + type ObservableEntity = (); + + fn observable_entity(&self) -> Option> { None } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 2739623c6aceb60f4017ddc15a65180134d20cb8..a1a6cbcceb5d3c11520d498f2fd6ac9a4dd694c3 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider { } impl LanguageModelProviderState for GoogleLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index 0364866ccd80c0a787fcd75812ff954dfe731a68..9afa3825b0b4af4c9a5edc7efbb82aabdbb979c2 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { http_client: Arc, available_models: Vec, _subscription: Subscription, @@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider { } impl LanguageModelProviderState for OllamaLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 9f24dabb094358d2ae4f28459f65c4ac6631ea98..e0239d959bea80e9fe2f642679f5459fbd1af716 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider { state: gpui::Model, } -struct State { +pub struct State { api_key: Option, _subscription: Subscription, } @@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider { } impl LanguageModelProviderState for OpenAiLanguageModelProvider { - fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { - Some(cx.observe(&self.state, |_, _, cx| { - cx.notify(); - })) + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) } } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index a3af7e6b181f4c534d3e2a98c46284362dc3ee47..94a69a9d2f2fd081e5dbf04e7834cc5534579fb8 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -54,9 +54,7 @@ fn register_language_model_providers( registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx); } else { registry.unregister_provider( - &LanguageModelProviderId::from( - crate::provider::cloud::PROVIDER_NAME.to_string(), - ), + LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()), cx, ); } @@ -80,9 +78,14 @@ pub struct ActiveModel { model: Option>, } -pub struct ActiveModelChanged; +pub enum Event { + ActiveModelChanged, + ProviderStateChanged, + AddedProvider(LanguageModelProviderId), + RemovedProvider(LanguageModelProviderId), +} -impl EventEmitter for LanguageModelRegistry {} +impl EventEmitter for LanguageModelRegistry {} impl LanguageModelRegistry { pub fn global(cx: &AppContext) -> Model { @@ -112,23 +115,26 @@ impl LanguageModelRegistry { provider: T, cx: &mut ModelContext, ) { - let name = provider.id(); + let id = provider.id(); - if let Some(subscription) = provider.subscribe(cx) { + let subscription = provider.subscribe(cx, |_, cx| { + cx.emit(Event::ProviderStateChanged); + }); + if let Some(subscription) = subscription { subscription.detach(); } - self.providers.insert(name, Arc::new(provider)); - cx.notify(); + self.providers.insert(id.clone(), Arc::new(provider)); + cx.emit(Event::AddedProvider(id)); } pub fn unregister_provider( &mut self, - name: &LanguageModelProviderId, + id: LanguageModelProviderId, cx: &mut ModelContext, ) { - if self.providers.remove(name).is_some() { - cx.notify(); + if self.providers.remove(&id).is_some() { + cx.emit(Event::RemovedProvider(id)); } } @@ -187,7 +193,7 @@ impl LanguageModelRegistry { provider, model: None, }); - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } pub fn set_active_model( @@ -202,13 +208,13 @@ impl LanguageModelRegistry { provider, model: Some(model), }); - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } else { log::warn!("Active model's provider not found in registry"); } } else { self.active_model = None; - cx.emit(ActiveModelChanged); + cx.emit(Event::ActiveModelChanged); } } @@ -239,7 +245,7 @@ mod tests { assert_eq!(providers[0].id(), crate::provider::fake::provider_id()); registry.update(cx, |registry, cx| { - registry.unregister_provider(&crate::provider::fake::provider_id(), cx); + registry.unregister_provider(crate::provider::fake::provider_id(), cx); }); let providers = registry.read(cx).providers();