diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index 35e19e54684557cef870d1daf092d5455d5f9a04..6b8e1d87a0934abd1d2e1cc553ac2dafaca98de0 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -293,12 +293,15 @@ impl PickerDelegate for AcpModelPickerDelegate { .w_full() .gap_1p5() .map(|this| match &model_info.icon { - Some(icon) => this.child(match icon { - AgentModelIcon::Path(path) => Icon::from_path(path.clone()), - AgentModelIcon::Named(name) => Icon::new(*name), - } - .color(model_icon_color) - .size(IconSize::Small) + Some(AgentModelIcon::Path(path)) => this.child( + Icon::from_path(path.clone()) + .color(model_icon_color) + .size(IconSize::Small), + ), + Some(AgentModelIcon::Named(icon)) => this.child( + Icon::new(*icon) + .color(model_icon_color) + .size(IconSize::Small), ), None => this, }) diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index f67ee8510d42ffb3b3cd65dce0f970f0e6114090..7fd808bb2059fdb16cec3cca271190202127a176 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -78,15 +78,13 @@ impl Render for AcpModelSelectorPopover { self.selector.clone(), ButtonLike::new("active-model") .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - .when_some(model_icon, |this, icon| { - this.child( - match icon { - AgentModelIcon::Path(path) => Icon::from_path(path), - AgentModelIcon::Named(icon_name) => Icon::new(icon_name), - } - .color(color) - .size(IconSize::XSmall), - ) + .when_some(model_icon, |this, icon| match icon { + AgentModelIcon::Path(path) => { + this.child(Icon::from_path(path).color(color).size(IconSize::XSmall)) + } + AgentModelIcon::Named(icon_name) => { + this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall)) + } }) .child( Label::new(model_name) diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index b63087cbf1e0dddc994962edf191d2e72df398be..3533c28caa93f82c96ecacdafdfdd3dc1b1643f1 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file}; use ui::{ Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure, Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize, - PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*, + PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*, }; use util::ResultExt as _; use workspace::{Workspace, create_and_open_local_file}; @@ -83,24 +83,14 @@ impl AgentConfiguration { window, |this, _, event: &language_model::Event, window, cx| match event { language_model::Event::AddedProvider(provider_id) => { - let registry = LanguageModelRegistry::read_global(cx); - // Only add if the provider is visible - if let Some(provider) = registry.provider(provider_id) { - if !registry.should_hide_provider(provider_id) { - this.add_provider_configuration_view(&provider, window, cx); - } + let provider = LanguageModelRegistry::read_global(cx).provider(provider_id); + if let Some(provider) = provider { + this.add_provider_configuration_view(&provider, window, cx); } } language_model::Event::RemovedProvider(provider_id) => { this.remove_provider_configuration_view(provider_id); } - language_model::Event::ProvidersChanged => { - // Rebuild all provider views when visibility changes - this.configuration_views_by_provider.clear(); - this.expanded_provider_configurations.clear(); - this.build_provider_configuration_views(window, cx); - cx.notify(); - } _ => {} }, ); @@ -127,7 +117,7 @@ impl AgentConfiguration { } fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context) { - let providers = LanguageModelRegistry::read_global(cx).visible_providers(); + let providers = LanguageModelRegistry::read_global(cx).providers(); for provider in providers { self.add_provider_configuration_view(&provider, window, cx); } @@ -270,15 +260,15 @@ impl AgentConfiguration { h_flex() .w_full() .gap_1p5() - .child( - if let Some(icon_path) = provider.icon_path() { - Icon::from_external_svg(icon_path) - } else { - Icon::new(provider.icon()) - } - .size(IconSize::Small) - .color(Color::Muted), - ) + .child(if let Some(icon_path) = provider.icon_path() { + Icon::from_external_svg(icon_path) + .size(IconSize::Small) + .color(Color::Muted) + } else { + Icon::new(provider.icon()) + .size(IconSize::Small) + .color(Color::Muted) + }) .child( h_flex() .w_full() @@ -430,7 +420,7 @@ impl AgentConfiguration { &mut self, cx: &mut Context, ) -> impl IntoElement { - let providers = LanguageModelRegistry::read_global(cx).visible_providers(); + let providers = LanguageModelRegistry::read_global(cx).providers(); let popover_menu = PopoverMenu::new("add-provider-popover") .trigger( @@ -893,6 +883,7 @@ impl AgentConfiguration { .child(context_server_configuration_menu) .child( Switch::new("context-server-switch", is_running.into()) + .color(SwitchColor::Accent) .on_click({ let context_server_manager = self.context_server_store.clone(); let fs = self.fs.clone(); diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 685bd775424b120f753d3a0d444e5a966f940773..924f37db0440dd1d4ddbdb90bdf73dfe56f0cbad 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -108,7 +108,7 @@ impl Render for AgentModelSelector { .child( Icon::new(IconName::ChevronDown) .color(color) - .size(IconSize::Small), + .size(IconSize::XSmall), ), move |_window, cx| { Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx) diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index fb18c26b0946f6f0b51d80b465d045e0cbb33e7d..9fd717a597e14918c3a3adc909ff53d2bb8de740 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -46,9 +46,7 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { - let providers = LanguageModelRegistry::global(cx) - .read(cx) - .visible_providers(); + let providers = LanguageModelRegistry::global(cx).read(cx).providers(); let recommended = providers .iter() @@ -140,9 +138,13 @@ impl LanguageModelPickerDelegate { // Subscribe to registry events and send refresh signals through the channel let registry = LanguageModelRegistry::global(cx); cx.subscribe(®istry, move |_picker, _, event, _cx| match event { - language_model::Event::ProviderStateChanged(_) - | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { + language_model::Event::ProviderStateChanged(_) => { + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::AddedProvider(_) => { + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::RemovedProvider(_) => { refresh_tx.unbounded_send(()).ok(); } _ => {} @@ -421,7 +423,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { let configured_providers = language_model_registry .read(cx) - .visible_providers() + .providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 340a3eada3c1c102351ce05bac483f8f6272d925..30538898b28a1d41d6c63b3e910f51c816e299ab 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1682,98 +1682,6 @@ impl TextThreadEditor { window: &mut Window, cx: &mut Context, ) { - let editor_clipboard_selections = cx - .read_from_clipboard() - .and_then(|item| item.entries().first().cloned()) - .and_then(|entry| match entry { - ClipboardEntry::String(text) => { - text.metadata_json::>() - } - _ => None, - }); - - let has_file_context = editor_clipboard_selections - .as_ref() - .is_some_and(|selections| { - selections - .iter() - .any(|sel| sel.file_path.is_some() && sel.line_range.is_some()) - }); - - if has_file_context { - if let Some(clipboard_item) = cx.read_from_clipboard() { - if let Some(ClipboardEntry::String(clipboard_text)) = - clipboard_item.entries().first() - { - if let Some(selections) = editor_clipboard_selections { - cx.stop_propagation(); - - let text = clipboard_text.text(); - self.editor.update(cx, |editor, cx| { - let mut current_offset = 0; - let weak_editor = cx.entity().downgrade(); - - for selection in selections { - if let (Some(file_path), Some(line_range)) = - (selection.file_path, selection.line_range) - { - let selected_text = - &text[current_offset..current_offset + selection.len]; - let fence = assistant_slash_commands::codeblock_fence_for_path( - file_path.to_str(), - Some(line_range.clone()), - ); - let formatted_text = format!("{fence}{selected_text}\n```"); - - let insert_point = editor - .selections - .newest::(&editor.display_snapshot(cx)) - .head(); - let start_row = MultiBufferRow(insert_point.row); - - editor.insert(&formatted_text, window, cx); - - let snapshot = editor.buffer().read(cx).snapshot(cx); - let anchor_before = snapshot.anchor_after(insert_point); - let anchor_after = editor - .selections - .newest_anchor() - .head() - .bias_left(&snapshot); - - editor.insert("\n", window, cx); - - let crease_text = acp_thread::selection_name( - Some(file_path.as_ref()), - &line_range, - ); - - let fold_placeholder = quote_selection_fold_placeholder( - crease_text, - weak_editor.clone(), - ); - let crease = Crease::inline( - anchor_before..anchor_after, - fold_placeholder, - render_quote_selection_output_toggle, - |_, _, _, _| Empty.into_any(), - ); - editor.insert_creases(vec![crease], cx); - editor.fold_at(start_row, window, cx); - - current_offset += selection.len; - if !selection.is_entire_line && current_offset < text.len() { - current_offset += 1; - } - } - } - }); - return; - } - } - } - } - cx.stop_propagation(); let mut images = if let Some(item) = cx.read_from_clipboard() { @@ -2204,11 +2112,13 @@ impl TextThreadEditor { let provider_icon_element = if let Some(icon_path) = provider_icon_path { Icon::from_external_svg(icon_path) + .color(color) + .size(IconSize::XSmall) } else { Icon::new(provider_icon_name) - } - .color(color) - .size(IconSize::XSmall); + .color(color) + .size(IconSize::XSmall) + }; PickerPopoverMenu::new( self.language_model_selector.clone(), diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index 764ce8c8f4b63be1e96a0e4616629a42c4740eba..bdf1ce3640bf5041b63d952625429156814dadfb 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -44,7 +44,7 @@ impl ApiKeysWithProviders { fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> { LanguageModelRegistry::read_global(cx) - .visible_providers() + .providers() .iter() .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID @@ -68,14 +68,14 @@ impl Render for ApiKeysWithProviders { .map(|(icon, name)| { h_flex() .gap_1p5() - .child( - match icon { - ProviderIcon::Name(icon_name) => Icon::new(icon_name), - ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path), - } - .size(IconSize::XSmall) - .color(Color::Muted), - ) + .child(match icon { + ProviderIcon::Name(icon_name) => Icon::new(icon_name) + .size(IconSize::XSmall) + .color(Color::Muted), + ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path) + .size(IconSize::XSmall) + .color(Color::Muted), + }) .child(Label::new(name)) }); div() diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index 831d97e5d4b7289f19aef40a3be5df8d967eb7a2..ae92268ff4db459e748b806e47f6f89851783bd9 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -45,7 +45,7 @@ impl AgentPanelOnboarding { fn has_configured_providers(cx: &App) -> bool { LanguageModelRegistry::read_global(cx) - .visible_providers() + .providers() .iter() .any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID) } diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 73747c2997a28a96a33839b8ad96ed58c7ccdae2..3a09a602d5b46180548fb9bcfff8f3b3e7cdae53 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -339,6 +339,17 @@ pub struct LanguageModelAuthConfig { /// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token"). #[serde(default)] pub credential_label: Option, + /// OAuth configuration for web-based authentication flows. + #[serde(default)] + pub oauth: Option, +} + +/// OAuth configuration for web-based authentication. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct OAuthConfig { + /// The text to display on the sign-in button (e.g., "Sign in with GitHub"). + #[serde(default)] + pub sign_in_button_label: Option, } impl ExtensionManifest { diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 7ce852f6a36208cff2ae9edc682017e018ea0cb7..555ba6dcc260b6c4e9f9952589c080e87820f10b 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -31,20 +31,22 @@ pub use wit::{ }, zed::extension::llm_provider::{ CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, - CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData, - MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, - ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo, - OauthHttpRequest as LlmOauthHttpRequest, OauthHttpResponse as LlmOauthHttpResponse, - OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult, - ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, - StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, - TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, - ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat, - ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent, - ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError, + CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest, + OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig, + OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, delete_credential as llm_delete_credential, get_credential as llm_get_credential, get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser, oauth_start_web_auth as llm_oauth_start_web_auth, + request_credential as llm_request_credential, send_oauth_http_request as llm_oauth_http_request, store_credential as llm_store_credential, }, @@ -300,6 +302,31 @@ pub trait Extension: Send + Sync { false } + /// Attempt to authenticate the provider. + /// This is called for background credential checks - it should check for + /// existing credentials and return Ok if found, or an error if not. + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + Err("`llm_provider_authenticate` not implemented".to_string()) + } + + /// Start an OAuth device flow sign-in. + /// This is called when the user explicitly clicks "Sign in with GitHub" or similar. + /// Opens the browser to the verification URL and returns the user code that should + /// be displayed to the user. + fn llm_provider_start_device_flow_sign_in( + &mut self, + _provider_id: &str, + ) -> Result { + Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string()) + } + + /// Poll for device flow sign-in completion. + /// This is called after llm_provider_start_device_flow_sign_in returns the user code. + /// The extension should poll the OAuth provider until the user authorizes or the flow times out. + fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> { + Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string()) + } + /// Reset credentials for the provider. fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { Err("`llm_provider_reset_credentials` not implemented".to_string()) @@ -624,6 +651,18 @@ impl wit::Guest for Component { extension().llm_provider_is_authenticated(&provider_id) } + fn llm_provider_authenticate(provider_id: String) -> Result<(), String> { + extension().llm_provider_authenticate(&provider_id) + } + + fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result { + extension().llm_provider_start_device_flow_sign_in(&provider_id) + } + + fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> { + extension().llm_provider_poll_device_flow_sign_in(&provider_id) + } + fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> { extension().llm_provider_reset_credentials(&provider_id) } diff --git a/crates/extension_api/wit/since_v0.8.0/extension.wit b/crates/extension_api/wit/since_v0.8.0/extension.wit index aad4db30e82aebd7106d2354a7df2458de549888..ef9f464d29d8021a2dc14ac2b6c08c6df006c2e0 100644 --- a/crates/extension_api/wit/since_v0.8.0/extension.wit +++ b/crates/extension_api/wit/since_v0.8.0/extension.wit @@ -18,7 +18,7 @@ world extension { use slash-command.{slash-command, slash-command-argument-completion, slash-command-output}; use llm-provider.{ provider-info, model-info, completion-request, - cache-configuration, completion-event, token-usage + credential-type, cache-configuration, completion-event, token-usage }; /// Initializes the extension. @@ -183,6 +183,33 @@ world extension { /// Check if the provider is authenticated. export llm-provider-is-authenticated: func(provider-id: string) -> bool; + /// Attempt to authenticate the provider. + /// This is called for background credential checks - it should check for + /// existing credentials and return Ok if found, or an error if not. + /// For interactive OAuth flows, use the device flow functions instead. + export llm-provider-authenticate: func(provider-id: string) -> result<_, string>; + + /// Start an OAuth device flow sign-in. + /// This is called when the user explicitly clicks "Sign in with GitHub" or similar. + /// + /// The device flow works as follows: + /// 1. Extension requests a device code from the OAuth provider + /// 2. Extension opens the verification URL in the browser + /// 3. Extension returns the user code to display to the user + /// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in + /// 5. Extension polls for the access token while user authorizes in browser + /// 6. Once authorized, extension stores the credential and returns success + /// + /// Returns the user code that should be displayed to the user while they + /// complete authorization in the browser. + export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result; + + /// Poll for device flow sign-in completion. + /// This is called after llm-provider-start-device-flow-sign-in returns the user code. + /// The extension should poll the OAuth provider until the user authorizes or the flow times out. + /// Returns Ok(()) on successful authentication, or an error message on failure. + export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>; + /// Reset credentials for the provider. export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>; diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit index f9c81f13bac734c6a317e1f3a8c68fb5da2c18fc..a3f1258fc78850603a2a71d6aafecbb52b339a16 100644 --- a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -235,6 +235,14 @@ interface llm-provider { cache-read-input-tokens: option, } + /// Credential types that can be requested. + enum credential-type { + /// An API key. + api-key, + /// An OAuth token. + oauth-token, + } + /// Cache configuration for prompt caching. record cache-configuration { /// Maximum number of cache anchors. @@ -249,6 +257,9 @@ interface llm-provider { record oauth-web-auth-config { /// The URL to open in the user's browser to start authentication. /// This should include client_id, redirect_uri, scope, state, etc. + /// Use `{port}` as a placeholder in the URL - it will be replaced with + /// the actual localhost port before opening the browser. + /// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback" auth-url: string, /// The path to listen on for the OAuth callback (e.g., "/callback"). /// A localhost server will be started to receive the redirect. @@ -288,6 +299,15 @@ interface llm-provider { body: string, } + /// Request a credential from the user. + /// Returns true if the credential was provided, false if the user cancelled. + request-credential: func( + provider-id: string, + credential-type: credential-type, + label: string, + placeholder: string + ) -> result; + /// Get a stored credential for this provider. get-credential: func(provider-id: string) -> option; diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index dff8834956cd834e6fb9f27e94c47f2aae0d6f5b..acec25b8258c1615d97e55d3742a5fff42661b87 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -10,14 +10,14 @@ use crate::wasm_host::wit::{ use anyhow::{Result, anyhow}; use credentials_provider::CredentialsProvider; use editor::Editor; -use extension::LanguageModelAuthConfig; +use extension::{LanguageModelAuthConfig, OAuthConfig}; use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, StreamExt}; use gpui::Focusable; use gpui::{ - AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, - TextStyleRefinement, UnderlineStyle, Window, px, + AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter, + MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px, }; use language_model::tool_schema::LanguageModelToolSchemaFormat; use language_model::{ @@ -182,9 +182,37 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider { self.state.read(cx).is_authenticated } - fn authenticate(&self, _cx: &mut App) -> Task> { - // Authentication is handled via the configuration view UI - Task::ready(Ok(())) + fn authenticate(&self, cx: &mut App) -> Task> { + let extension = self.extension.clone(); + let provider_id = self.provider_info.id.clone(); + let state = self.state.clone(); + + cx.spawn(async move |cx| { + let result = extension + .call(|extension, store| { + async move { + extension + .call_llm_provider_authenticate(store, &provider_id) + .await + } + .boxed() + }) + .await; + + match result { + Ok(Ok(Ok(()))) => { + cx.update(|cx| { + state.update(cx, |state, _| { + state.is_authenticated = true; + }); + })?; + Ok(()) + } + Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))), + Ok(Err(e)) => Err(AuthenticateError::Other(e)), + Err(e) => Err(AuthenticateError::Other(e)), + } + }) } fn configuration_view( @@ -287,6 +315,9 @@ struct ExtensionProviderConfigurationView { api_key_editor: Entity, loading_settings: bool, loading_credentials: bool, + oauth_in_progress: bool, + oauth_error: Option, + device_user_code: Option, _subscriptions: Vec, } @@ -324,6 +355,9 @@ impl ExtensionProviderConfigurationView { api_key_editor, loading_settings: true, loading_credentials: true, + oauth_in_progress: false, + oauth_error: None, + device_user_code: None, _subscriptions: vec![state_subscription], }; @@ -572,9 +606,130 @@ impl ExtensionProviderConfigurationView { .detach(); } + fn start_oauth_sign_in(&mut self, cx: &mut Context) { + if self.oauth_in_progress { + return; + } + + self.oauth_in_progress = true; + self.oauth_error = None; + self.device_user_code = None; + cx.notify(); + + let extension = self.extension.clone(); + let provider_id = self.extension_provider_id.clone(); + let state = self.state.clone(); + + cx.spawn(async move |this, cx| { + // Step 1: Start device flow - opens browser and returns user code + let start_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let user_code = match start_result { + Ok(Ok(Ok(code))) => code, + Ok(Ok(Err(e))) => { + log::error!("Device flow start failed: {}", e); + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e); + cx.notify(); + }) + .log_err(); + return; + } + Ok(Err(e)) | Err(e) => { + log::error!("Device flow start error: {}", e); + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e.to_string()); + cx.notify(); + }) + .log_err(); + return; + } + }; + + // Update UI to show the user code before polling + this.update(cx, |this, cx| { + this.device_user_code = Some(user_code); + cx.notify(); + }) + .log_err(); + + // Step 2: Poll for authentication completion + let poll_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let error_message = match poll_result { + Ok(Ok(Ok(()))) => { + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = true; + cx.notify(); + }); + }); + None + } + Ok(Ok(Err(e))) => { + log::error!("Device flow poll failed: {}", e); + Some(e) + } + Ok(Err(e)) | Err(e) => { + log::error!("Device flow poll error: {}", e); + Some(e.to_string()) + } + }; + + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = error_message; + this.device_user_code = None; + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + fn is_authenticated(&self, cx: &Context) -> bool { self.state.read(cx).is_authenticated } + + fn has_oauth_config(&self) -> bool { + self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some()) + } + + fn oauth_config(&self) -> Option<&OAuthConfig> { + self.auth_config.as_ref().and_then(|c| c.oauth.as_ref()) + } + + fn has_api_key_config(&self) -> bool { + // API key is available if there's a credential_label or no oauth-only config + self.auth_config + .as_ref() + .map(|c| c.credential_label.is_some() || c.oauth.is_none()) + .unwrap_or(true) + } } impl gpui::Render for ExtensionProviderConfigurationView { @@ -583,6 +738,8 @@ impl gpui::Render for ExtensionProviderConfigurationView { let is_authenticated = self.is_authenticated(cx); let env_var_allowed = self.state.read(cx).env_var_allowed; let api_key_from_env = self.state.read(cx).api_key_from_env; + let has_oauth = self.has_oauth_config(); + let has_api_key = self.has_api_key_config(); if is_loading { return v_flex() @@ -652,7 +809,7 @@ impl gpui::Render for ExtensionProviderConfigurationView { ) .child( Label::new(format!( - "{} is not set or empty. You can set it and restart Zed, or enter an API key below.", + "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.", env_var_name )) .color(Color::Warning) @@ -664,8 +821,20 @@ impl gpui::Render for ExtensionProviderConfigurationView { } } - // Render API key section + // If authenticated, show success state with sign out option if is_authenticated && !api_key_from_env { + let reset_label = if has_oauth && !has_api_key { + "Sign Out" + } else { + "Reset Credentials" + }; + + let status_label = if has_oauth && !has_api_key { + "Signed in" + } else { + "Authenticated" + }; + content = content.child( v_flex() .gap_2() @@ -677,39 +846,176 @@ impl gpui::Render for ExtensionProviderConfigurationView { .color(Color::Success) .size(ui::IconSize::Small), ) - .child(Label::new("API key configured").color(Color::Success)), + .child(Label::new(status_label).color(Color::Success)), ) .child( - ui::Button::new("reset-api-key", "Reset API Key") + ui::Button::new("reset-credentials", reset_label) .style(ui::ButtonStyle::Subtle) .on_click(cx.listener(|this, _, window, cx| { this.reset_api_key(window, cx); })), ), ); - } else if !api_key_from_env { - let credential_label = self - .auth_config - .as_ref() - .and_then(|c| c.credential_label.clone()) - .unwrap_or_else(|| "API Key".to_string()); - content = content.child( - v_flex() - .gap_2() - .on_action(cx.listener(Self::save_api_key)) - .child( - Label::new(credential_label) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child(self.api_key_editor.clone()) - .child( - Label::new("Enter your API key and press Enter to save") - .size(LabelSize::Small) - .color(Color::Muted), - ), - ); + return content.into_any_element(); + } + + // Not authenticated - show available auth options + if !api_key_from_env { + // Render OAuth sign-in button if configured + if has_oauth { + let oauth_config = self.oauth_config(); + let button_label = oauth_config + .and_then(|c| c.sign_in_button_label.clone()) + .unwrap_or_else(|| "Sign In".to_string()); + + let oauth_in_progress = self.oauth_in_progress; + + let oauth_error = self.oauth_error.clone(); + + content = content.child( + v_flex() + .gap_2() + .child( + ui::Button::new("oauth-sign-in", button_label) + .style(ui::ButtonStyle::Filled) + .disabled(oauth_in_progress) + .on_click(cx.listener(|this, _, _window, cx| { + this.start_oauth_sign_in(cx); + })), + ) + .when(oauth_in_progress, |this| { + let user_code = self.device_user_code.clone(); + this.child( + v_flex() + .gap_1() + .when_some(user_code, |this, code| { + let copied = cx + .read_from_clipboard() + .map(|item| item.text().as_ref() == Some(&code)) + .unwrap_or(false); + let code_for_click = code.clone(); + this.child( + h_flex() + .gap_1() + .child( + Label::new("Enter code:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + h_flex() + .gap_1() + .px_1() + .border_1() + .border_color(cx.theme().colors().border) + .rounded_sm() + .cursor_pointer() + .on_mouse_down( + MouseButton::Left, + move |_, window, cx| { + cx.write_to_clipboard( + ClipboardItem::new_string( + code_for_click.clone(), + ), + ); + window.refresh(); + }, + ) + .child( + Label::new(code) + .size(LabelSize::Small) + .color(Color::Accent), + ) + .child( + ui::Icon::new(if copied { + ui::IconName::Check + } else { + ui::IconName::Copy + }) + .size(ui::IconSize::Small) + .color(if copied { + Color::Success + } else { + Color::Muted + }), + ), + ), + ) + }) + .child( + Label::new("Waiting for authorization in browser...") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + .when_some(oauth_error, |this, error| { + this.child( + v_flex() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Warning) + .color(Color::Error) + .size(ui::IconSize::Small), + ) + .child( + Label::new("Authentication failed") + .color(Color::Error) + .size(LabelSize::Small), + ), + ) + .child( + div().pl_6().child( + Label::new(error) + .color(Color::Error) + .size(LabelSize::Small), + ), + ), + ) + }), + ); + } + + // Render API key input if configured (and we have both options, show a separator) + if has_api_key { + if has_oauth { + content = content.child( + h_flex() + .gap_2() + .items_center() + .child(div().h_px().flex_1().bg(cx.theme().colors().border)) + .child(Label::new("or").size(LabelSize::Small).color(Color::Muted)) + .child(div().h_px().flex_1().bg(cx.theme().colors().border)), + ); + } + + let credential_label = self + .auth_config + .as_ref() + .and_then(|c| c.credential_label.clone()) + .unwrap_or_else(|| "API Key".to_string()); + + content = content.child( + v_flex() + .gap_2() + .on_action(cx.listener(Self::save_api_key)) + .child( + Label::new(credential_label) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new("Enter your API key and press Enter to save") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ); + } } content.into_any_element() @@ -770,8 +1076,7 @@ impl LanguageModel for ExtensionLanguageModel { } fn name(&self) -> LanguageModelName { - // HACK: Add "(Extension)" prefix to help distinguish extension models during debugging - LanguageModelName::from(format!("(Extension) {}", self.model_info.name)) + LanguageModelName::from(self.model_info.name.clone()) } fn provider_id(&self) -> LanguageModelProviderId { diff --git a/crates/extension_host/src/wasm_host/wit.rs b/crates/extension_host/src/wasm_host/wit.rs index ab85cdf2429a4fde4eedfdf945c0ee87f339f4c3..c2b22d2ad0227830a424f482696cb85b6e88e708 100644 --- a/crates/extension_host/src/wasm_host/wit.rs +++ b/crates/extension_host/src/wasm_host/wit.rs @@ -35,15 +35,16 @@ pub use latest::{ zed::extension::context_server::ContextServerConfiguration, zed::extension::llm_provider::{ CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, - CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData, - MessageContent as LlmMessageContent, MessageRole as LlmMessageRole, - ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo, - ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage, - StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent, - TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice, - ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat, - ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent, - ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError, + CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, }, zed::extension::lsp::{ Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind, @@ -1229,6 +1230,53 @@ impl Extension { } } + pub async fn call_llm_provider_authenticate( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_provider_authenticate(store, provider_id).await, + _ => anyhow::bail!("`llm_provider_authenticate` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_provider_start_device_flow_sign_in( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_start_device_flow_sign_in(store, provider_id) + .await + } + _ => { + anyhow::bail!( + "`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0" + ) + } + } + } + + pub async fn call_llm_provider_poll_device_flow_sign_in( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id) + .await + } + _ => { + anyhow::bail!( + "`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0" + ) + } + } + } + pub async fn call_llm_provider_reset_credentials( &self, store: &mut Store, diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index d17787ad009af13f8b72b75498f89da5f7d62778..a7fc76ffb6d489881dfdc977b0814847e1ea7c00 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -1112,6 +1112,20 @@ impl ExtensionImports for WasmState { } impl llm_provider::Host for WasmState { + async fn request_credential( + &mut self, + _provider_id: String, + _credential_type: llm_provider::CredentialType, + _label: String, + _placeholder: String, + ) -> wasmtime::Result> { + // For now, credential requests return false (not provided) + // Extensions should use get_env_var to check for env vars first, + // then store_credential/get_credential for manual storage + // Full UI credential prompting will be added in a future phase + Ok(Ok(false)) + } + async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result> { let extension_id = self.manifest.id.clone(); @@ -1288,8 +1302,9 @@ impl llm_provider::Host for WasmState { .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))? .port(); + let auth_url_with_port = auth_url.replace("{port}", &port.to_string()); cx.update(|cx| { - cx.open_url(&auth_url); + cx.open_url(&auth_url_with_port); })?; let accept_future = async { diff --git a/extensions/copilot-chat/extension.toml b/extensions/copilot-chat/extension.toml index c226a20f3a77244c12aa63087e26d52ee06e726f..5e77c6dda4144f39f4ad904ed1fe6f7276f4845d 100644 --- a/extensions/copilot-chat/extension.toml +++ b/extensions/copilot-chat/extension.toml @@ -10,4 +10,7 @@ repository = "https://github.com/zed-industries/zed" name = "Copilot Chat" [language_model_providers.copilot-chat.auth] -env_var = "GH_COPILOT_TOKEN" \ No newline at end of file +env_var = "GH_COPILOT_TOKEN" + +[language_model_providers.copilot-chat.auth.oauth] +sign_in_button_label = "Sign in with GitHub" \ No newline at end of file diff --git a/extensions/copilot-chat/src/copilot_chat.rs b/extensions/copilot-chat/src/copilot_chat.rs index d7d592e7c2ab4eaffec96f65152fcd35d0d5f6b6..9d5730e85055a2ad3562199335541e986ca4365f 100644 --- a/extensions/copilot-chat/src/copilot_chat.rs +++ b/extensions/copilot-chat/src/copilot_chat.rs @@ -1,13 +1,86 @@ use std::collections::HashMap; use std::sync::Mutex; +use std::thread; +use std::time::Duration; use serde::{Deserialize, Serialize}; use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; use zed_extension_api::{self as zed, *}; +const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; + +struct DeviceFlowState { + device_code: String, + interval: u64, + expires_in: u64, +} + +#[derive(Clone)] +struct ApiToken { + api_key: String, + api_endpoint: String, +} + +#[derive(Clone, Deserialize)] +struct CopilotModel { + id: String, + name: String, + #[serde(default)] + is_chat_default: bool, + #[serde(default)] + is_chat_fallback: bool, + #[serde(default)] + model_picker_enabled: bool, + #[serde(default)] + capabilities: ModelCapabilities, + #[serde(default)] + policy: Option, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelCapabilities { + #[serde(default)] + family: String, + #[serde(default)] + limits: ModelLimits, + #[serde(default)] + supports: ModelSupportedFeatures, + #[serde(rename = "type", default)] + model_type: String, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelLimits { + #[serde(default)] + max_context_window_tokens: u64, + #[serde(default)] + max_output_tokens: u64, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelSupportedFeatures { + #[serde(default)] + streaming: bool, + #[serde(default)] + tool_calls: bool, + #[serde(default)] + vision: bool, +} + +#[derive(Clone, Deserialize)] +struct ModelPolicy { + state: String, +} + struct CopilotChatProvider { streams: Mutex>, next_stream_id: Mutex, + device_flow_state: Mutex>, + api_token: Mutex>, + cached_models: Mutex>>, } struct StreamState { @@ -25,95 +98,6 @@ struct AccumulatedToolCall { arguments: String, } -struct ModelDefinition { - id: &'static str, - display_name: &'static str, - max_tokens: u64, - max_output_tokens: Option, - supports_images: bool, - is_default: bool, - is_default_fast: bool, -} - -const MODELS: &[ModelDefinition] = &[ - ModelDefinition { - id: "gpt-4o", - display_name: "GPT-4o", - max_tokens: 128_000, - max_output_tokens: Some(16_384), - supports_images: true, - is_default: true, - is_default_fast: false, - }, - ModelDefinition { - id: "gpt-4o-mini", - display_name: "GPT-4o Mini", - max_tokens: 128_000, - max_output_tokens: Some(16_384), - supports_images: true, - is_default: false, - is_default_fast: true, - }, - ModelDefinition { - id: "gpt-4.1", - display_name: "GPT-4.1", - max_tokens: 1_000_000, - max_output_tokens: Some(32_768), - supports_images: true, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - id: "o1", - display_name: "o1", - max_tokens: 200_000, - max_output_tokens: Some(100_000), - supports_images: true, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - id: "o3-mini", - display_name: "o3-mini", - max_tokens: 200_000, - max_output_tokens: Some(100_000), - supports_images: false, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - id: "claude-3.5-sonnet", - display_name: "Claude 3.5 Sonnet", - max_tokens: 200_000, - max_output_tokens: Some(8_192), - supports_images: true, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - id: "claude-3.7-sonnet", - display_name: "Claude 3.7 Sonnet", - max_tokens: 200_000, - max_output_tokens: Some(8_192), - supports_images: true, - is_default: false, - is_default_fast: false, - }, - ModelDefinition { - id: "gemini-2.0-flash-001", - display_name: "Gemini 2.0 Flash", - max_tokens: 1_000_000, - max_output_tokens: Some(8_192), - supports_images: true, - is_default: false, - is_default_fast: false, - }, -]; - -fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> { - MODELS.iter().find(|m| m.id == model_id) -} - #[derive(Serialize)] struct OpenAiRequest { model: String, @@ -389,10 +373,7 @@ fn convert_request( LlmToolChoice::None => "none".to_string(), }); - let model_def = get_model_definition(model_id); - let max_tokens = request - .max_tokens - .or(model_def.and_then(|m| m.max_output_tokens)); + let max_tokens = request.max_tokens; Ok(OpenAiRequest { model: model_id.to_string(), @@ -422,42 +403,46 @@ impl zed::Extension for CopilotChatProvider { Self { streams: Mutex::new(HashMap::new()), next_stream_id: Mutex::new(0), + device_flow_state: Mutex::new(None), + api_token: Mutex::new(None), + cached_models: Mutex::new(None), } } fn llm_providers(&self) -> Vec { vec![LlmProviderInfo { - id: "copilot_chat".into(), + id: "copilot-chat".into(), name: "Copilot Chat".into(), icon: Some("icons/copilot.svg".into()), }] } fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { - Ok(MODELS - .iter() - .map(|m| LlmModelInfo { - id: m.id.to_string(), - name: m.display_name.to_string(), - max_token_count: m.max_tokens, - max_output_tokens: m.max_output_tokens, - capabilities: LlmModelCapabilities { - supports_images: m.supports_images, - supports_tools: true, - supports_tool_choice_auto: true, - supports_tool_choice_any: true, - supports_tool_choice_none: true, - supports_thinking: false, - tool_input_format: LlmToolInputFormat::JsonSchema, - }, - is_default: m.is_default, - is_default_fast: m.is_default_fast, - }) - .collect()) + // Try to get models from cache first + if let Some(models) = self.cached_models.lock().unwrap().as_ref() { + return Ok(convert_models_to_llm_info(models)); + } + + // Need to fetch models - requires authentication + let oauth_token = match llm_get_credential("copilot-chat") { + Some(token) => token, + None => return Ok(Vec::new()), // Not authenticated, return empty + }; + + // Get API token + let api_token = self.get_api_token(&oauth_token)?; + + // Fetch models from API + let models = self.fetch_models(&api_token)?; + + // Cache the models + *self.cached_models.lock().unwrap() = Some(models.clone()); + + Ok(convert_models_to_llm_info(&models)) } fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { - llm_get_credential("copilot_chat").is_some() + llm_get_credential("copilot-chat").is_some() } fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { @@ -466,13 +451,11 @@ impl zed::Extension for CopilotChatProvider { Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models. -## Configuration +## Authentication -Enter your GitHub Copilot token below. You need an active GitHub Copilot subscription. +Click **Sign in with GitHub** to authenticate with your GitHub account. You'll be redirected to GitHub to authorize access. This requires an active GitHub Copilot subscription. -To get your token: -1. Ensure you have a GitHub Copilot subscription -2. Generate a token from your GitHub Copilot settings +Alternatively, you can set the `GH_COPILOT_TOKEN` environment variable with your token. ## Available Models @@ -502,8 +485,156 @@ This extension requires an active GitHub Copilot subscription. ) } + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + // Check if we have existing credentials + if llm_get_credential("copilot-chat").is_some() { + return Ok(()); + } + + // No credentials found - return error for background auth checks. + // The device flow will be triggered by the host when the user clicks + // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in. + Err("CredentialsNotFound".to_string()) + } + + fn llm_provider_start_device_flow_sign_in( + &mut self, + _provider_id: &str, + ) -> Result { + // Step 1: Request device and user verification codes + let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest { + url: GITHUB_DEVICE_CODE_URL.to_string(), + method: "POST".to_string(), + headers: vec![ + ("Accept".to_string(), "application/json".to_string()), + ( + "Content-Type".to_string(), + "application/x-www-form-urlencoded".to_string(), + ), + ], + body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID), + })?; + + if device_code_response.status != 200 { + return Err(format!( + "Failed to get device code: HTTP {}", + device_code_response.status + )); + } + + #[derive(Deserialize)] + struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default)] + verification_uri_complete: Option, + expires_in: u64, + interval: u64, + } + + let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body) + .map_err(|e| format!("Failed to parse device code response: {}", e))?; + + // Store device flow state for polling + *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState { + device_code: device_info.device_code, + interval: device_info.interval, + expires_in: device_info.expires_in, + }); + + // Step 2: Open browser to verification URL + // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL + let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| { + format!( + "{}?user_code={}", + device_info.verification_uri, &device_info.user_code + ) + }); + llm_oauth_open_browser(&verification_url)?; + + // Return the user code for the host to display + Ok(device_info.user_code) + } + + fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> { + let state = self + .device_flow_state + .lock() + .unwrap() + .take() + .ok_or("No device flow in progress")?; + + let poll_interval = Duration::from_secs(state.interval.max(5)); + let max_attempts = (state.expires_in / state.interval.max(5)) as usize; + + for _ in 0..max_attempts { + thread::sleep(poll_interval); + + let token_response = llm_oauth_http_request(&LlmOauthHttpRequest { + url: GITHUB_ACCESS_TOKEN_URL.to_string(), + method: "POST".to_string(), + headers: vec![ + ("Accept".to_string(), "application/json".to_string()), + ( + "Content-Type".to_string(), + "application/x-www-form-urlencoded".to_string(), + ), + ], + body: format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + GITHUB_COPILOT_CLIENT_ID, state.device_code + ), + })?; + + #[derive(Deserialize)] + struct TokenResponse { + access_token: Option, + error: Option, + error_description: Option, + } + + let token_json: TokenResponse = serde_json::from_str(&token_response.body) + .map_err(|e| format!("Failed to parse token response: {}", e))?; + + if let Some(access_token) = token_json.access_token { + llm_store_credential("copilot-chat", &access_token)?; + return Ok(()); + } + + if let Some(error) = &token_json.error { + match error.as_str() { + "authorization_pending" => { + // User hasn't authorized yet, keep polling + continue; + } + "slow_down" => { + // Need to slow down polling + thread::sleep(Duration::from_secs(5)); + continue; + } + "expired_token" => { + return Err("Device code expired. Please try again.".to_string()); + } + "access_denied" => { + return Err("Authorization was denied.".to_string()); + } + _ => { + let description = token_json.error_description.unwrap_or_default(); + return Err(format!("OAuth error: {} - {}", error, description)); + } + } + } + } + + Err("Authorization timed out. Please try again.".to_string()) + } + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { - llm_delete_credential("copilot_chat") + // Clear cached API token and models + *self.api_token.lock().unwrap() = None; + *self.cached_models.lock().unwrap() = None; + llm_delete_credential("copilot-chat") } fn llm_stream_completion_start( @@ -512,21 +643,29 @@ This extension requires an active GitHub Copilot subscription. model_id: &str, request: &LlmCompletionRequest, ) -> Result { - let api_key = llm_get_credential("copilot_chat").ok_or_else(|| { + let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| { "No token configured. Please add your GitHub Copilot token in settings.".to_string() })?; + // Get or refresh API token + let api_token = self.get_api_token(&oauth_token)?; + let openai_request = convert_request(model_id, request)?; let body = serde_json::to_vec(&openai_request) .map_err(|e| format!("Failed to serialize request: {}", e))?; + let completions_url = format!("{}/chat/completions", api_token.api_endpoint); + let http_request = HttpRequest { method: HttpMethod::Post, - url: "https://api.githubcopilot.com/chat/completions".to_string(), + url: completions_url, headers: vec![ ("Content-Type".to_string(), "application/json".to_string()), - ("Authorization".to_string(), format!("Bearer {}", api_key)), + ( + "Authorization".to_string(), + format!("Bearer {}", api_token.api_key), + ), ( "Copilot-Integration-Id".to_string(), "vscode-chat".to_string(), @@ -679,4 +818,187 @@ This extension requires an active GitHub Copilot subscription. } } +impl CopilotChatProvider { + fn get_api_token(&self, oauth_token: &str) -> Result { + // Check if we have a cached token + if let Some(token) = self.api_token.lock().unwrap().clone() { + return Ok(token); + } + + // Request a new API token + let http_request = HttpRequest { + method: HttpMethod::Get, + url: GITHUB_COPILOT_TOKEN_URL.to_string(), + headers: vec![ + ( + "Authorization".to_string(), + format!("token {}", oauth_token), + ), + ("Accept".to_string(), "application/json".to_string()), + ], + body: None, + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response = http_request + .fetch() + .map_err(|e| format!("Failed to request API token: {}", e))?; + + #[derive(Deserialize)] + struct ApiTokenResponse { + token: String, + endpoints: ApiEndpoints, + } + + #[derive(Deserialize)] + struct ApiEndpoints { + api: String, + } + + let token_response: ApiTokenResponse = + serde_json::from_slice(&response.body).map_err(|e| { + format!( + "Failed to parse API token response: {} - body: {}", + e, + String::from_utf8_lossy(&response.body) + ) + })?; + + let api_token = ApiToken { + api_key: token_response.token, + api_endpoint: token_response.endpoints.api, + }; + + // Cache the token + *self.api_token.lock().unwrap() = Some(api_token.clone()); + + Ok(api_token) + } + + fn fetch_models(&self, api_token: &ApiToken) -> Result, String> { + let models_url = format!("{}/models", api_token.api_endpoint); + + let http_request = HttpRequest { + method: HttpMethod::Get, + url: models_url, + headers: vec![ + ( + "Authorization".to_string(), + format!("Bearer {}", api_token.api_key), + ), + ("Content-Type".to_string(), "application/json".to_string()), + ( + "Copilot-Integration-Id".to_string(), + "vscode-chat".to_string(), + ), + ("Editor-Version".to_string(), "Zed/1.0.0".to_string()), + ("x-github-api-version".to_string(), "2025-05-01".to_string()), + ], + body: None, + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response = http_request + .fetch() + .map_err(|e| format!("Failed to fetch models: {}", e))?; + + #[derive(Deserialize)] + struct ModelsResponse { + data: Vec, + } + + let models_response: ModelsResponse = + serde_json::from_slice(&response.body).map_err(|e| { + format!( + "Failed to parse models response: {} - body: {}", + e, + String::from_utf8_lossy(&response.body) + ) + })?; + + // Filter models like the built-in Copilot Chat does + let mut models: Vec = models_response + .data + .into_iter() + .filter(|model| { + model.model_picker_enabled + && model.capabilities.model_type == "chat" + && model + .policy + .as_ref() + .map(|p| p.state == "enabled") + .unwrap_or(true) + }) + .collect(); + + // Sort so default model is first + if let Some(pos) = models.iter().position(|m| m.is_chat_default) { + let default_model = models.remove(pos); + models.insert(0, default_model); + } + + Ok(models) + } +} + +fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec { + models + .iter() + .map(|m| { + let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 { + m.capabilities.limits.max_context_window_tokens + } else { + 128_000 // Default fallback + }; + let max_output = if m.capabilities.limits.max_output_tokens > 0 { + Some(m.capabilities.limits.max_output_tokens) + } else { + None + }; + + LlmModelInfo { + id: m.id.clone(), + name: m.name.clone(), + max_token_count: max_tokens, + max_output_tokens: max_output, + capabilities: LlmModelCapabilities { + supports_images: m.capabilities.supports.vision, + supports_tools: m.capabilities.supports.tool_calls, + supports_tool_choice_auto: m.capabilities.supports.tool_calls, + supports_tool_choice_any: m.capabilities.supports.tool_calls, + supports_tool_choice_none: m.capabilities.supports.tool_calls, + supports_thinking: false, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_chat_default, + is_default_fast: m.is_chat_fallback, + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_device_flow_request_body() { + let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID); + assert!(body.contains("client_id=Iv1.b507a08c87ecfe98")); + assert!(body.contains("scope=read:user")); + } + + #[test] + fn test_token_poll_request_body() { + let device_code = "test_device_code_123"; + let body = format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + GITHUB_COPILOT_CLIENT_ID, device_code + ); + assert!(body.contains("client_id=Iv1.b507a08c87ecfe98")); + assert!(body.contains("device_code=test_device_code_123")); + assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code")); + } +} + zed::register_extension!(CopilotChatProvider);