assistant panel: Fix panel not reloading after entering credentials (#15531)

Bennet Bo Fenner , Thorsten , and Thorsten Ball created

This is the revised version of #15527.

We also added new events to notify subscribers when new providers are
added or removed.

Co-Authored-by: Thorsten <thorsten@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Thorsten Ball <mrnugget@gmail.com>

Change summary

crates/assistant/src/assistant.rs                  | 14 ++++-
crates/assistant/src/assistant_panel.rs            | 44 +++++++++++----
crates/language_model/src/language_model.rs        | 15 +++++
crates/language_model/src/provider/anthropic.rs    | 11 ++-
crates/language_model/src/provider/cloud.rs        | 19 ++++--
crates/language_model/src/provider/copilot_chat.rs | 12 ++--
crates/language_model/src/provider/fake.rs         |  4 +
crates/language_model/src/provider/google.rs       | 10 +-
crates/language_model/src/provider/ollama.rs       | 10 +-
crates/language_model/src/provider/open_ai.rs      | 10 +-
crates/language_model/src/registry.rs              | 38 ++++++++-----
11 files changed, 119 insertions(+), 68 deletions(-)

Detailed changes

crates/assistant/src/assistant.rs 🔗

@@ -223,9 +223,17 @@ fn init_language_model_settings(cx: &mut AppContext) {
 
     cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
         .detach();
-    cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
-        update_active_language_model_from_settings(cx)
-    })
+    cx.subscribe(
+        &LanguageModelRegistry::global(cx),
+        |_, event: &language_model::Event, cx| match event {
+            language_model::Event::ProviderStateChanged
+            | language_model::Event::AddedProvider(_)
+            | language_model::Event::RemovedProvider(_) => {
+                update_active_language_model_from_settings(cx);
+            }
+            _ => {}
+        },
+    )
     .detach();
 }
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -394,8 +394,15 @@ impl AssistantPanel {
             cx.subscribe(&context_store, Self::handle_context_store_event),
             cx.subscribe(
                 &LanguageModelRegistry::global(cx),
-                |this, _, _: &language_model::ActiveModelChanged, cx| {
-                    this.completion_provider_changed(cx);
+                |this, _, event: &language_model::Event, cx| match event {
+                    language_model::Event::ActiveModelChanged => {
+                        this.completion_provider_changed(cx);
+                    }
+                    language_model::Event::ProviderStateChanged
+                    | language_model::Event::AddedProvider(_)
+                    | language_model::Event::RemovedProvider(_) => {
+                        this.ensure_authenticated(cx);
+                    }
                 },
             ),
         ];
@@ -588,6 +595,11 @@ impl AssistantPanel {
     }
 
     fn ensure_authenticated(&mut self, cx: &mut ViewContext<Self>) {
+        if self.is_authenticated(cx) {
+            self.set_authentication_prompt(None, cx);
+            return;
+        }
+
         let Some(provider_id) = LanguageModelRegistry::read_global(cx)
             .active_provider()
             .map(|p| p.id())
@@ -596,29 +608,35 @@ impl AssistantPanel {
         };
 
         let load_credentials = self.authenticate(cx);
-        let task = cx.spawn(|this, mut cx| async move {
-            let _ = load_credentials.await;
-            this.update(&mut cx, |this, cx| {
-                this.show_authentication_prompt(cx);
-            })
-            .log_err();
-        });
 
-        self.authenticate_provider_task = Some((provider_id, task));
+        self.authenticate_provider_task = Some((
+            provider_id,
+            cx.spawn(|this, mut cx| async move {
+                let _ = load_credentials.await;
+                this.update(&mut cx, |this, cx| {
+                    this.show_authentication_prompt(cx);
+                    this.authenticate_provider_task = None;
+                })
+                .log_err();
+            }),
+        ));
     }
 
     fn show_authentication_prompt(&mut self, cx: &mut ViewContext<Self>) {
+        let prompt = Self::authentication_prompt(cx);
+        self.set_authentication_prompt(prompt, cx);
+    }
+
+    fn set_authentication_prompt(&mut self, prompt: Option<AnyView>, cx: &mut ViewContext<Self>) {
         if self.active_context_editor(cx).is_none() {
             self.new_context(cx);
         }
 
-        let authentication_prompt = Self::authentication_prompt(cx);
         for context_editor in self.context_editors(cx) {
             context_editor.update(cx, |editor, cx| {
-                editor.set_authentication_prompt(authentication_prompt.clone(), cx);
+                editor.set_authentication_prompt(prompt.clone(), cx);
             });
         }
-
         cx.notify();
     }
 

crates/language_model/src/language_model.rs 🔗

@@ -89,7 +89,20 @@ pub trait LanguageModelProvider: 'static {
 }
 
 pub trait LanguageModelProviderState: 'static {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
+    type ObservableEntity;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
+
+    fn subscribe<T: 'static>(
+        &self,
+        cx: &mut gpui::ModelContext<T>,
+        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
+    ) -> Option<gpui::Subscription> {
+        let entity = self.observable_entity()?;
+        Some(cx.observe(&entity, move |this, _, cx| {
+            callback(this, cx);
+        }))
+    }
 }
 
 #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]

crates/language_model/src/provider/anthropic.rs 🔗

@@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider {
         Self { http_client, state }
     }
 }
+
 impl LanguageModelProviderState for AnthropicLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/cloud.rs 🔗

@@ -8,7 +8,7 @@ use anyhow::{anyhow, Context as _, Result};
 use client::Client;
 use collections::BTreeMap;
 use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
+use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
@@ -50,16 +50,19 @@ pub struct CloudLanguageModelProvider {
     _maintain_client_status: Task<()>,
 }
 
-struct State {
+pub struct State {
     client: Arc<Client>,
     status: client::Status,
     _subscription: Subscription,
 }
 
 impl State {
-    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
         let client = self.client.clone();
-        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
+        cx.spawn(move |this, mut cx| async move {
+            client.authenticate_and_connect(true, &cx).await?;
+            this.update(&mut cx, |_, cx| cx.notify())
+        })
     }
 }
 
@@ -99,10 +102,10 @@ impl CloudLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for CloudLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -11,8 +11,8 @@ use futures::future::BoxFuture;
 use futures::stream::BoxStream;
 use futures::{FutureExt, StreamExt};
 use gpui::{
-    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
-    ModelContext, Render, Subscription, Task, Transformation,
+    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
+    Subscription, Task, Transformation,
 };
 use settings::{Settings, SettingsStore};
 use std::time::Duration;
@@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/fake.rs 🔗

@@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for FakeLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
+    type ObservableEntity = ();
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
         None
     }
 }

crates/language_model/src/provider/google.rs 🔗

@@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for GoogleLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/ollama.rs 🔗

@@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<ollama::Model>,
     _subscription: Subscription,
@@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for OllamaLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/open_ai.rs 🔗

@@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for OpenAiLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/registry.rs 🔗

@@ -54,9 +54,7 @@ fn register_language_model_providers(
                 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
             } else {
                 registry.unregister_provider(
-                    &LanguageModelProviderId::from(
-                        crate::provider::cloud::PROVIDER_NAME.to_string(),
-                    ),
+                    LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
                     cx,
                 );
             }
@@ -80,9 +78,14 @@ pub struct ActiveModel {
     model: Option<Arc<dyn LanguageModel>>,
 }
 
-pub struct ActiveModelChanged;
+pub enum Event {
+    ActiveModelChanged,
+    ProviderStateChanged,
+    AddedProvider(LanguageModelProviderId),
+    RemovedProvider(LanguageModelProviderId),
+}
 
-impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
+impl EventEmitter<Event> for LanguageModelRegistry {}
 
 impl LanguageModelRegistry {
     pub fn global(cx: &AppContext) -> Model<Self> {
@@ -112,23 +115,26 @@ impl LanguageModelRegistry {
         provider: T,
         cx: &mut ModelContext<Self>,
     ) {
-        let name = provider.id();
+        let id = provider.id();
 
-        if let Some(subscription) = provider.subscribe(cx) {
+        let subscription = provider.subscribe(cx, |_, cx| {
+            cx.emit(Event::ProviderStateChanged);
+        });
+        if let Some(subscription) = subscription {
             subscription.detach();
         }
 
-        self.providers.insert(name, Arc::new(provider));
-        cx.notify();
+        self.providers.insert(id.clone(), Arc::new(provider));
+        cx.emit(Event::AddedProvider(id));
     }
 
     pub fn unregister_provider(
         &mut self,
-        name: &LanguageModelProviderId,
+        id: LanguageModelProviderId,
         cx: &mut ModelContext<Self>,
     ) {
-        if self.providers.remove(name).is_some() {
-            cx.notify();
+        if self.providers.remove(&id).is_some() {
+            cx.emit(Event::RemovedProvider(id));
         }
     }
 
@@ -187,7 +193,7 @@ impl LanguageModelRegistry {
             provider,
             model: None,
         });
-        cx.emit(ActiveModelChanged);
+        cx.emit(Event::ActiveModelChanged);
     }
 
     pub fn set_active_model(
@@ -202,13 +208,13 @@ impl LanguageModelRegistry {
                     provider,
                     model: Some(model),
                 });
-                cx.emit(ActiveModelChanged);
+                cx.emit(Event::ActiveModelChanged);
             } else {
                 log::warn!("Active model's provider not found in registry");
             }
         } else {
             self.active_model = None;
-            cx.emit(ActiveModelChanged);
+            cx.emit(Event::ActiveModelChanged);
         }
     }
 
@@ -239,7 +245,7 @@ mod tests {
         assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
 
         registry.update(cx, |registry, cx| {
-            registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
+            registry.unregister_provider(crate::provider::fake::provider_id(), cx);
         });
 
         let providers = registry.read(cx).providers();