Cargo.lock 🔗
@@ -5889,6 +5889,7 @@ dependencies = [
"wasmparser 0.221.3",
"wasmtime",
"wasmtime-wasi",
+ "workspace",
"zlog",
]
Richard Feldman created
Cargo.lock | 1
crates/extension_api/src/extension_api.rs | 32
crates/extension_api/wit/since_v0.8.0/extension.wit | 14
crates/extension_api/wit/since_v0.8.0/llm-provider.wit | 18
crates/extension_host/Cargo.toml | 1
crates/extension_host/src/wasm_host/llm_provider.rs | 278 ++++++-----
crates/extension_host/src/wasm_host/wit.rs | 21
crates/workspace/src/oauth_device_flow_modal.rs | 271 +++++++++++
crates/workspace/src/workspace.rs | 1
extensions/copilot-chat/src/copilot_chat.rs | 19
10 files changed, 487 insertions(+), 169 deletions(-)
@@ -5889,6 +5889,7 @@ dependencies = [
"wasmparser 0.221.3",
"wasmtime",
"wasmtime-wasi",
+ "workspace",
"zlog",
]
@@ -31,17 +31,18 @@ pub use wit::{
},
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
- CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData,
- MessageContent as LlmMessageContent, MessageRole as LlmMessageRole,
- ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo,
- OauthHttpRequest as LlmOauthHttpRequest, OauthHttpResponse as LlmOauthHttpResponse,
- OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult,
- ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage,
- StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent,
- TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice,
- ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat,
- ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent,
- ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError,
+ CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo,
+ ImageData as LlmImageData, MessageContent as LlmMessageContent,
+ MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
+ ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
+ OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
+ OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
+ RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
+ ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
+ ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
+ ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
+ ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
+ ToolUseJsonParseError as LlmToolUseJsonParseError,
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
oauth_send_http_request as llm_oauth_send_http_request,
@@ -301,12 +302,11 @@ pub trait Extension: Send + Sync {
/// Start an OAuth device flow sign-in.
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
- /// Opens the browser to the verification URL and returns the user code that should
- /// be displayed to the user.
+ /// Returns information needed to display the device flow prompt modal to the user.
fn llm_provider_start_device_flow_sign_in(
&mut self,
_provider_id: &str,
- ) -> Result<String, String> {
+ ) -> Result<LlmDeviceFlowPromptInfo, String> {
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
}
@@ -641,7 +641,9 @@ impl wit::Guest for Component {
extension().llm_provider_is_authenticated(&provider_id)
}
- fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
+ fn llm_provider_start_device_flow_sign_in(
+ provider_id: String,
+ ) -> Result<LlmDeviceFlowPromptInfo, String> {
extension().llm_provider_start_device_flow_sign_in(&provider_id)
}
@@ -18,7 +18,8 @@ world extension {
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
use llm-provider.{
provider-info, model-info, completion-request,
- cache-configuration, completion-event, token-usage
+ cache-configuration, completion-event, token-usage,
+ device-flow-prompt-info
};
/// Initializes the extension.
@@ -188,15 +189,14 @@ world extension {
///
/// The device flow works as follows:
/// 1. Extension requests a device code from the OAuth provider
- /// 2. Extension opens the verification URL in the browser
- /// 3. Extension returns the user code to display to the user
- /// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
+ /// 2. Extension returns prompt info including user code and verification URL
+ /// 3. Host displays a modal with the prompt info
+ /// 4. Host calls llm-provider-poll-device-flow-sign-in
/// 5. Extension polls for the access token while user authorizes in browser
/// 6. Once authorized, extension stores the credential and returns success
///
- /// Returns the user code that should be displayed to the user while they
- /// complete authorization in the browser.
- export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
+ /// Returns information needed to display the device flow prompt modal.
+ export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<device-flow-prompt-info, string>;
/// Poll for device flow sign-in completion.
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
@@ -328,4 +328,22 @@ interface llm-provider {
/// Useful for OAuth flows that need to open a browser but handle the
/// callback differently (e.g., polling-based flows).
oauth-open-browser: func(url: string) -> result<_, string>;
+
+ /// Information needed to display the device flow prompt modal to the user.
+ record device-flow-prompt-info {
+ /// The user code to display (e.g., "ABC-123").
+ user-code: string,
+ /// The URL the user needs to visit to authorize (for the "Connect" button).
+ verification-url: string,
+ /// The headline text for the modal (e.g., "Use GitHub Copilot in Zed.").
+ headline: string,
+ /// A description to show below the headline (e.g., "Using Copilot requires an active subscription on GitHub.").
+ description: string,
+ /// Label for the connect button (e.g., "Connect to GitHub").
+ connect-button-label: string,
+ /// Success headline shown when authorization completes.
+ success-headline: string,
+ /// Success message shown when authorization completes.
+ success-message: string,
+ }
}
@@ -57,6 +57,7 @@ theme.workspace = true
toml.workspace = true
ui.workspace = true
url.workspace = true
+workspace.workspace = true
util.workspace = true
wasmparser.workspace = true
wasmtime-wasi.workspace = true
@@ -1,6 +1,7 @@
use crate::ExtensionSettings;
use crate::LEGACY_LLM_EXTENSION_IDS;
use crate::wasm_host::WasmExtension;
+use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
use collections::HashSet;
use crate::wasm_host::wit::{
@@ -18,8 +19,8 @@ use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use gpui::Focusable;
use gpui::{
- AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
- MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
+ AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
+ TextStyleRefinement, UnderlineStyle, Window, px,
};
use language_model::tool_schema::LanguageModelToolSchemaFormat;
use language_model::{
@@ -35,6 +36,10 @@ use std::sync::Arc;
use theme::ThemeSettings;
use ui::{Label, LabelSize, prelude::*};
use util::ResultExt as _;
+use workspace::Workspace;
+use workspace::oauth_device_flow_modal::{
+ OAuthDeviceFlowModal, OAuthDeviceFlowModalConfig, OAuthDeviceFlowState, OAuthDeviceFlowStatus,
+};
/// An extension-based language model provider.
pub struct ExtensionLanguageModelProvider {
@@ -259,6 +264,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
let state = self.state.clone();
let auth_config = self.auth_config.clone();
+ let icon_path = self.icon_path.clone();
cx.new(|cx| {
ExtensionProviderConfigurationView::new(
credential_key,
@@ -267,6 +273,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
full_provider_id,
auth_config,
state,
+ icon_path,
window,
cx,
)
@@ -348,7 +355,7 @@ struct ExtensionProviderConfigurationView {
loading_credentials: bool,
oauth_in_progress: bool,
oauth_error: Option<String>,
- device_user_code: Option<String>,
+ icon_path: Option<SharedString>,
_subscriptions: Vec<Subscription>,
}
@@ -360,6 +367,7 @@ impl ExtensionProviderConfigurationView {
full_provider_id: String,
auth_config: Option<LanguageModelAuthConfig>,
state: Entity<ExtensionLlmProviderState>,
+ icon_path: Option<SharedString>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -388,7 +396,7 @@ impl ExtensionProviderConfigurationView {
loading_credentials: true,
oauth_in_progress: false,
oauth_error: None,
- device_user_code: None,
+ icon_path,
_subscriptions: vec![state_subscription],
};
@@ -488,7 +496,6 @@ impl ExtensionProviderConfigurationView {
// Update settings file
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
- let settings_key = settings_key.clone();
move |settings, _| {
let allowed = settings
.extension
@@ -510,22 +517,21 @@ impl ExtensionProviderConfigurationView {
// Update local state
let new_allowed = !currently_allowed;
- let env_var_name_clone = env_var_name.clone();
state.update(cx, |state, cx| {
if new_allowed {
- state.allowed_env_vars.insert(env_var_name_clone.clone());
+ state.allowed_env_vars.insert(env_var_name.clone());
// Check if this env var is set and update env_var_name_used
- if let Ok(value) = std::env::var(&env_var_name_clone) {
+ if let Ok(value) = std::env::var(&env_var_name) {
if !value.is_empty() && state.env_var_name_used.is_none() {
- state.env_var_name_used = Some(env_var_name_clone);
+ state.env_var_name_used = Some(env_var_name.clone());
state.is_authenticated = true;
}
}
} else {
- state.allowed_env_vars.remove(&env_var_name_clone);
+ state.allowed_env_vars.remove(&env_var_name);
// If this was the env var being used, clear it and find another
- if state.env_var_name_used.as_ref() == Some(&env_var_name_clone) {
+ if state.env_var_name_used.as_ref() == Some(&env_var_name) {
state.env_var_name_used = state.allowed_env_vars.iter().find_map(|var| {
if let Ok(value) = std::env::var(var) {
if !value.is_empty() {
@@ -637,22 +643,33 @@ impl ExtensionProviderConfigurationView {
.detach();
}
- fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
+ fn start_oauth_sign_in(&mut self, window: &mut Window, cx: &mut Context<Self>) {
if self.oauth_in_progress {
return;
}
self.oauth_in_progress = true;
self.oauth_error = None;
- self.device_user_code = None;
cx.notify();
let extension = self.extension.clone();
let provider_id = self.extension_provider_id.clone();
let state = self.state.clone();
+ let icon_path = self.icon_path.clone();
+ let this_handle = cx.weak_entity();
- cx.spawn(async move |this, cx| {
- // Step 1: Start device flow - opens browser and returns user code
+ // Get workspace to show modal
+ let Some(workspace) = window.root::<Workspace>().flatten() else {
+ self.oauth_in_progress = false;
+ self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string());
+ cx.notify();
+ return;
+ };
+
+ let workspace = workspace.downgrade();
+ let state = state.downgrade();
+ cx.spawn_in(window, async move |_this, cx| {
+ // Step 1: Start device flow - get prompt info from extension
let start_result = extension
.call({
let provider_id = provider_id.clone();
@@ -666,38 +683,67 @@ impl ExtensionProviderConfigurationView {
})
.await;
- let user_code = match start_result {
- Ok(Ok(Ok(code))) => code,
+ let prompt_info: LlmDeviceFlowPromptInfo = match start_result {
+ Ok(Ok(Ok(info))) => info,
Ok(Ok(Err(e))) => {
log::error!("Device flow start failed: {}", e);
- this.update(cx, |this, cx| {
- this.oauth_in_progress = false;
- this.oauth_error = Some(e);
- cx.notify();
- })
- .log_err();
+ this_handle
+ .update_in(cx, |this, _window, cx| {
+ this.oauth_in_progress = false;
+ this.oauth_error = Some(e);
+ cx.notify();
+ })
+ .log_err();
return;
}
Ok(Err(e)) | Err(e) => {
log::error!("Device flow start error: {}", e);
- this.update(cx, |this, cx| {
+ this_handle
+ .update_in(cx, |this, _window, cx| {
+ this.oauth_in_progress = false;
+ this.oauth_error = Some(e.to_string());
+ cx.notify();
+ })
+ .log_err();
+ return;
+ }
+ };
+
+ // Step 2: Create state entity and show the modal
+ let modal_config = OAuthDeviceFlowModalConfig {
+ user_code: prompt_info.user_code,
+ verification_url: prompt_info.verification_url,
+ headline: prompt_info.headline,
+ description: prompt_info.description,
+ connect_button_label: prompt_info.connect_button_label,
+ success_headline: prompt_info.success_headline,
+ success_message: prompt_info.success_message,
+ icon_path,
+ };
+
+ let flow_state: Option<Entity<OAuthDeviceFlowState>> = workspace
+ .update_in(cx, |workspace, window, cx| {
+ let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config));
+ let flow_state_clone = flow_state.clone();
+ workspace.toggle_modal(window, cx, |_window, cx| {
+ OAuthDeviceFlowModal::new(flow_state_clone, cx)
+ });
+ flow_state
+ })
+ .ok();
+
+ let Some(flow_state) = flow_state else {
+ this_handle
+ .update_in(cx, |this, _window, cx| {
this.oauth_in_progress = false;
- this.oauth_error = Some(e.to_string());
+ this.oauth_error = Some("Failed to show sign-in modal".to_string());
cx.notify();
})
.log_err();
- return;
- }
+ return;
};
- // Update UI to show the user code before polling
- this.update(cx, |this, cx| {
- this.device_user_code = Some(user_code);
- cx.notify();
- })
- .log_err();
-
- // Step 2: Poll for authentication completion
+ // Step 3: Poll for authentication completion
let poll_result = extension
.call({
let provider_id = provider_id.clone();
@@ -711,7 +757,7 @@ impl ExtensionProviderConfigurationView {
})
.await;
- let error_message = match poll_result {
+ match poll_result {
Ok(Ok(Ok(()))) => {
// After successful auth, refresh the models list
let models_result = extension
@@ -731,33 +777,61 @@ impl ExtensionProviderConfigurationView {
_ => Vec::new(),
};
- cx.update(|cx| {
- state.update(cx, |state, cx| {
+ state
+ .update_in(cx, |state, _window, cx| {
state.is_authenticated = true;
state.available_models = new_models;
cx.notify();
- });
- })
- .log_err();
- None
+ })
+ .log_err();
+
+ // Update flow state to show success
+ flow_state
+ .update_in(cx, |state, _window, cx| {
+ state.set_status(OAuthDeviceFlowStatus::Authorized, cx);
+ })
+ .log_err();
}
Ok(Ok(Err(e))) => {
log::error!("Device flow poll failed: {}", e);
- Some(e)
+ flow_state
+ .update_in(cx, |state, _window, cx| {
+ state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx);
+ })
+ .log_err();
+ this_handle
+ .update_in(cx, |this, _window, cx| {
+ this.oauth_error = Some(e);
+ cx.notify();
+ })
+ .log_err();
}
Ok(Err(e)) | Err(e) => {
log::error!("Device flow poll error: {}", e);
- Some(e.to_string())
+ let error_string = e.to_string();
+ flow_state
+ .update_in(cx, |state, _window, cx| {
+ state.set_status(
+ OAuthDeviceFlowStatus::Failed(error_string.clone()),
+ cx,
+ );
+ })
+ .log_err();
+ this_handle
+ .update_in(cx, |this, _window, cx| {
+ this.oauth_error = Some(error_string);
+ cx.notify();
+ })
+ .log_err();
}
};
- this.update(cx, |this, cx| {
- this.oauth_in_progress = false;
- this.oauth_error = error_message;
- this.device_user_code = None;
- cx.notify();
- })
- .log_err();
+ this_handle
+ .update_in(cx, |this, _window, cx| {
+ this.oauth_in_progress = false;
+ cx.notify();
+ })
+ .log_err();
})
.detach();
}
@@ -784,7 +858,7 @@ impl ExtensionProviderConfigurationView {
}
impl gpui::Render for ExtensionProviderConfigurationView {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_loading = self.loading_settings || self.loading_credentials;
let is_authenticated = self.is_authenticated(cx);
let allowed_env_vars = self.state.read(cx).allowed_env_vars.clone();
@@ -803,7 +877,7 @@ impl gpui::Render for ExtensionProviderConfigurationView {
// Render settings markdown if available
if let Some(markdown) = &self.settings_markdown {
- let style = settings_markdown_style(_window, cx);
+ let style = settings_markdown_style(window, cx);
content = content.child(MarkdownElement::new(markdown.clone(), style));
}
@@ -948,90 +1022,30 @@ impl gpui::Render for ExtensionProviderConfigurationView {
let oauth_error = self.oauth_error.clone();
+ let mut button = ui::Button::new("oauth-sign-in", button_label)
+ .full_width()
+ .style(ui::ButtonStyle::Outlined)
+ .disabled(oauth_in_progress)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.start_oauth_sign_in(window, cx);
+ }));
+ if let Some(icon) = button_icon {
+ button = button
+ .icon(icon)
+ .icon_position(ui::IconPosition::Start)
+ .icon_size(ui::IconSize::Small)
+ .icon_color(Color::Muted);
+ }
+
content = content.child(
v_flex()
.gap_2()
- .child({
- let mut button = ui::Button::new("oauth-sign-in", button_label)
- .full_width()
- .style(ui::ButtonStyle::Outlined)
- .disabled(oauth_in_progress)
- .on_click(cx.listener(|this, _, _window, cx| {
- this.start_oauth_sign_in(cx);
- }));
- if let Some(icon) = button_icon {
- button = button
- .icon(icon)
- .icon_position(ui::IconPosition::Start)
- .icon_size(ui::IconSize::Small)
- .icon_color(Color::Muted);
- }
- button
- })
+ .child(button)
.when(oauth_in_progress, |this| {
- let user_code = self.device_user_code.clone();
this.child(
- v_flex()
- .gap_1()
- .when_some(user_code, |this, code| {
- let copied = cx
- .read_from_clipboard()
- .map(|item| item.text().as_ref() == Some(&code))
- .unwrap_or(false);
- let code_for_click = code.clone();
- this.child(
- h_flex()
- .gap_1()
- .child(
- Label::new("Enter code:")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- h_flex()
- .gap_1()
- .px_1()
- .border_1()
- .border_color(cx.theme().colors().border)
- .rounded_sm()
- .cursor_pointer()
- .on_mouse_down(
- MouseButton::Left,
- move |_, window, cx| {
- cx.write_to_clipboard(
- ClipboardItem::new_string(
- code_for_click.clone(),
- ),
- );
- window.refresh();
- },
- )
- .child(
- Label::new(code)
- .size(LabelSize::Small)
- .color(Color::Accent),
- )
- .child(
- ui::Icon::new(if copied {
- ui::IconName::Check
- } else {
- ui::IconName::Copy
- })
- .size(ui::IconSize::Small)
- .color(if copied {
- Color::Success
- } else {
- Color::Muted
- }),
- ),
- ),
- )
- })
- .child(
- Label::new("Waiting for authorization in browser...")
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
+ Label::new("Sign-in in progress...")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
)
})
.when_some(oauth_error, |this, error| {
@@ -35,15 +35,16 @@ pub use latest::{
zed::extension::context_server::ContextServerConfiguration,
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
- CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData,
- MessageContent as LlmMessageContent, MessageRole as LlmMessageRole,
- ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo,
- ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage,
- StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent,
- TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice,
- ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat,
- ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent,
- ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError,
+ CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo,
+ ImageData as LlmImageData, MessageContent as LlmMessageContent,
+ MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
+ ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
+ RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
+ ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
+ ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
+ ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
+ ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
+ ToolUseJsonParseError as LlmToolUseJsonParseError,
},
zed::extension::lsp::{
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
@@ -1233,7 +1234,7 @@ impl Extension {
&self,
store: &mut Store<WasmState>,
provider_id: &str,
- ) -> Result<Result<String, String>> {
+ ) -> Result<Result<LlmDeviceFlowPromptInfo, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
@@ -0,0 +1,271 @@
+use gpui::{
+ Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent, Element, Entity,
+ EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent,
+ ParentElement, Render, SharedString, Styled, Subscription, Transformation, Window, div,
+ percentage, svg,
+};
+use menu;
+use std::time::Duration;
+use ui::{Button, Icon, IconName, Label, prelude::*};
+
+use crate::ModalView;
+
+/// Configuration for the OAuth device flow modal.
+/// This allows extensions to specify the text and appearance of the modal.
+#[derive(Clone)]
+pub struct OAuthDeviceFlowModalConfig {
+ /// The user code to display (e.g., "ABC-123").
+ pub user_code: String,
+ /// The URL the user needs to visit to authorize (for the "Connect" button).
+ pub verification_url: String,
+ /// The headline text for the modal (e.g., "Use GitHub Copilot in Zed.").
+ pub headline: String,
+ /// A description to show below the headline.
+ pub description: String,
+ /// Label for the connect button (e.g., "Connect to GitHub").
+ pub connect_button_label: String,
+ /// Success headline shown when authorization completes.
+ pub success_headline: String,
+ /// Success message shown when authorization completes.
+ pub success_message: String,
+ /// Optional path to an SVG icon file (absolute path on disk).
+ pub icon_path: Option<SharedString>,
+}
+
+/// The current status of the OAuth device flow.
+#[derive(Clone, Debug)]
+pub enum OAuthDeviceFlowStatus {
+ /// Waiting for user to click connect and authorize.
+ Prompting,
+ /// User clicked connect, waiting for authorization.
+ WaitingForAuthorization,
+ /// Successfully authorized.
+ Authorized,
+ /// Authorization failed with an error message.
+ Failed(String),
+}
+
+/// Shared state for the OAuth device flow that can be observed by the modal.
+pub struct OAuthDeviceFlowState {
+ pub config: OAuthDeviceFlowModalConfig,
+ pub status: OAuthDeviceFlowStatus,
+}
+
+impl EventEmitter<()> for OAuthDeviceFlowState {}
+
+impl OAuthDeviceFlowState {
+ pub fn new(config: OAuthDeviceFlowModalConfig) -> Self {
+ Self {
+ config,
+ status: OAuthDeviceFlowStatus::Prompting,
+ }
+ }
+
+ /// Update the status of the OAuth flow.
+ pub fn set_status(&mut self, status: OAuthDeviceFlowStatus, cx: &mut Context<Self>) {
+ self.status = status;
+ cx.emit(());
+ cx.notify();
+ }
+}
+
+/// A generic OAuth device flow modal that can be used by extensions.
+pub struct OAuthDeviceFlowModal {
+ state: Entity<OAuthDeviceFlowState>,
+ connect_clicked: bool,
+ focus_handle: FocusHandle,
+ _subscription: Subscription,
+}
+
+impl Focusable for OAuthDeviceFlowModal {
+ fn focus_handle(&self, _: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl EventEmitter<DismissEvent> for OAuthDeviceFlowModal {}
+
+impl ModalView for OAuthDeviceFlowModal {}
+
+impl OAuthDeviceFlowModal {
+ pub fn new(state: Entity<OAuthDeviceFlowState>, cx: &mut Context<Self>) -> Self {
+ let subscription = cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ });
+
+ Self {
+ state,
+ connect_clicked: false,
+ focus_handle: cx.focus_handle(),
+ _subscription: subscription,
+ }
+ }
+
+ fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ let state = self.state.read(cx);
+ if let Some(icon_path) = &state.config.icon_path {
+ Icon::from_external_svg(icon_path.clone())
+ .size(ui::IconSize::XLarge)
+ .color(Color::Custom(cx.theme().colors().icon))
+ .into_any_element()
+ } else {
+ div().into_any_element()
+ }
+ }
+
+ fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ let state = self.state.read(cx);
+ let user_code = state.config.user_code.clone();
+ let copied = cx
+ .read_from_clipboard()
+ .map(|item| item.text().as_ref() == Some(&user_code))
+ .unwrap_or(false);
+ let user_code_for_click = user_code.clone();
+
+ h_flex()
+ .w_full()
+ .p_1()
+ .border_1()
+ .border_muted(cx)
+ .rounded_sm()
+ .cursor_pointer()
+ .justify_between()
+ .on_mouse_down(gpui::MouseButton::Left, move |_, window, cx| {
+ cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
+ window.refresh();
+ })
+ .child(div().flex_1().child(Label::new(user_code)))
+ .child(div().flex_none().px_1().child(Label::new(if copied {
+ "Copied!"
+ } else {
+ "Copy"
+ })))
+ }
+
+ fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl Element {
+ let (connect_button_label, verification_url, headline, description) = {
+ let state = self.state.read(cx);
+ let label = if self.connect_clicked {
+ "Waiting for connection...".to_string()
+ } else {
+ state.config.connect_button_label.clone()
+ };
+ (
+ label,
+ state.config.verification_url.clone(),
+ state.config.headline.clone(),
+ state.config.description.clone(),
+ )
+ };
+
+ v_flex()
+ .flex_1()
+ .gap_2()
+ .items_center()
+ .child(Headline::new(headline).size(HeadlineSize::Large))
+ .child(Label::new(description).color(Color::Muted))
+ .child(self.render_device_code(cx))
+ .child(
+ Label::new("Paste this code into GitHub after clicking the button below.")
+ .size(ui::LabelSize::Small),
+ )
+ .child(
+ Button::new("connect-button", connect_button_label)
+ .on_click(cx.listener(move |this, _, _window, cx| {
+ cx.open_url(&verification_url);
+ this.connect_clicked = true;
+ }))
+ .full_width()
+ .style(ButtonStyle::Filled),
+ )
+ .child(
+ Button::new("cancel-button", "Cancel")
+ .full_width()
+ .on_click(cx.listener(|_, _, _, cx| {
+ cx.emit(DismissEvent);
+ })),
+ )
+ }
+
+ fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl Element {
+ let state = self.state.read(cx);
+ let success_headline = state.config.success_headline.clone();
+ let success_message = state.config.success_message.clone();
+
+ v_flex()
+ .gap_2()
+ .child(Headline::new(success_headline).size(HeadlineSize::Large))
+ .child(Label::new(success_message))
+ .child(
+ Button::new("done-button", "Done")
+ .full_width()
+ .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
+ )
+ }
+
+ fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl Element {
+ v_flex()
+ .gap_2()
+ .child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
+ .child(Label::new(error.to_string()).color(Color::Error))
+ .child(
+ Button::new("close-button", "Close")
+ .full_width()
+ .on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
+ )
+ }
+
+ fn render_loading(window: &mut Window, _cx: &mut Context<Self>) -> impl Element {
+ let loading_icon = svg()
+ .size_8()
+ .path(IconName::ArrowCircle.path())
+ .text_color(window.text_style().color)
+ .with_animation(
+ "icon_circle_arrow",
+ Animation::new(Duration::from_secs(2)).repeat(),
+ |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
+ );
+
+ h_flex().justify_center().child(loading_icon)
+ }
+}
+
+impl Render for OAuthDeviceFlowModal {
+ fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let status = self.state.read(cx).status.clone();
+
+ let prompt = match &status {
+ OAuthDeviceFlowStatus::Prompting => self.render_prompting_modal(cx).into_any_element(),
+ OAuthDeviceFlowStatus::WaitingForAuthorization => {
+ if self.connect_clicked {
+ self.render_prompting_modal(cx).into_any_element()
+ } else {
+ Self::render_loading(window, cx).into_any_element()
+ }
+ }
+ OAuthDeviceFlowStatus::Authorized => {
+ self.render_authorized_modal(cx).into_any_element()
+ }
+ OAuthDeviceFlowStatus::Failed(error) => {
+ self.render_failed_modal(error, cx).into_any_element()
+ }
+ };
+
+ v_flex()
+ .id("oauth-device-flow-modal")
+ .track_focus(&self.focus_handle(cx))
+ .elevation_3(cx)
+ .w_96()
+ .items_center()
+ .p_4()
+ .gap_2()
+ .on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
+ cx.emit(DismissEvent);
+ }))
+ .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _| {
+ window.focus(&this.focus_handle);
+ }))
+ .child(self.render_icon(cx))
+ .child(prompt)
+ }
+}
@@ -4,6 +4,7 @@ pub mod invalid_item_view;
pub mod item;
mod modal_layer;
pub mod notifications;
+pub mod oauth_device_flow_modal;
pub mod pane;
pub mod pane_group;
mod path_list;
@@ -454,7 +454,7 @@ impl zed::Extension for CopilotChatProvider {
fn llm_provider_start_device_flow_sign_in(
&mut self,
_provider_id: &str,
- ) -> Result<String, String> {
+ ) -> Result<LlmDeviceFlowPromptInfo, String> {
// Step 1: Request device and user verification codes
let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
url: GITHUB_DEVICE_CODE_URL.to_string(),
@@ -497,7 +497,7 @@ impl zed::Extension for CopilotChatProvider {
expires_in: device_info.expires_in,
});
- // Step 2: Open browser to verification URL
+ // Step 2: Construct verification URL
// Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
format!(
@@ -505,10 +505,19 @@ impl zed::Extension for CopilotChatProvider {
device_info.verification_uri, &device_info.user_code
)
});
- llm_oauth_open_browser(&verification_url)?;
- // Return the user code for the host to display
- Ok(device_info.user_code)
+ // Return prompt info for the host to display in the modal
+ Ok(LlmDeviceFlowPromptInfo {
+ user_code: device_info.user_code,
+ verification_url,
+ headline: "Use GitHub Copilot in Zed.".to_string(),
+ description: "Using Copilot requires an active subscription on GitHub.".to_string(),
+ connect_button_label: "Connect to GitHub".to_string(),
+ success_headline: "Copilot Enabled!".to_string(),
+ success_message:
+ "You can update your settings or sign out from the Copilot menu in the status bar."
+ .to_string(),
+ })
}
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {