Support OAuth extensions in settings panel

Richard Feldman created

Change summary

Cargo.lock                                                     |   1 
crates/extension_host/src/wasm_host/llm_provider.rs            | 116 
crates/settings_ui/Cargo.toml                                  |   1 
crates/settings_ui/src/page_data.rs                            |   4 
crates/settings_ui/src/pages/edit_prediction_provider_setup.rs | 237 +++
5 files changed, 263 insertions(+), 96 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -14831,6 +14831,7 @@ dependencies = [
  "gpui",
  "heck 0.5.0",
  "language",
+ "language_model",
  "language_models",
  "log",
  "menu",

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -34,7 +34,7 @@ use markdown::{Markdown, MarkdownElement, MarkdownStyle};
 use settings::Settings;
 use std::sync::Arc;
 use theme::ThemeSettings;
-use ui::{Label, LabelSize, prelude::*};
+use ui::{ConfiguredApiCard, Label, LabelSize, prelude::*};
 use util::ResultExt as _;
 use workspace::Workspace;
 use workspace::oauth_device_flow_modal::{
@@ -658,17 +658,25 @@ impl ExtensionProviderConfigurationView {
         let icon_path = self.icon_path.clone();
         let this_handle = cx.weak_entity();
 
-        // Get workspace to show modal
-        let Some(workspace) = window.root::<Workspace>().flatten() else {
+        // Get workspace window handle to show modal - try current window first, then find any workspace window
+        log::info!("OAuth: Looking for workspace window");
+        let workspace_window = window.window_handle().downcast::<Workspace>().or_else(|| {
+            log::info!("OAuth: Current window is not a workspace, searching other windows");
+            cx.windows()
+                .into_iter()
+                .find_map(|window_handle| window_handle.downcast::<Workspace>())
+        });
+
+        let Some(workspace_window) = workspace_window else {
+            log::error!("OAuth: Could not find any workspace window");
             self.oauth_in_progress = false;
             self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string());
             cx.notify();
             return;
         };
-
-        let workspace = workspace.downgrade();
+        log::info!("OAuth: Found workspace window");
         let state = state.downgrade();
-        cx.spawn_in(window, async move |_this, cx| {
+        cx.spawn(async move |_this, cx| {
             // Step 1: Start device flow - get prompt info from extension
             let start_result = extension
                 .call({
@@ -683,12 +691,22 @@ impl ExtensionProviderConfigurationView {
                 })
                 .await;
 
+            log::info!(
+                "OAuth: Device flow start result: {:?}",
+                start_result.is_ok()
+            );
             let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
-                Ok(Ok(Ok(info))) => info,
+                Ok(Ok(Ok(info))) => {
+                    log::info!(
+                        "OAuth: Got device flow prompt info, user_code: {}",
+                        info.user_code
+                    );
+                    info
+                }
                 Ok(Ok(Err(e))) => {
-                    log::error!("Device flow start failed: {}", e);
+                    log::error!("OAuth: Device flow start failed: {}", e);
                     this_handle
-                        .update_in(cx, |this, _window, cx| {
+                        .update(cx, |this, cx| {
                             this.oauth_in_progress = false;
                             this.oauth_error = Some(e);
                             cx.notify();
@@ -697,9 +715,9 @@ impl ExtensionProviderConfigurationView {
                     return;
                 }
                 Ok(Err(e)) | Err(e) => {
-                    log::error!("Device flow start error: {}", e);
+                    log::error!("OAuth: Device flow start error: {}", e);
                     this_handle
-                        .update_in(cx, |this, _window, cx| {
+                        .update(cx, |this, cx| {
                             this.oauth_in_progress = false;
                             this.oauth_error = Some(e.to_string());
                             cx.notify();
@@ -721,20 +739,31 @@ impl ExtensionProviderConfigurationView {
                 icon_path,
             };
 
-            let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace
-                .update_in(cx, |workspace, window, cx| {
+            log::info!("OAuth: Attempting to show modal in workspace window");
+            let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace_window
+                .update(cx, |workspace, window, cx| {
+                    log::info!("OAuth: Inside workspace.update, creating modal");
+                    window.activate_window();
                     let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
                     let flow_state_clone = flow_state.clone();
                     workspace.toggle_modal(window, cx, |_window, cx| {
+                        log::info!("OAuth: Inside toggle_modal callback");
                         OAuthDeviceFlowModal::new(flow_state_clone, cx)
                     });
                     flow_state
                 })
                 .ok();
 
+            log::info!(
+                "OAuth: workspace_window.update result: {:?}",
+                flow_state.is_some()
+            );
             let Some(flow_state) = flow_state else {
+                log::error!(
+                    "OAuth: Failed to show sign-in modal - workspace_window.update returned None"
+                );
                 this_handle
-                    .update_in(cx, |this, _window, cx| {
+                    .update(cx, |this, cx| {
                         this.oauth_in_progress = false;
                         this.oauth_error = Some("Failed to show sign-in modal".to_string());
                         cx.notify();
@@ -742,6 +771,7 @@ impl ExtensionProviderConfigurationView {
                     .log_err();
                 return;
             };
+            log::info!("OAuth: Modal shown successfully, starting poll");
 
             // Step 3: Poll for authentication completion
             let poll_result = extension
@@ -778,7 +808,7 @@ impl ExtensionProviderConfigurationView {
                     };
 
                     state
-                        .update_in(cx, |state, _window, cx| {
+                        .update(cx, |state, cx| {
                             state.is_authenticated = true;
                             state.available_models = new_models;
                             cx.notify();
@@ -787,7 +817,7 @@ impl ExtensionProviderConfigurationView {
 
                     // Update flow state to show success
                     flow_state
-                        .update_in(cx, |state, _window, cx| {
+                        .update(cx, |state, cx| {
                             state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
                         })
                         .log_err();
@@ -795,12 +825,12 @@ impl ExtensionProviderConfigurationView {
                 Ok(Ok(Err(e))) => {
                     log::error!("Device flow poll failed: {}", e);
                     flow_state
-                        .update_in(cx, |state, _window, cx| {
+                        .update(cx, |state, cx| {
                             state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
                         })
                         .log_err();
                     this_handle
-                        .update_in(cx, |this, _window, cx| {
+                        .update(cx, |this, cx| {
                             this.oauth_error = Some(e);
                             cx.notify();
                         })
@@ -810,7 +840,7 @@ impl ExtensionProviderConfigurationView {
                     log::error!("Device flow poll error: {}", e);
                     let error_string = e.to_string();
                     flow_state
-                        .update_in(cx, |state, _window, cx| {
+                        .update(cx, |state, cx| {
                             state.set_status(
                                 OAuthDeviceFlowStatus::Failed(error_string.clone()),
                                 cx,
@@ -818,7 +848,7 @@ impl ExtensionProviderConfigurationView {
                         })
                         .log_err();
                     this_handle
-                        .update_in(cx, |this, _window, cx| {
+                        .update(cx, |this, cx| {
                             this.oauth_error = Some(error_string);
                             cx.notify();
                         })
@@ -827,7 +857,7 @@ impl ExtensionProviderConfigurationView {
             };
 
             this_handle
-                .update_in(cx, |this, _window, cx| {
+                .update(cx, |this, cx| {
                     this.oauth_in_progress = false;
                     cx.notify();
                 })
@@ -958,46 +988,18 @@ impl gpui::Render for ExtensionProviderConfigurationView {
 
         // If authenticated, show success state with sign out option
         if is_authenticated && env_var_name_used.is_none() {
-            let reset_label = if has_oauth && !has_api_key {
-                "Sign Out"
+            let (status_label, button_label) = if has_oauth && !has_api_key {
+                ("Signed in", "Sign Out")
             } else {
-                "Reset Key"
-            };
-
-            let status_label = if has_oauth && !has_api_key {
-                "Signed in"
-            } else {
-                "API key configured"
+                ("API key configured", "Reset Key")
             };
 
             content = content.child(
-                h_flex()
-                    .mt_0p5()
-                    .p_1()
-                    .justify_between()
-                    .rounded_md()
-                    .border_1()
-                    .border_color(cx.theme().colors().border)
-                    .bg(cx.theme().colors().background)
-                    .child(
-                        h_flex()
-                            .flex_1()
-                            .min_w_0()
-                            .gap_1()
-                            .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
-                            .child(Label::new(status_label).truncate()),
-                    )
-                    .child(
-                        ui::Button::new("reset-key", reset_label)
-                            .label_size(LabelSize::Small)
-                            .icon(ui::IconName::Undo)
-                            .icon_size(ui::IconSize::Small)
-                            .icon_color(Color::Muted)
-                            .icon_position(ui::IconPosition::Start)
-                            .on_click(cx.listener(|this, _, window, cx| {
-                                this.reset_api_key(window, cx);
-                            })),
-                    ),
+                ConfiguredApiCard::new(status_label)
+                    .button_label(button_label)
+                    .on_click(cx.listener(|this, _, window, cx| {
+                        this.reset_api_key(window, cx);
+                    })),
             );
 
             return content.into_any_element();

crates/settings_ui/Cargo.toml 🔗

@@ -21,6 +21,7 @@ bm25 = "2.3.2"
 copilot.workspace = true
 edit_prediction.workspace = true
 extension_host.workspace = true
+language_model.workspace = true
 language_models.workspace = true
 editor.workspace = true
 feature_flags.workspace = true

crates/settings_ui/src/page_data.rs 🔗

@@ -7479,8 +7479,8 @@ fn edit_prediction_language_settings_section() -> Vec<SettingsPageItem> {
             files: USER,
             render: Arc::new(|_, window, cx| {
                 let settings_window = cx.entity();
-                let page = window.use_state(cx, |_, _| {
-                    crate::pages::EditPredictionSetupPage::new(settings_window)
+                let page = window.use_state(cx, |window, cx| {
+                    crate::pages::EditPredictionSetupPage::new(settings_window, window, cx)
                 });
                 page.into_any_element()
             }),

crates/settings_ui/src/pages/edit_prediction_provider_setup.rs 🔗

@@ -5,9 +5,13 @@ use edit_prediction::{
 };
 use extension_host::ExtensionStore;
 use feature_flags::FeatureFlagAppExt as _;
-use gpui::{Entity, ScrollHandle, prelude::*};
+use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
+use language_model::{
+    ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
+};
 use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
-use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
+use std::collections::HashMap;
+use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
 
 use crate::{
     SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
@@ -17,14 +21,85 @@ use crate::{
 pub struct EditPredictionSetupPage {
     settings_window: Entity<SettingsWindow>,
     scroll_handle: ScrollHandle,
+    extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
+    _registry_subscription: Subscription,
+}
+
+struct ExtensionOAuthProviderView {
+    provider_name: SharedString,
+    provider_icon: IconName,
+    provider_icon_path: Option<SharedString>,
+    configuration_view: AnyView,
 }
 
 impl EditPredictionSetupPage {
-    pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
-        Self {
+    pub fn new(
+        settings_window: Entity<SettingsWindow>,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Self {
+        let registry_subscription = cx.subscribe_in(
+            &LanguageModelRegistry::global(cx),
+            window,
+            |this, _, event: &language_model::Event, window, cx| match event {
+                language_model::Event::AddedProvider(provider_id) => {
+                    this.maybe_add_extension_oauth_view(provider_id, window, cx);
+                }
+                language_model::Event::RemovedProvider(provider_id) => {
+                    this.extension_oauth_views.remove(provider_id);
+                }
+                _ => {}
+            },
+        );
+
+        let mut this = Self {
             settings_window,
             scroll_handle: ScrollHandle::new(),
+            extension_oauth_views: HashMap::default(),
+            _registry_subscription: registry_subscription,
+        };
+        this.build_extension_oauth_views(window, cx);
+        this
+    }
+
+    fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+        let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
+        for provider_id in oauth_provider_ids {
+            self.maybe_add_extension_oauth_view(&provider_id, window, cx);
+        }
+    }
+
+    fn maybe_add_extension_oauth_view(
+        &mut self,
+        provider_id: &LanguageModelProviderId,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        // Check if this provider has OAuth configured in the extension manifest
+        if !is_extension_oauth_provider(provider_id, cx) {
+            return;
         }
+
+        let registry = LanguageModelRegistry::global(cx).read(cx);
+        let Some(provider) = registry.provider(provider_id) else {
+            return;
+        };
+
+        let provider_name = provider.name().0.clone();
+        let provider_icon = provider.icon();
+        let provider_icon_path = provider.icon_path();
+        let configuration_view =
+            provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx);
+
+        self.extension_oauth_views.insert(
+            provider_id.clone(),
+            ExtensionOAuthProviderView {
+                provider_name,
+                provider_icon,
+                provider_icon_path,
+                configuration_view,
+            },
+        );
     }
 }
 
@@ -37,10 +112,43 @@ impl Render for EditPredictionSetupPage {
             .installed_extensions()
             .contains_key("copilot-chat");
 
-        let providers = [
-            (!copilot_extension_installed)
-                .then(|| render_github_copilot_provider(window, cx).into_any_element()),
-            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
+        let mut providers: Vec<AnyElement> = Vec::new();
+
+        // Built-in Copilot (hidden if copilot-chat extension is installed)
+        if !copilot_extension_installed {
+            providers.push(render_github_copilot_provider(window, cx).into_any_element());
+        }
+
+        // Extension providers with OAuth support
+        for (provider_id, view) in &self.extension_oauth_views {
+            let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
+                Icon::from_external_svg(icon_path.clone())
+                    .size(ui::IconSize::Medium)
+                    .into_any_element()
+            } else {
+                Icon::new(view.provider_icon)
+                    .size(ui::IconSize::Medium)
+                    .into_any_element()
+            };
+
+            providers.push(
+                v_flex()
+                    .id(SharedString::from(provider_id.0.to_string()))
+                    .min_w_0()
+                    .gap_1p5()
+                    .child(
+                        h_flex().gap_2().items_center().child(icon_element).child(
+                            Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
+                        ),
+                    )
+                    .child(view.configuration_view.clone())
+                    .into_any_element(),
+            );
+        }
+
+        // Mercury (feature flagged)
+        if cx.has_flag::<Zeta2FeatureFlag>() {
+            providers.push(
                 render_api_key_provider(
                     IconName::Inception,
                     "Mercury",
@@ -51,9 +159,13 @@ impl Render for EditPredictionSetupPage {
                     window,
                     cx,
                 )
-                .into_any_element()
-            }),
-            cx.has_flag::<Zeta2FeatureFlag>().then(|| {
+                .into_any_element(),
+            );
+        }
+
+        // Sweep (feature flagged)
+        if cx.has_flag::<Zeta2FeatureFlag>() {
+            providers.push(
                 render_api_key_provider(
                     IconName::SweepAi,
                     "Sweep",
@@ -64,32 +176,34 @@ impl Render for EditPredictionSetupPage {
                     window,
                     cx,
                 )
-                .into_any_element()
-            }),
-            Some(
-                render_api_key_provider(
-                    IconName::AiMistral,
-                    "Codestral",
-                    "https://console.mistral.ai/codestral".into(),
-                    codestral_api_key(cx),
-                    |cx| language_models::MistralLanguageModelProvider::api_url(cx),
-                    Some(settings_window.update(cx, |settings_window, cx| {
-                        let codestral_settings = codestral_settings();
-                        settings_window
-                            .render_sub_page_items_section(
-                                codestral_settings.iter().enumerate(),
-                                None,
-                                window,
-                                cx,
-                            )
-                            .into_any_element()
-                    })),
-                    window,
-                    cx,
-                )
                 .into_any_element(),
-            ),
-        ];
+            );
+        }
+
+        // Codestral
+        providers.push(
+            render_api_key_provider(
+                IconName::AiMistral,
+                "Codestral",
+                "https://console.mistral.ai/codestral".into(),
+                codestral_api_key(cx),
+                |cx| language_models::MistralLanguageModelProvider::api_url(cx),
+                Some(settings_window.update(cx, |settings_window, cx| {
+                    let codestral_settings = codestral_settings();
+                    settings_window
+                        .render_sub_page_items_section(
+                            codestral_settings.iter().enumerate(),
+                            None,
+                            window,
+                            cx,
+                        )
+                        .into_any_element()
+                })),
+                window,
+                cx,
+            )
+            .into_any_element(),
+        );
 
         div()
             .size_full()
@@ -103,11 +217,60 @@ impl Render for EditPredictionSetupPage {
                     .pb_16()
                     .overflow_y_scroll()
                     .track_scroll(&self.scroll_handle)
-                    .children(providers.into_iter().flatten()),
+                    .children(providers),
             )
     }
 }
 
+/// Get extension provider IDs that have OAuth configured.
+fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
+    let extension_store = ExtensionStore::global(cx).read(cx);
+
+    extension_store
+        .installed_extensions()
+        .iter()
+        .flat_map(|(extension_id, entry)| {
+            entry.manifest.language_model_providers.iter().filter_map(
+                move |(provider_id, provider_entry)| {
+                    // Check if this provider has OAuth configured
+                    let has_oauth = provider_entry
+                        .auth
+                        .as_ref()
+                        .is_some_and(|auth| auth.oauth.is_some());
+
+                    if has_oauth {
+                        Some(LanguageModelProviderId(
+                            format!("{}:{}", extension_id, provider_id).into(),
+                        ))
+                    } else {
+                        None
+                    }
+                },
+            )
+        })
+        .collect()
+}
+
+/// Check if a provider ID corresponds to an extension with OAuth configured.
+fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
+    // Extension provider IDs are in the format "extension_id:provider_id"
+    let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
+        return false;
+    };
+
+    let extension_store = ExtensionStore::global(cx).read(cx);
+    let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
+        return false;
+    };
+
+    entry
+        .manifest
+        .language_model_providers
+        .get(local_provider_id)
+        .and_then(|p| p.auth.as_ref())
+        .is_some_and(|auth| auth.oauth.is_some())
+}
+
 fn render_api_key_provider(
     icon: IconName,
     title: &'static str,