diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 60b755e39457a6be99bf37c28e1074388365d3e5..b52ba1bd0ac46db08b6b985a7eca25571caea334 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -347,6 +347,9 @@ pub struct OAuthConfig { /// The Zed icon path to display on the sign-in button (e.g. "github"). #[serde(default)] pub sign_in_button_icon: Option, + /// The description text shown next to the sign-in button in edit prediction settings. + #[serde(default)] + pub sign_in_description: Option, } impl ExtensionManifest { diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index d4f38a66977274914d6f1ecc06e0ad8ba1b68763..435bbb0909f0ce99e99816b9f1df788999d14aad 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -19,8 +19,9 @@ use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::Focusable; use gpui::{ - AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, - TextStyleRefinement, UnderlineStyle, Window, px, + AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, DismissEvent, Entity, + EventEmitter, FocusHandle, MouseDownEvent, Subscription, Task, TextStyleRefinement, + UnderlineStyle, Window, WindowBounds, WindowOptions, point, px, rems, }; use language_model::tool_schema::LanguageModelToolSchemaFormat; use language_model::{ @@ -34,7 +35,10 @@ use markdown::{Markdown, MarkdownElement, MarkdownStyle}; use settings::Settings; use std::sync::Arc; use theme::ThemeSettings; -use ui::{ConfiguredApiCard, Label, LabelSize, prelude::*}; +use ui::{ + Button, ButtonLike, ButtonSize, ButtonStyle, ConfiguredApiCard, Headline, HeadlineSize, Icon, + Label, LabelSize, Vector, VectorName, prelude::*, +}; use util::ResultExt as _; use workspace::Workspace; use workspace::oauth_device_flow_modal::{ @@ -253,7 +257,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { fn configuration_view( &self, - _target_agent: ConfigurationViewTargetAgent, + target_agent: ConfigurationViewTargetAgent, window: &mut Window, cx: &mut App, ) -> AnyView { @@ -274,6 +278,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { auth_config, state, icon_path, + target_agent, window, cx, ) @@ -356,6 +361,7 @@ struct ExtensionProviderConfigurationView { oauth_in_progress: bool, oauth_error: Option, icon_path: Option, + target_agent: ConfigurationViewTargetAgent, _subscriptions: Vec, } @@ -368,6 +374,7 @@ impl ExtensionProviderConfigurationView { auth_config: Option, state: Entity, icon_path: Option, + target_agent: ConfigurationViewTargetAgent, window: &mut Window, cx: &mut Context, ) -> Self { @@ -397,6 +404,7 @@ impl ExtensionProviderConfigurationView { oauth_in_progress: false, oauth_error: None, icon_path, + target_agent, _subscriptions: vec![state_subscription], }; @@ -657,24 +665,42 @@ impl ExtensionProviderConfigurationView { let state = self.state.clone(); let icon_path = self.icon_path.clone(); let this_handle = cx.weak_entity(); - - // Get workspace window handle to show modal - try current window first, then find any workspace window - log::info!("OAuth: Looking for workspace window"); - let workspace_window = window.window_handle().downcast::().or_else(|| { - log::info!("OAuth: Current window is not a workspace, searching other windows"); - cx.windows() - .into_iter() - .find_map(|window_handle| window_handle.downcast::()) - }); - - let Some(workspace_window) = workspace_window else { - log::error!("OAuth: Could not find any workspace window"); - self.oauth_in_progress = false; - self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string()); - cx.notify(); - return; + let use_popup_window = self.is_edit_prediction_mode(); + + // Get current window bounds for positioning popup + let current_window_center = window.bounds().center(); + + // For workspace modal mode, find the workspace window + let workspace_window = if !use_popup_window { + log::info!("OAuth: Looking for workspace window"); + let ws = window.window_handle().downcast::().or_else(|| { + log::info!("OAuth: Current window is not a workspace, searching other windows"); + cx.windows() + .into_iter() + .find_map(|window_handle| window_handle.downcast::()) + }); + + if ws.is_none() { + log::error!("OAuth: Could not find any workspace window"); + self.oauth_in_progress = false; + self.oauth_error = + Some("Could not access workspace to show sign-in modal".to_string()); + cx.notify(); + return; + } + ws + } else { + None }; - log::info!("OAuth: Found workspace window"); + + log::info!( + "OAuth: Using {} mode", + if use_popup_window { + "popup window" + } else { + "workspace modal" + } + ); let state = state.downgrade(); cx.spawn(async move |_this, cx| { // Step 1: Start device flow - get prompt info from extension @@ -727,7 +753,7 @@ impl ExtensionProviderConfigurationView { } }; - // Step 2: Create state entity and show the modal + // Step 2: Create state entity and show the modal/window let modal_config = OAuthDeviceFlowModalConfig { user_code: prompt_info.user_code, verification_url: prompt_info.verification_url, @@ -739,29 +765,71 @@ impl ExtensionProviderConfigurationView { icon_path, }; - log::info!("OAuth: Attempting to show modal in workspace window"); - let flow_state: Option> = workspace_window - .update(cx, |workspace, window, cx| { - log::info!("OAuth: Inside workspace.update, creating modal"); - window.activate_window(); - let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config)); - let flow_state_clone = flow_state.clone(); - workspace.toggle_modal(window, cx, |_window, cx| { - log::info!("OAuth: Inside toggle_modal callback"); - OAuthDeviceFlowModal::new(flow_state_clone, cx) - }); - flow_state + let flow_state: Option> = if use_popup_window { + // Open a popup window like Copilot does + log::info!("OAuth: Opening popup window"); + cx.update(|cx| { + let height = px(450.); + let width = px(350.); + let window_bounds = WindowBounds::Windowed(gpui::bounds( + current_window_center - point(height / 2.0, width / 2.0), + gpui::size(height, width), + )); + + let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config.clone())); + let flow_state_for_window = flow_state.clone(); + + cx.open_window( + WindowOptions { + kind: gpui::WindowKind::PopUp, + window_bounds: Some(window_bounds), + is_resizable: false, + is_movable: true, + titlebar: Some(gpui::TitlebarOptions { + appears_transparent: true, + ..Default::default() + }), + ..Default::default() + }, + |window, cx| { + cx.new(|cx| { + OAuthCodeVerificationWindow::new( + modal_config, + flow_state_for_window, + window, + cx, + ) + }) + }, + ) + .log_err(); + + Some(flow_state) + }) + .ok() + .flatten() + } else { + // Use workspace modal + log::info!("OAuth: Attempting to show modal in workspace window"); + workspace_window.as_ref().and_then(|ws| { + ws.update(cx, |workspace, window, cx| { + log::info!("OAuth: Inside workspace.update, creating modal"); + window.activate_window(); + let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config)); + let flow_state_clone = flow_state.clone(); + workspace.toggle_modal(window, cx, |_window, cx| { + log::info!("OAuth: Inside toggle_modal callback"); + OAuthDeviceFlowModal::new(flow_state_clone, cx) + }); + flow_state + }) + .ok() }) - .ok(); + }; - log::info!( - "OAuth: workspace_window.update result: {:?}", - flow_state.is_some() - ); + log::info!("OAuth: flow_state created: {:?}", flow_state.is_some()); let Some(flow_state) = flow_state else { - log::error!( - "OAuth: Failed to show sign-in modal - workspace_window.update returned None" - ); + log::error!("OAuth: Failed to show sign-in modal/window"); this_handle .update(cx, |this, cx| { this.oauth_in_progress = false; @@ -771,7 +839,7 @@ impl ExtensionProviderConfigurationView { .log_err(); return; }; - log::info!("OAuth: Modal shown successfully, starting poll"); + log::info!("OAuth: Modal/window shown successfully, starting poll"); // Step 3: Poll for authentication completion let poll_result = extension @@ -885,10 +953,129 @@ impl ExtensionProviderConfigurationView { .map(|c| c.credential_label.is_some() || c.oauth.is_none()) .unwrap_or(true) } + + fn is_edit_prediction_mode(&self) -> bool { + self.target_agent == ConfigurationViewTargetAgent::EditPrediction + } + + fn render_for_edit_prediction( + &mut self, + _window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { + let is_loading = self.loading_settings || self.loading_credentials; + let is_authenticated = self.is_authenticated(cx); + let has_oauth = self.has_oauth_config(); + + // Helper to create the horizontal container layout matching Copilot + let container = |description: SharedString, action: AnyElement| { + h_flex() + .pt_2p5() + .w_full() + .justify_between() + .child( + v_flex() + .w_full() + .max_w_1_2() + .child(Label::new("Authenticate To Use")) + .child( + Label::new(description) + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + .child(action) + }; + + // Get the description from OAuth config or use a default + let oauth_config = self.oauth_config(); + let description: SharedString = oauth_config + .and_then(|c| c.sign_in_description.clone()) + .unwrap_or_else(|| "Sign in to authenticate with this provider.".to_string()) + .into(); + + if is_loading { + return container( + description, + Button::new("loading", "Loading...") + .style(ButtonStyle::Outlined) + .disabled(true) + .into_any_element(), + ) + .into_any_element(); + } + + // If authenticated, show the configured card + if is_authenticated { + let (status_label, button_label) = if has_oauth { + ("Authorized", "Sign Out") + } else { + ("API key configured", "Reset Key") + }; + + return ConfiguredApiCard::new(status_label) + .button_label(button_label) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_key(window, cx); + })) + .into_any_element(); + } + + // Not authenticated - show sign in button + if has_oauth { + let button_label = oauth_config + .and_then(|c| c.sign_in_button_label.clone()) + .unwrap_or_else(|| "Sign In".to_string()); + let button_icon = oauth_config + .and_then(|c| c.sign_in_button_icon.as_ref()) + .and_then(|icon_name| match icon_name.as_str() { + "github" => Some(ui::IconName::Github), + _ => None, + }); + + let oauth_in_progress = self.oauth_in_progress; + + let mut button = Button::new("oauth-sign-in", button_label) + .size(ButtonSize::Medium) + .style(ButtonStyle::Outlined) + .disabled(oauth_in_progress) + .on_click(cx.listener(|this, _, window, cx| { + this.start_oauth_sign_in(window, cx); + })); + + if let Some(icon) = button_icon { + button = button + .icon(icon) + .icon_position(ui::IconPosition::Start) + .icon_size(ui::IconSize::Small) + .icon_color(Color::Muted); + } + + return container(description, button.into_any_element()).into_any_element(); + } + + // Fallback for API key only providers - show a simple message + container( + description, + Button::new("configure", "Configure") + .size(ButtonSize::Medium) + .style(ButtonStyle::Outlined) + .disabled(true) + .into_any_element(), + ) + .into_any_element() + } } impl gpui::Render for ExtensionProviderConfigurationView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + // Use simplified horizontal layout for edit prediction mode + if self.is_edit_prediction_mode() { + return self + .render_for_edit_prediction(window, cx) + .into_any_element(); + } + let is_loading = self.loading_settings || self.loading_credentials; let is_authenticated = self.is_authenticated(cx); let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone(); @@ -1158,6 +1345,230 @@ impl Focusable for ExtensionProviderConfigurationView { } } +/// A popup window for OAuth device flow, similar to CopilotCodeVerification. +/// This is used when in edit prediction mode to avoid moving the settings panel behind. +pub struct OAuthCodeVerificationWindow { + config: OAuthDeviceFlowModalConfig, + status: OAuthDeviceFlowStatus, + connect_clicked: bool, + focus_handle: FocusHandle, + _subscription: Option, +} + +impl Focusable for OAuthCodeVerificationWindow { + fn focus_handle(&self, _: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for OAuthCodeVerificationWindow {} + +impl OAuthCodeVerificationWindow { + pub fn new( + config: OAuthDeviceFlowModalConfig, + state: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + window.on_window_should_close(cx, |window, cx| { + if let Some(this) = window.root::().flatten() { + this.update(cx, |_, cx| { + cx.emit(DismissEvent); + }); + } + true + }); + cx.subscribe_in( + &cx.entity(), + window, + |_, _, _: &DismissEvent, window, _cx| { + window.remove_window(); + }, + ) + .detach(); + + let subscription = cx.observe(&state, |this, state, cx| { + let status = state.read(cx).status.clone(); + this.status = status; + cx.notify(); + }); + + Self { + config, + status: state.read(cx).status.clone(), + connect_clicked: false, + focus_handle: cx.focus_handle(), + _subscription: Some(subscription), + } + } + + fn render_icon(&self, cx: &mut Context) -> impl IntoElement { + let icon_color = Color::Custom(cx.theme().colors().icon); + let icon_size = rems(2.5); + let plus_size = rems(0.875); + let plus_color = cx.theme().colors().icon.opacity(0.5); + + if let Some(icon_path) = &self.config.icon_path { + h_flex() + .gap_2() + .items_center() + .child( + Icon::from_external_svg(icon_path.clone()) + .size(ui::IconSize::Custom(icon_size)) + .color(icon_color), + ) + .child( + gpui::svg() + .size(plus_size) + .path("icons/plus.svg") + .text_color(plus_color), + ) + .child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color)) + .into_any_element() + } else { + Vector::new(VectorName::ZedLogo, icon_size, icon_size) + .color(icon_color) + .into_any_element() + } + } + + fn render_device_code(&self, cx: &mut Context) -> impl IntoElement { + let user_code = self.config.user_code.clone(); + let copied = cx + .read_from_clipboard() + .map(|item| item.text().as_ref() == Some(&user_code)) + .unwrap_or(false); + let user_code_for_click = user_code.clone(); + + ButtonLike::new("copy-button") + .full_width() + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .size(ButtonSize::Medium) + .child( + h_flex() + .w_full() + .p_1() + .justify_between() + .child(Label::new(user_code)) + .child(Label::new(if copied { "Copied!" } else { "Copy" })), + ) + .on_click(move |_, window, cx| { + cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone())); + window.refresh(); + }) + } + + fn render_prompting_modal(&self, cx: &mut Context) -> impl IntoElement { + let connect_button_label: String = if self.connect_clicked { + "Waiting for connection…".to_string() + } else { + self.config.connect_button_label.clone() + }; + let verification_url = self.config.verification_url.clone(); + + v_flex() + .flex_1() + .gap_2p5() + .items_center() + .text_center() + .child(Headline::new(self.config.headline.clone()).size(HeadlineSize::Large)) + .child(Label::new(self.config.description.clone()).color(Color::Muted)) + .child(self.render_device_code(cx)) + .child( + Label::new("Paste this code after clicking the button below.").color(Color::Muted), + ) + .child( + v_flex() + .w_full() + .gap_1() + .child( + Button::new("connect-button", connect_button_label) + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .on_click(cx.listener(move |this, _, _window, cx| { + cx.open_url(&verification_url); + this.connect_clicked = true; + })), + ) + .child( + Button::new("cancel-button", "Cancel") + .full_width() + .size(ButtonSize::Medium) + .on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })), + ), + ) + } + + fn render_authorized_modal(&self, cx: &mut Context) -> impl IntoElement { + v_flex() + .gap_2() + .text_center() + .justify_center() + .child(Headline::new(self.config.success_headline.clone()).size(HeadlineSize::Large)) + .child(Label::new(self.config.success_message.clone()).color(Color::Muted)) + .child( + Button::new("done-button", "Done") + .full_width() + .style(ButtonStyle::Outlined) + .size(ButtonSize::Medium) + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + } + + fn render_failed_modal(&self, error: &str, cx: &mut Context) -> impl IntoElement { + v_flex() + .gap_2() + .text_center() + .justify_center() + .child(Headline::new("Authorization Failed").size(HeadlineSize::Large)) + .child(Label::new(error.to_string()).color(Color::Error)) + .child( + Button::new("close-button", "Close") + .full_width() + .size(ButtonSize::Medium) + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + } +} + +impl gpui::Render for OAuthCodeVerificationWindow { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let prompt = match &self.status { + OAuthDeviceFlowStatus::Prompting | OAuthDeviceFlowStatus::WaitingForAuthorization => { + self.render_prompting_modal(cx).into_any_element() + } + OAuthDeviceFlowStatus::Authorized => { + self.render_authorized_modal(cx).into_any_element() + } + OAuthDeviceFlowStatus::Failed(error) => { + self.render_failed_modal(error, cx).into_any_element() + } + }; + + v_flex() + .id("oauth_code_verification") + .track_focus(&self.focus_handle(cx)) + .size_full() + .px_4() + .py_8() + .gap_2() + .items_center() + .justify_center() + .elevation_3(cx) + .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| { + cx.emit(DismissEvent); + })) + .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _| { + window.focus(&this.focus_handle); + })) + .child(self.render_icon(cx)) + .child(prompt) + } +} + fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 21756517a9c48cc72c7cc17bf62dff810ef0dd92..7307bebcbf2950e719186db345050420fe36d7a3 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -825,10 +825,11 @@ pub trait LanguageModelProvider: 'static { fn reset_credentials(&self, cx: &mut App) -> Task>; } -#[derive(Default, Clone)] +#[derive(Default, Clone, PartialEq, Eq)] pub enum ConfigurationViewTargetAgent { #[default] ZedAgent, + EditPrediction, Other(SharedString), } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 25ba7615dc23e2561648e173588be6d93c28e295..299e94484cf41f1bcb38e50656128c38d954d612 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -924,6 +924,7 @@ impl Render for ConfigurationView { .on_action(cx.listener(Self::save_api_key)) .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(), + ConfigurationViewTargetAgent::EditPrediction => "Anthropic for edit predictions".into(), ConfigurationViewTargetAgent::Other(agent) => agent.clone(), }))) .child( diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 0a7df23fa04222cb087d56b5034af2efdbdcbd3a..b3e03ce9d3a7be24996a074abff3f421d5bad0fd 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -841,6 +841,7 @@ impl Render for ConfigurationView { .on_action(cx.listener(Self::save_api_key)) .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(), + ConfigurationViewTargetAgent::EditPrediction => "Google AI for edit predictions".into(), ConfigurationViewTargetAgent::Other(agent) => agent.clone(), }))) .child( diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index ca37b990afaac3d03c811f6620f226db288dc12d..46089653dadcc74ba0688db1aa5264cba9351851 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -85,11 +85,11 @@ impl EditPredictionSetupPage { return; }; - let provider_name = provider.name().0.clone(); + let provider_name = provider.name().0; let provider_icon = provider.icon(); let provider_icon_path = provider.icon_path(); let configuration_view = - provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx); + provider.configuration_view(ConfigurationViewTargetAgent::EditPrediction, window, cx); self.extension_oauth_views.insert( provider_id.clone(),