assistant panel: Fix entering credentials not updating view (#15527)

Thorsten Ball and Bennet created

Co-authored-by: Bennet <bennet@zed.dev>

Release Notes:

- N/A

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/assistant/src/assistant_panel.rs            | 38 +++++++++++----
crates/language_model/src/language_model.rs        | 17 ++++++
crates/language_model/src/provider/anthropic.rs    | 11 ++--
crates/language_model/src/provider/cloud.rs        | 10 ++--
crates/language_model/src/provider/copilot_chat.rs | 12 ++--
crates/language_model/src/provider/fake.rs         |  4 +
crates/language_model/src/provider/google.rs       | 10 ++--
crates/language_model/src/provider/ollama.rs       | 10 ++--
crates/language_model/src/provider/open_ai.rs      | 10 ++--
crates/language_model/src/registry.rs              | 22 +++++---
10 files changed, 92 insertions(+), 52 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -393,8 +393,13 @@ impl AssistantPanel {
             cx.subscribe(&context_store, Self::handle_context_store_event),
             cx.subscribe(
                 &LanguageModelRegistry::global(cx),
-                |this, _, _: &language_model::ActiveModelChanged, cx| {
-                    this.completion_provider_changed(cx);
+                |this, _, event: &language_model::Event, cx| match event {
+                    language_model::Event::ActiveModelChanged => {
+                        this.completion_provider_changed(cx);
+                    }
+                    language_model::Event::ProviderStateChanged => {
+                        this.ensure_authenticated(cx);
+                    }
                 },
             ),
         ];
@@ -587,6 +592,16 @@ impl AssistantPanel {
     }
 
     fn ensure_authenticated(&mut self, cx: &mut ViewContext<Self>) {
+        if self.is_authenticated(cx) {
+            for context_editor in self.context_editors(cx) {
+                context_editor.update(cx, |editor, cx| {
+                    editor.set_authentication_prompt(None, cx);
+                });
+            }
+            cx.notify();
+            return;
+        }
+
         let Some(provider_id) = LanguageModelRegistry::read_global(cx)
             .active_provider()
             .map(|p| p.id())
@@ -595,15 +610,18 @@ impl AssistantPanel {
         };
 
         let load_credentials = self.authenticate(cx);
-        let task = cx.spawn(|this, mut cx| async move {
-            let _ = load_credentials.await;
-            this.update(&mut cx, |this, cx| {
-                this.show_authentication_prompt(cx);
-            })
-            .log_err();
-        });
 
-        self.authenticate_provider_task = Some((provider_id, task));
+        self.authenticate_provider_task = Some((
+            provider_id,
+            cx.spawn(|this, mut cx| async move {
+                let _ = load_credentials.await;
+                this.update(&mut cx, |this, cx| {
+                    this.show_authentication_prompt(cx);
+                    this.authenticate_provider_task = None;
+                })
+                .log_err();
+            }),
+        ));
     }
 
     fn show_authentication_prompt(&mut self, cx: &mut ViewContext<Self>) {

crates/language_model/src/language_model.rs 🔗

@@ -86,10 +86,25 @@ pub trait LanguageModelProvider: 'static {
     fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
     fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
     fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
+
+    // fn observable_entity(&self) ;
 }
 
 pub trait LanguageModelProviderState: 'static {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
+    type ObservableEntity;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
+
+    fn subscribe<T: 'static>(
+        &self,
+        cx: &mut gpui::ModelContext<T>,
+        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
+    ) -> Option<gpui::Subscription> {
+        let entity = self.observable_entity()?;
+        Some(cx.observe(&entity, move |this, _, cx| {
+            callback(this, cx);
+        }))
+    }
 }
 
 #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]

crates/language_model/src/provider/anthropic.rs 🔗

@@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider {
         Self { http_client, state }
     }
 }
+
 impl LanguageModelProviderState for AnthropicLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/cloud.rs 🔗

@@ -50,7 +50,7 @@ pub struct CloudLanguageModelProvider {
     _maintain_client_status: Task<()>,
 }
 
-struct State {
+pub struct State {
     client: Arc<Client>,
     status: client::Status,
     _subscription: Subscription,
@@ -99,10 +99,10 @@ impl CloudLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for CloudLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -11,8 +11,8 @@ use futures::future::BoxFuture;
 use futures::stream::BoxStream;
 use futures::{FutureExt, StreamExt};
 use gpui::{
-    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
-    ModelContext, Render, Subscription, Task, Transformation,
+    percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
+    Subscription, Task, Transformation,
 };
 use settings::{Settings, SettingsStore};
 use std::time::Duration;
@@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/fake.rs 🔗

@@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for FakeLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
+    type ObservableEntity = ();
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
         None
     }
 }

crates/language_model/src/provider/google.rs 🔗

@@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for GoogleLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/ollama.rs 🔗

@@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     http_client: Arc<dyn HttpClient>,
     available_models: Vec<ollama::Model>,
     _subscription: Subscription,
@@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for OllamaLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/provider/open_ai.rs 🔗

@@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider {
     state: gpui::Model<State>,
 }
 
-struct State {
+pub struct State {
     api_key: Option<String>,
     _subscription: Subscription,
 }
@@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider {
 }
 
 impl LanguageModelProviderState for OpenAiLanguageModelProvider {
-    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
-        Some(cx.observe(&self.state, |_, _, cx| {
-            cx.notify();
-        }))
+    type ObservableEntity = State;
+
+    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+        Some(self.state.clone())
     }
 }
 

crates/language_model/src/registry.rs 🔗

@@ -54,9 +54,7 @@ fn register_language_model_providers(
                 registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
             } else {
                 registry.unregister_provider(
-                    &LanguageModelProviderId::from(
-                        crate::provider::cloud::PROVIDER_NAME.to_string(),
-                    ),
+                    &LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
                     cx,
                 );
             }
@@ -80,9 +78,12 @@ pub struct ActiveModel {
     model: Option<Arc<dyn LanguageModel>>,
 }
 
-pub struct ActiveModelChanged;
+pub enum Event {
+    ActiveModelChanged,
+    ProviderStateChanged,
+}
 
-impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
+impl EventEmitter<Event> for LanguageModelRegistry {}
 
 impl LanguageModelRegistry {
     pub fn global(cx: &AppContext) -> Model<Self> {
@@ -114,7 +115,10 @@ impl LanguageModelRegistry {
     ) {
         let name = provider.id();
 
-        if let Some(subscription) = provider.subscribe(cx) {
+        let subscription = provider.subscribe(cx, |_, cx| {
+            cx.emit(Event::ProviderStateChanged);
+        });
+        if let Some(subscription) = subscription {
             subscription.detach();
         }
 
@@ -187,7 +191,7 @@ impl LanguageModelRegistry {
             provider,
             model: None,
         });
-        cx.emit(ActiveModelChanged);
+        cx.emit(Event::ActiveModelChanged);
     }
 
     pub fn set_active_model(
@@ -202,13 +206,13 @@ impl LanguageModelRegistry {
                     provider,
                     model: Some(model),
                 });
-                cx.emit(ActiveModelChanged);
+                cx.emit(Event::ActiveModelChanged);
             } else {
                 log::warn!("Active model's provider not found in registry");
             }
         } else {
             self.active_model = None;
-            cx.emit(ActiveModelChanged);
+            cx.emit(Event::ActiveModelChanged);
         }
     }