Detailed changes
@@ -223,9 +223,17 @@ fn init_language_model_settings(cx: &mut AppContext) {
cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
.detach();
- cx.observe(&LanguageModelRegistry::global(cx), |_, cx| {
- update_active_language_model_from_settings(cx)
- })
+ cx.subscribe(
+ &LanguageModelRegistry::global(cx),
+ |_, event: &language_model::Event, cx| match event {
+ language_model::Event::ProviderStateChanged
+ | language_model::Event::AddedProvider(_)
+ | language_model::Event::RemovedProvider(_) => {
+ update_active_language_model_from_settings(cx);
+ }
+ _ => {}
+ },
+ )
.detach();
}
@@ -394,8 +394,15 @@ impl AssistantPanel {
cx.subscribe(&context_store, Self::handle_context_store_event),
cx.subscribe(
&LanguageModelRegistry::global(cx),
- |this, _, _: &language_model::ActiveModelChanged, cx| {
- this.completion_provider_changed(cx);
+ |this, _, event: &language_model::Event, cx| match event {
+ language_model::Event::ActiveModelChanged => {
+ this.completion_provider_changed(cx);
+ }
+ language_model::Event::ProviderStateChanged
+ | language_model::Event::AddedProvider(_)
+ | language_model::Event::RemovedProvider(_) => {
+ this.ensure_authenticated(cx);
+ }
},
),
];
@@ -588,6 +595,11 @@ impl AssistantPanel {
}
fn ensure_authenticated(&mut self, cx: &mut ViewContext<Self>) {
+ if self.is_authenticated(cx) {
+ self.set_authentication_prompt(None, cx);
+ return;
+ }
+
let Some(provider_id) = LanguageModelRegistry::read_global(cx)
.active_provider()
.map(|p| p.id())
@@ -596,29 +608,35 @@ impl AssistantPanel {
};
let load_credentials = self.authenticate(cx);
- let task = cx.spawn(|this, mut cx| async move {
- let _ = load_credentials.await;
- this.update(&mut cx, |this, cx| {
- this.show_authentication_prompt(cx);
- })
- .log_err();
- });
- self.authenticate_provider_task = Some((provider_id, task));
+ self.authenticate_provider_task = Some((
+ provider_id,
+ cx.spawn(|this, mut cx| async move {
+ let _ = load_credentials.await;
+ this.update(&mut cx, |this, cx| {
+ this.show_authentication_prompt(cx);
+ this.authenticate_provider_task = None;
+ })
+ .log_err();
+ }),
+ ));
}
fn show_authentication_prompt(&mut self, cx: &mut ViewContext<Self>) {
+ let prompt = Self::authentication_prompt(cx);
+ self.set_authentication_prompt(prompt, cx);
+ }
+
+ fn set_authentication_prompt(&mut self, prompt: Option<AnyView>, cx: &mut ViewContext<Self>) {
if self.active_context_editor(cx).is_none() {
self.new_context(cx);
}
- let authentication_prompt = Self::authentication_prompt(cx);
for context_editor in self.context_editors(cx) {
context_editor.update(cx, |editor, cx| {
- editor.set_authentication_prompt(authentication_prompt.clone(), cx);
+ editor.set_authentication_prompt(prompt.clone(), cx);
});
}
-
cx.notify();
}
@@ -89,7 +89,20 @@ pub trait LanguageModelProvider: 'static {
}
pub trait LanguageModelProviderState: 'static {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
+ type ObservableEntity;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
+
+ fn subscribe<T: 'static>(
+ &self,
+ cx: &mut gpui::ModelContext<T>,
+ callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
+ ) -> Option<gpui::Subscription> {
+ let entity = self.observable_entity()?;
+ Some(cx.observe(&entity, move |this, _, cx| {
+ callback(this, cx);
+ }))
+ }
}
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
@@ -44,7 +44,7 @@ pub struct AnthropicLanguageModelProvider {
state: gpui::Model<State>,
}
-struct State {
+pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
@@ -61,11 +61,12 @@ impl AnthropicLanguageModelProvider {
Self { http_client, state }
}
}
+
impl LanguageModelProviderState for AnthropicLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -8,7 +8,7 @@ use anyhow::{anyhow, Context as _, Result};
use client::Client;
use collections::BTreeMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
+use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -50,16 +50,19 @@ pub struct CloudLanguageModelProvider {
_maintain_client_status: Task<()>,
}
-struct State {
+pub struct State {
client: Arc<Client>,
status: client::Status,
_subscription: Subscription,
}
impl State {
- fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
+ fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let client = self.client.clone();
- cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
+ cx.spawn(move |this, mut cx| async move {
+ client.authenticate_and_connect(true, &cx).await?;
+ this.update(&mut cx, |_, cx| cx.notify())
+ })
}
}
@@ -99,10 +102,10 @@ impl CloudLanguageModelProvider {
}
impl LanguageModelProviderState for CloudLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -11,8 +11,8 @@ use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use gpui::{
- percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model,
- ModelContext, Render, Subscription, Task, Transformation,
+ percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, Render,
+ Subscription, Task, Transformation,
};
use settings::{Settings, SettingsStore};
use std::time::Duration;
@@ -67,10 +67,10 @@ impl CopilotChatLanguageModelProvider {
}
impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut ModelContext<T>) -> Option<Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -36,7 +36,9 @@ pub struct FakeLanguageModelProvider {
}
impl LanguageModelProviderState for FakeLanguageModelProvider {
- fn subscribe<T: 'static>(&self, _: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
+ type ObservableEntity = ();
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
None
}
}
@@ -44,7 +44,7 @@ pub struct GoogleLanguageModelProvider {
state: gpui::Model<State>,
}
-struct State {
+pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
@@ -63,10 +63,10 @@ impl GoogleLanguageModelProvider {
}
impl LanguageModelProviderState for GoogleLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -32,7 +32,7 @@ pub struct OllamaLanguageModelProvider {
state: gpui::Model<State>,
}
-struct State {
+pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
_subscription: Subscription,
@@ -87,10 +87,10 @@ impl OllamaLanguageModelProvider {
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -45,7 +45,7 @@ pub struct OpenAiLanguageModelProvider {
state: gpui::Model<State>,
}
-struct State {
+pub struct State {
api_key: Option<String>,
_subscription: Subscription,
}
@@ -64,10 +64,10 @@ impl OpenAiLanguageModelProvider {
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
- fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
- Some(cx.observe(&self.state, |_, _, cx| {
- cx.notify();
- }))
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
+ Some(self.state.clone())
}
}
@@ -54,9 +54,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
- &LanguageModelProviderId::from(
- crate::provider::cloud::PROVIDER_NAME.to_string(),
- ),
+ LanguageModelProviderId::from(crate::provider::cloud::PROVIDER_ID.to_string()),
cx,
);
}
@@ -80,9 +78,14 @@ pub struct ActiveModel {
model: Option<Arc<dyn LanguageModel>>,
}
-pub struct ActiveModelChanged;
+pub enum Event {
+ ActiveModelChanged,
+ ProviderStateChanged,
+ AddedProvider(LanguageModelProviderId),
+ RemovedProvider(LanguageModelProviderId),
+}
-impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
+impl EventEmitter<Event> for LanguageModelRegistry {}
impl LanguageModelRegistry {
pub fn global(cx: &AppContext) -> Model<Self> {
@@ -112,23 +115,26 @@ impl LanguageModelRegistry {
provider: T,
cx: &mut ModelContext<Self>,
) {
- let name = provider.id();
+ let id = provider.id();
- if let Some(subscription) = provider.subscribe(cx) {
+ let subscription = provider.subscribe(cx, |_, cx| {
+ cx.emit(Event::ProviderStateChanged);
+ });
+ if let Some(subscription) = subscription {
subscription.detach();
}
- self.providers.insert(name, Arc::new(provider));
- cx.notify();
+ self.providers.insert(id.clone(), Arc::new(provider));
+ cx.emit(Event::AddedProvider(id));
}
pub fn unregister_provider(
&mut self,
- name: &LanguageModelProviderId,
+ id: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
- if self.providers.remove(name).is_some() {
- cx.notify();
+ if self.providers.remove(&id).is_some() {
+ cx.emit(Event::RemovedProvider(id));
}
}
@@ -187,7 +193,7 @@ impl LanguageModelRegistry {
provider,
model: None,
});
- cx.emit(ActiveModelChanged);
+ cx.emit(Event::ActiveModelChanged);
}
pub fn set_active_model(
@@ -202,13 +208,13 @@ impl LanguageModelRegistry {
provider,
model: Some(model),
});
- cx.emit(ActiveModelChanged);
+ cx.emit(Event::ActiveModelChanged);
} else {
log::warn!("Active model's provider not found in registry");
}
} else {
self.active_model = None;
- cx.emit(ActiveModelChanged);
+ cx.emit(Event::ActiveModelChanged);
}
}
@@ -239,7 +245,7 @@ mod tests {
assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| {
- registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
+ registry.unregister_provider(crate::provider::fake::provider_id(), cx);
});
let providers = registry.read(cx).providers();