Cargo.lock 🔗
@@ -14831,6 +14831,7 @@ dependencies = [
"gpui",
"heck 0.5.0",
"language",
+ "language_model",
"language_models",
"log",
"menu",
Richard Feldman created
Cargo.lock | 1
crates/extension_host/src/wasm_host/llm_provider.rs | 116
crates/settings_ui/Cargo.toml | 1
crates/settings_ui/src/page_data.rs | 4
crates/settings_ui/src/pages/edit_prediction_provider_setup.rs | 237 +++
5 files changed, 263 insertions(+), 96 deletions(-)
@@ -14831,6 +14831,7 @@ dependencies = [
"gpui",
"heck 0.5.0",
"language",
+ "language_model",
"language_models",
"log",
"menu",
@@ -34,7 +34,7 @@ use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use settings::Settings;
use std::sync::Arc;
use theme::ThemeSettings;
-use ui::{Label, LabelSize, prelude::*};
+use ui::{ConfiguredApiCard, Label, LabelSize, prelude::*};
use util::ResultExt as _;
use workspace::Workspace;
use workspace::oauth_device_flow_modal::{
@@ -658,17 +658,25 @@ impl ExtensionProviderConfigurationView {
let icon_path = self.icon_path.clone();
let this_handle = cx.weak_entity();
- // Get workspace to show modal
- let Some(workspace) = window.root::<Workspace>().flatten() else {
+ // Get workspace window handle to show modal - try current window first, then find any workspace window
+ log::info!("OAuth: Looking for workspace window");
+ let workspace_window = window.window_handle().downcast::<Workspace>().or_else(|| {
+ log::info!("OAuth: Current window is not a workspace, searching other windows");
+ cx.windows()
+ .into_iter()
+ .find_map(|window_handle| window_handle.downcast::<Workspace>())
+ });
+
+ let Some(workspace_window) = workspace_window else {
+ log::error!("OAuth: Could not find any workspace window");
self.oauth_in_progress = false;
self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string());
cx.notify();
return;
};
-
- let workspace = workspace.downgrade();
+ log::info!("OAuth: Found workspace window");
let state = state.downgrade();
- cx.spawn_in(window, async move |_this, cx| {
+ cx.spawn(async move |_this, cx| {
// Step 1: Start device flow - get prompt info from extension
let start_result = extension
.call({
@@ -683,12 +691,22 @@ impl ExtensionProviderConfigurationView {
})
.await;
+ log::info!(
+ "OAuth: Device flow start result: {:?}",
+ start_result.is_ok()
+ );
let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
- Ok(Ok(Ok(info))) => info,
+ Ok(Ok(Ok(info))) => {
+ log::info!(
+ "OAuth: Got device flow prompt info, user_code: {}",
+ info.user_code
+ );
+ info
+ }
Ok(Ok(Err(e))) => {
- log::error!("Device flow start failed: {}", e);
+ log::error!("OAuth: Device flow start failed: {}", e);
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_in_progress = false;
this.oauth_error = Some(e);
cx.notify();
@@ -697,9 +715,9 @@ impl ExtensionProviderConfigurationView {
return;
}
Ok(Err(e)) | Err(e) => {
- log::error!("Device flow start error: {}", e);
+ log::error!("OAuth: Device flow start error: {}", e);
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_in_progress = false;
this.oauth_error = Some(e.to_string());
cx.notify();
@@ -721,20 +739,31 @@ impl ExtensionProviderConfigurationView {
icon_path,
};
- let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace
- .update_in(cx, |workspace, window, cx| {
+ log::info!("OAuth: Attempting to show modal in workspace window");
+ let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace_window
+ .update(cx, |workspace, window, cx| {
+ log::info!("OAuth: Inside workspace.update, creating modal");
+ window.activate_window();
let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
let flow_state_clone = flow_state.clone();
workspace.toggle_modal(window, cx, |_window, cx| {
+ log::info!("OAuth: Inside toggle_modal callback");
OAuthDeviceFlowModal::new(flow_state_clone, cx)
});
flow_state
})
.ok();
+ log::info!(
+ "OAuth: workspace_window.update result: {:?}",
+ flow_state.is_some()
+ );
let Some(flow_state) = flow_state else {
+ log::error!(
+ "OAuth: Failed to show sign-in modal - workspace_window.update returned None"
+ );
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_in_progress = false;
this.oauth_error = Some("Failed to show sign-in modal".to_string());
cx.notify();
@@ -742,6 +771,7 @@ impl ExtensionProviderConfigurationView {
.log_err();
return;
};
+ log::info!("OAuth: Modal shown successfully, starting poll");
// Step 3: Poll for authentication completion
let poll_result = extension
@@ -778,7 +808,7 @@ impl ExtensionProviderConfigurationView {
};
state
- .update_in(cx, |state, _window, cx| {
+ .update(cx, |state, cx| {
state.is_authenticated = true;
state.available_models = new_models;
cx.notify();
@@ -787,7 +817,7 @@ impl ExtensionProviderConfigurationView {
// Update flow state to show success
flow_state
- .update_in(cx, |state, _window, cx| {
+ .update(cx, |state, cx| {
state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
})
.log_err();
@@ -795,12 +825,12 @@ impl ExtensionProviderConfigurationView {
Ok(Ok(Err(e))) => {
log::error!("Device flow poll failed: {}", e);
flow_state
- .update_in(cx, |state, _window, cx| {
+ .update(cx, |state, cx| {
state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
})
.log_err();
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_error = Some(e);
cx.notify();
})
@@ -810,7 +840,7 @@ impl ExtensionProviderConfigurationView {
log::error!("Device flow poll error: {}", e);
let error_string = e.to_string();
flow_state
- .update_in(cx, |state, _window, cx| {
+ .update(cx, |state, cx| {
state.set_status(
OAuthDeviceFlowStatus::Failed(error_string.clone()),
cx,
@@ -818,7 +848,7 @@ impl ExtensionProviderConfigurationView {
})
.log_err();
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_error = Some(error_string);
cx.notify();
})
@@ -827,7 +857,7 @@ impl ExtensionProviderConfigurationView {
};
this_handle
- .update_in(cx, |this, _window, cx| {
+ .update(cx, |this, cx| {
this.oauth_in_progress = false;
cx.notify();
})
@@ -958,46 +988,18 @@ impl gpui::Render for ExtensionProviderConfigurationView {
// If authenticated, show success state with sign out option
if is_authenticated && env_var_name_used.is_none() {
- let reset_label = if has_oauth && !has_api_key {
- "Sign Out"
+ let (status_label, button_label) = if has_oauth && !has_api_key {
+ ("Signed in", "Sign Out")
} else {
- "Reset Key"
- };
-
- let status_label = if has_oauth && !has_api_key {
- "Signed in"
- } else {
- "API key configured"
+ ("API key configured", "Reset Key")
};
content = content.child(
- h_flex()
- .mt_0p5()
- .p_1()
- .justify_between()
- .rounded_md()
- .border_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().background)
- .child(
- h_flex()
- .flex_1()
- .min_w_0()
- .gap_1()
- .child(ui::Icon::new(ui::IconName::Check).color(Color::Success))
- .child(Label::new(status_label).truncate()),
- )
- .child(
- ui::Button::new("reset-key", reset_label)
- .label_size(LabelSize::Small)
- .icon(ui::IconName::Undo)
- .icon_size(ui::IconSize::Small)
- .icon_color(Color::Muted)
- .icon_position(ui::IconPosition::Start)
- .on_click(cx.listener(|this, _, window, cx| {
- this.reset_api_key(window, cx);
- })),
- ),
+ ConfiguredApiCard::new(status_label)
+ .button_label(button_label)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.reset_api_key(window, cx);
+ })),
);
return content.into_any_element();
@@ -21,6 +21,7 @@ bm25 = "2.3.2"
copilot.workspace = true
edit_prediction.workspace = true
extension_host.workspace = true
+language_model.workspace = true
language_models.workspace = true
editor.workspace = true
feature_flags.workspace = true
@@ -7479,8 +7479,8 @@ fn edit_prediction_language_settings_section() -> Vec<SettingsPageItem> {
files: USER,
render: Arc::new(|_, window, cx| {
let settings_window = cx.entity();
- let page = window.use_state(cx, |_, _| {
- crate::pages::EditPredictionSetupPage::new(settings_window)
+ let page = window.use_state(cx, |window, cx| {
+ crate::pages::EditPredictionSetupPage::new(settings_window, window, cx)
});
page.into_any_element()
}),
@@ -5,9 +5,13 @@ use edit_prediction::{
};
use extension_host::ExtensionStore;
use feature_flags::FeatureFlagAppExt as _;
-use gpui::{Entity, ScrollHandle, prelude::*};
+use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
+use language_model::{
+ ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
+};
use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
-use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
+use std::collections::HashMap;
+use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
use crate::{
SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
@@ -17,14 +21,85 @@ use crate::{
pub struct EditPredictionSetupPage {
settings_window: Entity<SettingsWindow>,
scroll_handle: ScrollHandle,
+ extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
+ _registry_subscription: Subscription,
+}
+
+struct ExtensionOAuthProviderView {
+ provider_name: SharedString,
+ provider_icon: IconName,
+ provider_icon_path: Option<SharedString>,
+ configuration_view: AnyView,
}
impl EditPredictionSetupPage {
- pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
- Self {
+ pub fn new(
+ settings_window: Entity<SettingsWindow>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let registry_subscription = cx.subscribe_in(
+ &LanguageModelRegistry::global(cx),
+ window,
+ |this, _, event: &language_model::Event, window, cx| match event {
+ language_model::Event::AddedProvider(provider_id) => {
+ this.maybe_add_extension_oauth_view(provider_id, window, cx);
+ }
+ language_model::Event::RemovedProvider(provider_id) => {
+ this.extension_oauth_views.remove(provider_id);
+ }
+ _ => {}
+ },
+ );
+
+ let mut this = Self {
settings_window,
scroll_handle: ScrollHandle::new(),
+ extension_oauth_views: HashMap::default(),
+ _registry_subscription: registry_subscription,
+ };
+ this.build_extension_oauth_views(window, cx);
+ this
+ }
+
+ fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
+ for provider_id in oauth_provider_ids {
+ self.maybe_add_extension_oauth_view(&provider_id, window, cx);
+ }
+ }
+
+ fn maybe_add_extension_oauth_view(
+ &mut self,
+ provider_id: &LanguageModelProviderId,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ // Check if this provider has OAuth configured in the extension manifest
+ if !is_extension_oauth_provider(provider_id, cx) {
+ return;
}
+
+ let registry = LanguageModelRegistry::global(cx).read(cx);
+ let Some(provider) = registry.provider(provider_id) else {
+ return;
+ };
+
+ let provider_name = provider.name().0.clone();
+ let provider_icon = provider.icon();
+ let provider_icon_path = provider.icon_path();
+ let configuration_view =
+ provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx);
+
+ self.extension_oauth_views.insert(
+ provider_id.clone(),
+ ExtensionOAuthProviderView {
+ provider_name,
+ provider_icon,
+ provider_icon_path,
+ configuration_view,
+ },
+ );
}
}
@@ -37,10 +112,43 @@ impl Render for EditPredictionSetupPage {
.installed_extensions()
.contains_key("copilot-chat");
- let providers = [
- (!copilot_extension_installed)
- .then(|| render_github_copilot_provider(window, cx).into_any_element()),
- cx.has_flag::<Zeta2FeatureFlag>().then(|| {
+ let mut providers: Vec<AnyElement> = Vec::new();
+
+ // Built-in Copilot (hidden if copilot-chat extension is installed)
+ if !copilot_extension_installed {
+ providers.push(render_github_copilot_provider(window, cx).into_any_element());
+ }
+
+ // Extension providers with OAuth support
+ for (provider_id, view) in &self.extension_oauth_views {
+ let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
+ Icon::from_external_svg(icon_path.clone())
+ .size(ui::IconSize::Medium)
+ .into_any_element()
+ } else {
+ Icon::new(view.provider_icon)
+ .size(ui::IconSize::Medium)
+ .into_any_element()
+ };
+
+ providers.push(
+ v_flex()
+ .id(SharedString::from(provider_id.0.to_string()))
+ .min_w_0()
+ .gap_1p5()
+ .child(
+ h_flex().gap_2().items_center().child(icon_element).child(
+ Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
+ ),
+ )
+ .child(view.configuration_view.clone())
+ .into_any_element(),
+ );
+ }
+
+ // Mercury (feature flagged)
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ providers.push(
render_api_key_provider(
IconName::Inception,
"Mercury",
@@ -51,9 +159,13 @@ impl Render for EditPredictionSetupPage {
window,
cx,
)
- .into_any_element()
- }),
- cx.has_flag::<Zeta2FeatureFlag>().then(|| {
+ .into_any_element(),
+ );
+ }
+
+ // Sweep (feature flagged)
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ providers.push(
render_api_key_provider(
IconName::SweepAi,
"Sweep",
@@ -64,32 +176,34 @@ impl Render for EditPredictionSetupPage {
window,
cx,
)
- .into_any_element()
- }),
- Some(
- render_api_key_provider(
- IconName::AiMistral,
- "Codestral",
- "https://console.mistral.ai/codestral".into(),
- codestral_api_key(cx),
- |cx| language_models::MistralLanguageModelProvider::api_url(cx),
- Some(settings_window.update(cx, |settings_window, cx| {
- let codestral_settings = codestral_settings();
- settings_window
- .render_sub_page_items_section(
- codestral_settings.iter().enumerate(),
- None,
- window,
- cx,
- )
- .into_any_element()
- })),
- window,
- cx,
- )
.into_any_element(),
- ),
- ];
+ );
+ }
+
+ // Codestral
+ providers.push(
+ render_api_key_provider(
+ IconName::AiMistral,
+ "Codestral",
+ "https://console.mistral.ai/codestral".into(),
+ codestral_api_key(cx),
+ |cx| language_models::MistralLanguageModelProvider::api_url(cx),
+ Some(settings_window.update(cx, |settings_window, cx| {
+ let codestral_settings = codestral_settings();
+ settings_window
+ .render_sub_page_items_section(
+ codestral_settings.iter().enumerate(),
+ None,
+ window,
+ cx,
+ )
+ .into_any_element()
+ })),
+ window,
+ cx,
+ )
+ .into_any_element(),
+ );
div()
.size_full()
@@ -103,11 +217,60 @@ impl Render for EditPredictionSetupPage {
.pb_16()
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
- .children(providers.into_iter().flatten()),
+ .children(providers),
)
}
}
+/// Get extension provider IDs that have OAuth configured.
+fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
+ let extension_store = ExtensionStore::global(cx).read(cx);
+
+ extension_store
+ .installed_extensions()
+ .iter()
+ .flat_map(|(extension_id, entry)| {
+ entry.manifest.language_model_providers.iter().filter_map(
+ move |(provider_id, provider_entry)| {
+ // Check if this provider has OAuth configured
+ let has_oauth = provider_entry
+ .auth
+ .as_ref()
+ .is_some_and(|auth| auth.oauth.is_some());
+
+ if has_oauth {
+ Some(LanguageModelProviderId(
+ format!("{}:{}", extension_id, provider_id).into(),
+ ))
+ } else {
+ None
+ }
+ },
+ )
+ })
+ .collect()
+}
+
+/// Check if a provider ID corresponds to an extension with OAuth configured.
+fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
+ // Extension provider IDs are in the format "extension_id:provider_id"
+ let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
+ return false;
+ };
+
+ let extension_store = ExtensionStore::global(cx).read(cx);
+ let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
+ return false;
+ };
+
+ entry
+ .manifest
+ .language_model_providers
+ .get(local_provider_id)
+ .and_then(|p| p.auth.as_ref())
+ .is_some_and(|auth| auth.oauth.is_some())
+}
+
fn render_api_key_provider(
icon: IconName,
title: &'static str,