From 5abf9687480c91c074efb443e827858481a96457 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Sun, 14 Dec 2025 21:47:17 -0500 Subject: [PATCH] Make Copilot login flow look like the builtin one --- Cargo.lock | 1 + crates/extension_api/src/extension_api.rs | 32 +- .../wit/since_v0.8.0/extension.wit | 14 +- .../wit/since_v0.8.0/llm-provider.wit | 18 ++ crates/extension_host/Cargo.toml | 1 + .../src/wasm_host/llm_provider.rs | 278 +++++++++--------- crates/extension_host/src/wasm_host/wit.rs | 21 +- .../workspace/src/oauth_device_flow_modal.rs | 271 +++++++++++++++++ crates/workspace/src/workspace.rs | 1 + extensions/copilot-chat/src/copilot_chat.rs | 19 +- 10 files changed, 487 insertions(+), 169 deletions(-) create mode 100644 crates/workspace/src/oauth_device_flow_modal.rs diff --git a/Cargo.lock b/Cargo.lock index fb8ace072250c1cccb08257f4f62913512a7b172..9430380944e573b5b9f7fdfc1e83df0d760d5957 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5889,6 +5889,7 @@ dependencies = [ "wasmparser 0.221.3", "wasmtime", "wasmtime-wasi", + "workspace", "zlog", ] diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 38bff3adad3295c911c65f1a39168044bc2ec0b9..14acdfd66597eca040102e429bf0a6def73b6ec6 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -31,17 +31,18 @@ pub use wit::{ }, zed::extension::llm_provider::{ CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, - CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData, - MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, - ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo, - OauthHttpRequest as LlmOauthHttpRequest, OauthHttpResponse as LlmOauthHttpResponse, - OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult, - ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, - StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, - TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, - ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat, - ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent, - ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError, + CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest, + OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig, + OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, delete_credential as llm_delete_credential, get_credential as llm_get_credential, get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser, oauth_send_http_request as llm_oauth_send_http_request, @@ -301,12 +302,11 @@ pub trait Extension: Send + Sync { /// Start an OAuth device flow sign-in. /// This is called when the user explicitly clicks "Sign in with GitHub" or similar. - /// Opens the browser to the verification URL and returns the user code that should - /// be displayed to the user. + /// Returns information needed to display the device flow prompt modal to the user. fn llm_provider_start_device_flow_sign_in( &mut self, _provider_id: &str, - ) -> Result { + ) -> Result { Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string()) } @@ -641,7 +641,9 @@ impl wit::Guest for Component { extension().llm_provider_is_authenticated(&provider_id) } - fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result { + fn llm_provider_start_device_flow_sign_in( + provider_id: String, + ) -> Result { extension().llm_provider_start_device_flow_sign_in(&provider_id) } diff --git a/crates/extension_api/wit/since_v0.8.0/extension.wit b/crates/extension_api/wit/since_v0.8.0/extension.wit index f1df34eaa15465b2f95b80e99bcac3bf59fc45b9..7440984f5d171ccfadd212760d41e15ce7325535 100644 --- a/crates/extension_api/wit/since_v0.8.0/extension.wit +++ b/crates/extension_api/wit/since_v0.8.0/extension.wit @@ -18,7 +18,8 @@ world extension { use slash-command.{slash-command, slash-command-argument-completion, slash-command-output}; use llm-provider.{ provider-info, model-info, completion-request, - cache-configuration, completion-event, token-usage + cache-configuration, completion-event, token-usage, + device-flow-prompt-info }; /// Initializes the extension. @@ -188,15 +189,14 @@ world extension { /// /// The device flow works as follows: /// 1. Extension requests a device code from the OAuth provider - /// 2. Extension opens the verification URL in the browser - /// 3. Extension returns the user code to display to the user - /// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in + /// 2. Extension returns prompt info including user code and verification URL + /// 3. Host displays a modal with the prompt info + /// 4. Host calls llm-provider-poll-device-flow-sign-in /// 5. Extension polls for the access token while user authorizes in browser /// 6. Once authorized, extension stores the credential and returns success /// - /// Returns the user code that should be displayed to the user while they - /// complete authorization in the browser. - export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result; + /// Returns information needed to display the device flow prompt modal. + export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result; /// Poll for device flow sign-in completion. /// This is called after llm-provider-start-device-flow-sign-in returns the user code. diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit index 696d085a22b2dafa5fb6348318c124d2af36108f..6ce2a344d245a7297cae002ab529fd39937e00cc 100644 --- a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -328,4 +328,22 @@ interface llm-provider { /// Useful for OAuth flows that need to open a browser but handle the /// callback differently (e.g., polling-based flows). oauth-open-browser: func(url: string) -> result<_, string>; + + /// Information needed to display the device flow prompt modal to the user. + record device-flow-prompt-info { + /// The user code to display (e.g., "ABC-123"). + user-code: string, + /// The URL the user needs to visit to authorize (for the "Connect" button). + verification-url: string, + /// The headline text for the modal (e.g., "Use GitHub Copilot in Zed."). + headline: string, + /// A description to show below the headline (e.g., "Using Copilot requires an active subscription on GitHub."). + description: string, + /// Label for the connect button (e.g., "Connect to GitHub"). + connect-button-label: string, + /// Success headline shown when authorization completes. + success-headline: string, + /// Success message shown when authorization completes. + success-message: string, + } } diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index 0f3d1eefee9e04e77ea6cbbea3249f44c4efd504..a0fb499013ac096c741af055d2ecd66f0e28be4a 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -57,6 +57,7 @@ theme.workspace = true toml.workspace = true ui.workspace = true url.workspace = true +workspace.workspace = true util.workspace = true wasmparser.workspace = true wasmtime-wasi.workspace = true diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index 0b57b7375ef6f3547c687051f40ab4826ea6c4c3..68e48e825f3430345ba6398011816a8448ef40f6 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -1,6 +1,7 @@ use crate::ExtensionSettings; use crate::LEGACY_LLM_EXTENSION_IDS; use crate::wasm_host::WasmExtension; +use crate::wasm_host::wit::LlmDeviceFlowPromptInfo; use collections::HashSet; use crate::wasm_host::wit::{ @@ -18,8 +19,8 @@ use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::Focusable; use gpui::{ - AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter, - MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px, + AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, + TextStyleRefinement, UnderlineStyle, Window, px, }; use language_model::tool_schema::LanguageModelToolSchemaFormat; use language_model::{ @@ -35,6 +36,10 @@ use std::sync::Arc; use theme::ThemeSettings; use ui::{Label, LabelSize, prelude::*}; use util::ResultExt as _; +use workspace::Workspace; +use workspace::oauth_device_flow_modal::{ + OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus, +}; /// An extension-based language model provider. pub struct ExtensionLanguageModelProvider { @@ -259,6 +264,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { let state = self.state.clone(); let auth_config = self.auth_config.clone(); + let icon_path = self.icon_path.clone(); cx.new(|cx| { ExtensionProviderConfigurationView::new( credential_key, @@ -267,6 +273,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { full_provider_id, auth_config, state, + icon_path, window, cx, ) @@ -348,7 +355,7 @@ struct ExtensionProviderConfigurationView { loading_credentials: bool, oauth_in_progress: bool, oauth_error: Option, - device_user_code: Option, + icon_path: Option, _subscriptions: Vec, } @@ -360,6 +367,7 @@ impl ExtensionProviderConfigurationView { full_provider_id: String, auth_config: Option, state: Entity, + icon_path: Option, window: &mut Window, cx: &mut Context, ) -> Self { @@ -388,7 +396,7 @@ impl ExtensionProviderConfigurationView { loading_credentials: true, oauth_in_progress: false, oauth_error: None, - device_user_code: None, + icon_path, _subscriptions: vec![state_subscription], }; @@ -488,7 +496,6 @@ impl ExtensionProviderConfigurationView { // Update settings file settings::update_settings_file(::global(cx), cx, { - let settings_key = settings_key.clone(); move |settings, _| { let allowed = settings .extension @@ -510,22 +517,21 @@ impl ExtensionProviderConfigurationView { // Update local state let new_allowed = !currently_allowed; - let env_var_name_clone = env_var_name.clone(); state.update(cx, |state, cx| { if new_allowed { - state.allowed_env_vars.insert(env_var_name_clone.clone()); + state.allowed_env_vars.insert(env_var_name.clone()); // Check if this env var is set and update env_var_name_used - if let Ok(value) = std::env::var(&env_var_name_clone) { + if let Ok(value) = std::env::var(&env_var_name) { if !value.is_empty() && state.env_var_name_used.is_none() { - state.env_var_name_used = Some(env_var_name_clone); + state.env_var_name_used = Some(env_var_name.clone()); state.is_authenticated = true; } } } else { - state.allowed_env_vars.remove(&env_var_name_clone); + state.allowed_env_vars.remove(&env_var_name); // If this was the env var being used, clear it and find another - if state.env_var_name_used.as_ref() == Some(&env_var_name_clone) { + if state.env_var_name_used.as_ref() == Some(&env_var_name) { state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| { if let Ok(value) = std::env::var(var) { if !value.is_empty() { @@ -637,22 +643,33 @@ impl ExtensionProviderConfigurationView { .detach(); } - fn start_oauth_sign_in(&mut self, cx: &mut Context) { + fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context) { if self.oauth_in_progress { return; } self.oauth_in_progress = true; self.oauth_error = None; - self.device_user_code = None; cx.notify(); let extension = self.extension.clone(); let provider_id = self.extension_provider_id.clone(); let state = self.state.clone(); + let icon_path = self.icon_path.clone(); + let this_handle = cx.weak_entity(); - cx.spawn(async move |this, cx| { - // Step 1: Start device flow - opens browser and returns user code + // Get workspace to show modal + let Some(workspace) = window.root::().flatten() else { + self.oauth_in_progress = false; + self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string()); + cx.notify(); + return; + }; + + let workspace = workspace.downgrade(); + let state = state.downgrade(); + cx.spawn_in(window, async move |_this, cx| { + // Step 1: Start device flow - get prompt info from extension let start_result = extension .call({ let provider_id = provider_id.clone(); @@ -666,38 +683,67 @@ impl ExtensionProviderConfigurationView { }) .await; - let user_code = match start_result { - Ok(Ok(Ok(code))) => code, + let prompt_info: LlmDeviceFlowPromptInfo = match start_result { + Ok(Ok(Ok(info))) => info, Ok(Ok(Err(e))) => { log::error!("Device flow start failed: {}", e); - this.update(cx, |this, cx| { - this.oauth_in_progress = false; - this.oauth_error = Some(e); - cx.notify(); - }) - .log_err(); + this_handle + .update_in(cx, |this, _window, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e); + cx.notify(); + }) + .log_err(); return; } Ok(Err(e)) | Err(e) => { log::error!("Device flow start error: {}", e); - this.update(cx, |this, cx| { + this_handle + .update_in(cx, |this, _window, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e.to_string()); + cx.notify(); + }) + .log_err(); + return; + } + }; + + // Step 2: Create state entity and show the modal + let modal_config = OAuthDeviceFlowModalConfig { + user_code: prompt_info.user_code, + verification_url: prompt_info.verification_url, + headline: prompt_info.headline, + description: prompt_info.description, + connect_button_label: prompt_info.connect_button_label, + success_headline: prompt_info.success_headline, + success_message: prompt_info.success_message, + icon_path, + }; + + let flow_state: Option> = workspace + .update_in(cx, |workspace, window, cx| { + let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config)); + let flow_state_clone = flow_state.clone(); + workspace.toggle_modal(window, cx, |_window, cx| { + OAuthDeviceFlowModal::new(flow_state_clone, cx) + }); + flow_state + }) + .ok(); + + let Some(flow_state) = flow_state else { + this_handle + .update_in(cx, |this, _window, cx| { this.oauth_in_progress = false; - this.oauth_error = Some(e.to_string()); + this.oauth_error = Some("Failed to show sign-in modal".to_string()); cx.notify(); }) .log_err(); - return; - } + return; }; - // Update UI to show the user code before polling - this.update(cx, |this, cx| { - this.device_user_code = Some(user_code); - cx.notify(); - }) - .log_err(); - - // Step 2: Poll for authentication completion + // Step 3: Poll for authentication completion let poll_result = extension .call({ let provider_id = provider_id.clone(); @@ -711,7 +757,7 @@ impl ExtensionProviderConfigurationView { }) .await; - let error_message = match poll_result { + match poll_result { Ok(Ok(Ok(()))) => { // After successful auth, refresh the models list let models_result = extension @@ -731,33 +777,61 @@ impl ExtensionProviderConfigurationView { _ => Vec::new(), }; - cx.update(|cx| { - state.update(cx, |state, cx| { + state + .update_in(cx, |state, _window, cx| { state.is_authenticated = true; state.available_models = new_models; cx.notify(); - }); - }) - .log_err(); - None + }) + .log_err(); + + // Update flow state to show success + flow_state + .update_in(cx, |state, _window, cx| { + state.set_status(OAuthDeviceFlowStatus::Authorized, cx); + }) + .log_err(); } Ok(Ok(Err(e))) => { log::error!("Device flow poll failed: {}", e); - Some(e) + flow_state + .update_in(cx, |state, _window, cx| { + state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx); + }) + .log_err(); + this_handle + .update_in(cx, |this, _window, cx| { + this.oauth_error = Some(e); + cx.notify(); + }) + .log_err(); } Ok(Err(e)) | Err(e) => { log::error!("Device flow poll error: {}", e); - Some(e.to_string()) + let error_string = e.to_string(); + flow_state + .update_in(cx, |state, _window, cx| { + state.set_status( + OAuthDeviceFlowStatus::Failed(error_string.clone()), + cx, + ); + }) + .log_err(); + this_handle + .update_in(cx, |this, _window, cx| { + this.oauth_error = Some(error_string); + cx.notify(); + }) + .log_err(); } }; - this.update(cx, |this, cx| { - this.oauth_in_progress = false; - this.oauth_error = error_message; - this.device_user_code = None; - cx.notify(); - }) - .log_err(); + this_handle + .update_in(cx, |this, _window, cx| { + this.oauth_in_progress = false; + cx.notify(); + }) + .log_err(); }) .detach(); } @@ -784,7 +858,7 @@ impl ExtensionProviderConfigurationView { } impl gpui::Render for ExtensionProviderConfigurationView { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let is_loading = self.loading_settings || self.loading_credentials; let is_authenticated = self.is_authenticated(cx); let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone(); @@ -803,7 +877,7 @@ impl gpui::Render for ExtensionProviderConfigurationView { // Render settings markdown if available if let Some(markdown) = &self.settings_markdown { - let style = settings_markdown_style(_window, cx); + let style = settings_markdown_style(window, cx); content = content.child(MarkdownElement::new(markdown.clone(), style)); } @@ -948,90 +1022,30 @@ impl gpui::Render for ExtensionProviderConfigurationView { let oauth_error = self.oauth_error.clone(); + let mut button = ui::Button::new("oauth-sign-in", button_label) + .full_width() + .style(ui::ButtonStyle::Outlined) + .disabled(oauth_in_progress) + .on_click(cx.listener(|this, _, window, cx| { + this.start_oauth_sign_in(window, cx); + })); + if let Some(icon) = button_icon { + button = button + .icon(icon) + .icon_position(ui::IconPosition::Start) + .icon_size(ui::IconSize::Small) + .icon_color(Color::Muted); + } + content = content.child( v_flex() .gap_2() - .child({ - let mut button = ui::Button::new("oauth-sign-in", button_label) - .full_width() - .style(ui::ButtonStyle::Outlined) - .disabled(oauth_in_progress) - .on_click(cx.listener(|this, _, _window, cx| { - this.start_oauth_sign_in(cx); - })); - if let Some(icon) = button_icon { - button = button - .icon(icon) - .icon_position(ui::IconPosition::Start) - .icon_size(ui::IconSize::Small) - .icon_color(Color::Muted); - } - button - }) + .child(button) .when(oauth_in_progress, |this| { - let user_code = self.device_user_code.clone(); this.child( - v_flex() - .gap_1() - .when_some(user_code, |this, code| { - let copied = cx - .read_from_clipboard() - .map(|item| item.text().as_ref() == Some(&code)) - .unwrap_or(false); - let code_for_click = code.clone(); - this.child( - h_flex() - .gap_1() - .child( - Label::new("Enter code:") - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - h_flex() - .gap_1() - .px_1() - .border_1() - .border_color(cx.theme().colors().border) - .rounded_sm() - .cursor_pointer() - .on_mouse_down( - MouseButton::Left, - move |_, window, cx| { - cx.write_to_clipboard( - ClipboardItem::new_string( - code_for_click.clone(), - ), - ); - window.refresh(); - }, - ) - .child( - Label::new(code) - .size(LabelSize::Small) - .color(Color::Accent), - ) - .child( - ui::Icon::new(if copied { - ui::IconName::Check - } else { - ui::IconName::Copy - }) - .size(ui::IconSize::Small) - .color(if copied { - Color::Success - } else { - Color::Muted - }), - ), - ), - ) - }) - .child( - Label::new("Waiting for authorization in browser...") - .size(LabelSize::Small) - .color(Color::Muted), - ), + Label::new("Sign-in in progress...") + .size(LabelSize::Small) + .color(Color::Muted), ) }) .when_some(oauth_error, |this, error| { diff --git a/crates/extension_host/src/wasm_host/wit.rs b/crates/extension_host/src/wasm_host/wit.rs index eeb28e59ebbee18bd7acda4cbcc5c8e04c63c05a..84fe2af71c317bab0f944041f5d5be2fee9fa462 100644 --- a/crates/extension_host/src/wasm_host/wit.rs +++ b/crates/extension_host/src/wasm_host/wit.rs @@ -35,15 +35,16 @@ pub use latest::{ zed::extension::context_server::ContextServerConfiguration, zed::extension::llm_provider::{ CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, - CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData, - MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, - ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo, - ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, - StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, - TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, - ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat, - ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent, - ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError, + CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, }, zed::extension::lsp::{ Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind, @@ -1233,7 +1234,7 @@ impl Extension { &self, store: &mut Store, provider_id: &str, - ) -> Result> { + ) -> Result> { match self { Extension::V0_8_0(ext) => { ext.call_llm_provider_start_device_flow_sign_in(store, provider_id) diff --git a/crates/workspace/src/oauth_device_flow_modal.rs b/crates/workspace/src/oauth_device_flow_modal.rs new file mode 100644 index 0000000000000000000000000000000000000000..5a5ad4ddd424e54abca468ef11219ae1ce6ff239 --- /dev/null +++ b/crates/workspace/src/oauth_device_flow_modal.rs @@ -0,0 +1,271 @@ +use gpui::{ + Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent, Element, Entity, + EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent, + ParentElement, Render, SharedString, Styled, Subscription, Transformation, Window, div, + percentage, svg, +}; +use menu; +use std::time::Duration; +use ui::{Button, Icon, IconName, Label, prelude::*}; + +use crate::ModalView; + +/// Configuration for the OAuth device flow modal. +/// This allows extensions to specify the text and appearance of the modal. +#[derive(Clone)] +pub struct OAuthDeviceFlowModalConfig { + /// The user code to display (e.g., "ABC-123"). + pub user_code: String, + /// The URL the user needs to visit to authorize (for the "Connect" button). + pub verification_url: String, + /// The headline text for the modal (e.g., "Use GitHub Copilot in Zed."). + pub headline: String, + /// A description to show below the headline. + pub description: String, + /// Label for the connect button (e.g., "Connect to GitHub"). + pub connect_button_label: String, + /// Success headline shown when authorization completes. + pub success_headline: String, + /// Success message shown when authorization completes. + pub success_message: String, + /// Optional path to an SVG icon file (absolute path on disk). + pub icon_path: Option, +} + +/// The current status of the OAuth device flow. +#[derive(Clone, Debug)] +pub enum OAuthDeviceFlowStatus { + /// Waiting for user to click connect and authorize. + Prompting, + /// User clicked connect, waiting for authorization. + WaitingForAuthorization, + /// Successfully authorized. + Authorized, + /// Authorization failed with an error message. + Failed(String), +} + +/// Shared state for the OAuth device flow that can be observed by the modal. +pub struct OAuthDeviceFlowState { + pub config: OAuthDeviceFlowModalConfig, + pub status: OAuthDeviceFlowStatus, +} + +impl EventEmitter<()> for OAuthDeviceFlowState {} + +impl OAuthDeviceFlowState { + pub fn new(config: OAuthDeviceFlowModalConfig) -> Self { + Self { + config, + status: OAuthDeviceFlowStatus::Prompting, + } + } + + /// Update the status of the OAuth flow. + pub fn set_status(&mut self, status: OAuthDeviceFlowStatus, cx: &mut Context) { + self.status = status; + cx.emit(()); + cx.notify(); + } +} + +/// A generic OAuth device flow modal that can be used by extensions. +pub struct OAuthDeviceFlowModal { + state: Entity, + connect_clicked: bool, + focus_handle: FocusHandle, + _subscription: Subscription, +} + +impl Focusable for OAuthDeviceFlowModal { + fn focus_handle(&self, _: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for OAuthDeviceFlowModal {} + +impl ModalView for OAuthDeviceFlowModal {} + +impl OAuthDeviceFlowModal { + pub fn new(state: Entity, cx: &mut Context) -> Self { + let subscription = cx.observe(&state, |_, _, cx| { + cx.notify(); + }); + + Self { + state, + connect_clicked: false, + focus_handle: cx.focus_handle(), + _subscription: subscription, + } + } + + fn render_icon(&self, cx: &mut Context) -> impl IntoElement { + let state = self.state.read(cx); + if let Some(icon_path) = &state.config.icon_path { + Icon::from_external_svg(icon_path.clone()) + .size(ui::IconSize::XLarge) + .color(Color::Custom(cx.theme().colors().icon)) + .into_any_element() + } else { + div().into_any_element() + } + } + + fn render_device_code(&self, cx: &mut Context) -> impl IntoElement { + let state = self.state.read(cx); + let user_code = state.config.user_code.clone(); + let copied = cx + .read_from_clipboard() + .map(|item| item.text().as_ref() == Some(&user_code)) + .unwrap_or(false); + let user_code_for_click = user_code.clone(); + + h_flex() + .w_full() + .p_1() + .border_1() + .border_muted(cx) + .rounded_sm() + .cursor_pointer() + .justify_between() + .on_mouse_down(gpui::MouseButton::Left, move |_, window, cx| { + cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone())); + window.refresh(); + }) + .child(div().flex_1().child(Label::new(user_code))) + .child(div().flex_none().px_1().child(Label::new(if copied { + "Copied!" + } else { + "Copy" + }))) + } + + fn render_prompting_modal(&self, cx: &mut Context) -> impl Element { + let (connect_button_label, verification_url, headline, description) = { + let state = self.state.read(cx); + let label = if self.connect_clicked { + "Waiting for connection...".to_string() + } else { + state.config.connect_button_label.clone() + }; + ( + label, + state.config.verification_url.clone(), + state.config.headline.clone(), + state.config.description.clone(), + ) + }; + + v_flex() + .flex_1() + .gap_2() + .items_center() + .child(Headline::new(headline).size(HeadlineSize::Large)) + .child(Label::new(description).color(Color::Muted)) + .child(self.render_device_code(cx)) + .child( + Label::new("Paste this code into GitHub after clicking the button below.") + .size(ui::LabelSize::Small), + ) + .child( + Button::new("connect-button", connect_button_label) + .on_click(cx.listener(move |this, _, _window, cx| { + cx.open_url(&verification_url); + this.connect_clicked = true; + })) + .full_width() + .style(ButtonStyle::Filled), + ) + .child( + Button::new("cancel-button", "Cancel") + .full_width() + .on_click(cx.listener(|_, _, _, cx| { + cx.emit(DismissEvent); + })), + ) + } + + fn render_authorized_modal(&self, cx: &mut Context) -> impl Element { + let state = self.state.read(cx); + let success_headline = state.config.success_headline.clone(); + let success_message = state.config.success_message.clone(); + + v_flex() + .gap_2() + .child(Headline::new(success_headline).size(HeadlineSize::Large)) + .child(Label::new(success_message)) + .child( + Button::new("done-button", "Done") + .full_width() + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + } + + fn render_failed_modal(&self, error: &str, cx: &mut Context) -> impl Element { + v_flex() + .gap_2() + .child(Headline::new("Authorization Failed").size(HeadlineSize::Large)) + .child(Label::new(error.to_string()).color(Color::Error)) + .child( + Button::new("close-button", "Close") + .full_width() + .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))), + ) + } + + fn render_loading(window: &mut Window, _cx: &mut Context) -> impl Element { + let loading_icon = svg() + .size_8() + .path(IconName::ArrowCircle.path()) + .text_color(window.text_style().color) + .with_animation( + "icon_circle_arrow", + Animation::new(Duration::from_secs(2)).repeat(), + |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))), + ); + + h_flex().justify_center().child(loading_icon) + } +} + +impl Render for OAuthDeviceFlowModal { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { + let status = self.state.read(cx).status.clone(); + + let prompt = match &status { + OAuthDeviceFlowStatus::Prompting => self.render_prompting_modal(cx).into_any_element(), + OAuthDeviceFlowStatus::WaitingForAuthorization => { + if self.connect_clicked { + self.render_prompting_modal(cx).into_any_element() + } else { + Self::render_loading(window, cx).into_any_element() + } + } + OAuthDeviceFlowStatus::Authorized => { + self.render_authorized_modal(cx).into_any_element() + } + OAuthDeviceFlowStatus::Failed(error) => { + self.render_failed_modal(error, cx).into_any_element() + } + }; + + v_flex() + .id("oauth-device-flow-modal") + .track_focus(&self.focus_handle(cx)) + .elevation_3(cx) + .w_96() + .items_center() + .p_4() + .gap_2() + .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| { + cx.emit(DismissEvent); + })) + .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _| { + window.focus(&this.focus_handle); + })) + .child(self.render_icon(cx)) + .child(prompt) + } +} diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index d2a9ef71fc7fc2aacb1fc2f9be41ce001f5cef5e..ec557187e0852355dbd9f083397fda17d2ddbf42 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -4,6 +4,7 @@ pub mod invalid_item_view; pub mod item; mod modal_layer; pub mod notifications; +pub mod oauth_device_flow_modal; pub mod pane; pub mod pane_group; mod path_list; diff --git a/extensions/copilot-chat/src/copilot_chat.rs b/extensions/copilot-chat/src/copilot_chat.rs index bdeb59b01d810709a6c73af8b513a8ede00a8dfd..dfd0e9aa570ecc94f9cda627b6da995af297d693 100644 --- a/extensions/copilot-chat/src/copilot_chat.rs +++ b/extensions/copilot-chat/src/copilot_chat.rs @@ -454,7 +454,7 @@ impl zed::Extension for CopilotChatProvider { fn llm_provider_start_device_flow_sign_in( &mut self, _provider_id: &str, - ) -> Result { + ) -> Result { // Step 1: Request device and user verification codes let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest { url: GITHUB_DEVICE_CODE_URL.to_string(), @@ -497,7 +497,7 @@ impl zed::Extension for CopilotChatProvider { expires_in: device_info.expires_in, }); - // Step 2: Open browser to verification URL + // Step 2: Construct verification URL // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| { format!( @@ -505,10 +505,19 @@ impl zed::Extension for CopilotChatProvider { device_info.verification_uri, &device_info.user_code ) }); - llm_oauth_open_browser(&verification_url)?; - // Return the user code for the host to display - Ok(device_info.user_code) + // Return prompt info for the host to display in the modal + Ok(LlmDeviceFlowPromptInfo { + user_code: device_info.user_code, + verification_url, + headline: "Use GitHub Copilot in Zed.".to_string(), + description: "Using Copilot requires an active subscription on GitHub.".to_string(), + connect_button_label: "Connect to GitHub".to_string(), + success_headline: "Copilot Enabled!".to_string(), + success_message: + "You can update your settings or sign out from the Copilot menu in the status bar." + .to_string(), + }) } fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {