Detailed changes
@@ -293,12 +293,15 @@ impl PickerDelegate for AcpModelPickerDelegate {
.w_full()
.gap_1p5()
.map(|this| match &model_info.icon {
- Some(icon) => this.child(match icon {
- AgentModelIcon::Path(path) => Icon::from_path(path.clone()),
- AgentModelIcon::Named(name) => Icon::new(*name),
- }
- .color(model_icon_color)
- .size(IconSize::Small)
+ Some(AgentModelIcon::Path(path)) => this.child(
+ Icon::from_path(path.clone())
+ .color(model_icon_color)
+ .size(IconSize::Small),
+ ),
+ Some(AgentModelIcon::Named(icon)) => this.child(
+ Icon::new(*icon)
+ .color(model_icon_color)
+ .size(IconSize::Small),
),
None => this,
})
@@ -78,15 +78,13 @@ impl Render for AcpModelSelectorPopover {
self.selector.clone(),
ButtonLike::new("active-model")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
- .when_some(model_icon, |this, icon| {
- this.child(
- match icon {
- AgentModelIcon::Path(path) => Icon::from_path(path),
- AgentModelIcon::Named(icon_name) => Icon::new(icon_name),
- }
- .color(color)
- .size(IconSize::XSmall),
- )
+ .when_some(model_icon, |this, icon| match icon {
+ AgentModelIcon::Path(path) => {
+ this.child(Icon::from_path(path).color(color).size(IconSize::XSmall))
+ }
+ AgentModelIcon::Named(icon_name) => {
+ this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
+ }
})
.child(
Label::new(model_name)
@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
- PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
+ PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};
@@ -83,24 +83,14 @@ impl AgentConfiguration {
window,
|this, _, event: &language_model::Event, window, cx| match event {
language_model::Event::AddedProvider(provider_id) => {
- let registry = LanguageModelRegistry::read_global(cx);
- // Only add if the provider is visible
- if let Some(provider) = registry.provider(provider_id) {
- if !registry.should_hide_provider(provider_id) {
- this.add_provider_configuration_view(&provider, window, cx);
- }
+ let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
+ if let Some(provider) = provider {
+ this.add_provider_configuration_view(&provider, window, cx);
}
}
language_model::Event::RemovedProvider(provider_id) => {
this.remove_provider_configuration_view(provider_id);
}
- language_model::Event::ProvidersChanged => {
- // Rebuild all provider views when visibility changes
- this.configuration_views_by_provider.clear();
- this.expanded_provider_configurations.clear();
- this.build_provider_configuration_views(window, cx);
- cx.notify();
- }
_ => {}
},
);
@@ -127,7 +117,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let providers = LanguageModelRegistry::read_global(cx).visible_providers();
+ let providers = LanguageModelRegistry::read_global(cx).providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@@ -270,15 +260,15 @@ impl AgentConfiguration {
h_flex()
.w_full()
.gap_1p5()
- .child(
- if let Some(icon_path) = provider.icon_path() {
- Icon::from_external_svg(icon_path)
- } else {
- Icon::new(provider.icon())
- }
- .size(IconSize::Small)
- .color(Color::Muted),
- )
+ .child(if let Some(icon_path) = provider.icon_path() {
+ Icon::from_external_svg(icon_path)
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ } else {
+ Icon::new(provider.icon())
+ .size(IconSize::Small)
+ .color(Color::Muted)
+ })
.child(
h_flex()
.w_full()
@@ -430,7 +420,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let providers = LanguageModelRegistry::read_global(cx).visible_providers();
+ let providers = LanguageModelRegistry::read_global(cx).providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(
@@ -893,6 +883,7 @@ impl AgentConfiguration {
.child(context_server_configuration_menu)
.child(
Switch::new("context-server-switch", is_running.into())
+ .color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();
@@ -108,7 +108,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
- .size(IconSize::Small),
+ .size(IconSize::XSmall),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
@@ -46,9 +46,7 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
- let providers = LanguageModelRegistry::global(cx)
- .read(cx)
- .visible_providers();
+ let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let recommended = providers
.iter()
@@ -140,9 +138,13 @@ impl LanguageModelPickerDelegate {
// Subscribe to registry events and send refresh signals through the channel
let registry = LanguageModelRegistry::global(cx);
cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
- language_model::Event::ProviderStateChanged(_)
- | language_model::Event::AddedProvider(_)
- | language_model::Event::RemovedProvider(_) => {
+ language_model::Event::ProviderStateChanged(_) => {
+ refresh_tx.unbounded_send(()).ok();
+ }
+ language_model::Event::AddedProvider(_) => {
+ refresh_tx.unbounded_send(()).ok();
+ }
+ language_model::Event::RemovedProvider(_) => {
refresh_tx.unbounded_send(()).ok();
}
_ => {}
@@ -421,7 +423,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
- .visible_providers()
+ .providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let editor_clipboard_selections = cx
- .read_from_clipboard()
- .and_then(|item| item.entries().first().cloned())
- .and_then(|entry| match entry {
- ClipboardEntry::String(text) => {
- text.metadata_json::<Vec<editor::ClipboardSelection>>()
- }
- _ => None,
- });
-
- let has_file_context = editor_clipboard_selections
- .as_ref()
- .is_some_and(|selections| {
- selections
- .iter()
- .any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
- });
-
- if has_file_context {
- if let Some(clipboard_item) = cx.read_from_clipboard() {
- if let Some(ClipboardEntry::String(clipboard_text)) =
- clipboard_item.entries().first()
- {
- if let Some(selections) = editor_clipboard_selections {
- cx.stop_propagation();
-
- let text = clipboard_text.text();
- self.editor.update(cx, |editor, cx| {
- let mut current_offset = 0;
- let weak_editor = cx.entity().downgrade();
-
- for selection in selections {
- if let (Some(file_path), Some(line_range)) =
- (selection.file_path, selection.line_range)
- {
- let selected_text =
- &text[current_offset..current_offset + selection.len];
- let fence = assistant_slash_commands::codeblock_fence_for_path(
- file_path.to_str(),
- Some(line_range.clone()),
- );
- let formatted_text = format!("{fence}{selected_text}\n```");
-
- let insert_point = editor
- .selections
- .newest::<Point>(&editor.display_snapshot(cx))
- .head();
- let start_row = MultiBufferRow(insert_point.row);
-
- editor.insert(&formatted_text, window, cx);
-
- let snapshot = editor.buffer().read(cx).snapshot(cx);
- let anchor_before = snapshot.anchor_after(insert_point);
- let anchor_after = editor
- .selections
- .newest_anchor()
- .head()
- .bias_left(&snapshot);
-
- editor.insert("\n", window, cx);
-
- let crease_text = acp_thread::selection_name(
- Some(file_path.as_ref()),
- &line_range,
- );
-
- let fold_placeholder = quote_selection_fold_placeholder(
- crease_text,
- weak_editor.clone(),
- );
- let crease = Crease::inline(
- anchor_before..anchor_after,
- fold_placeholder,
- render_quote_selection_output_toggle,
- |_, _, _, _| Empty.into_any(),
- );
- editor.insert_creases(vec![crease], cx);
- editor.fold_at(start_row, window, cx);
-
- current_offset += selection.len;
- if !selection.is_entire_line && current_offset < text.len() {
- current_offset += 1;
- }
- }
- }
- });
- return;
- }
- }
- }
- }
-
cx.stop_propagation();
let mut images = if let Some(item) = cx.read_from_clipboard() {
@@ -2204,11 +2112,13 @@ impl TextThreadEditor {
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
Icon::from_external_svg(icon_path)
+ .color(color)
+ .size(IconSize::XSmall)
} else {
Icon::new(provider_icon_name)
- }
- .color(color)
- .size(IconSize::XSmall);
+ .color(color)
+ .size(IconSize::XSmall)
+ };
PickerPopoverMenu::new(
self.language_model_selector.clone(),
@@ -44,7 +44,7 @@ impl ApiKeysWithProviders {
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
LanguageModelRegistry::read_global(cx)
- .visible_providers()
+ .providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
@@ -68,14 +68,14 @@ impl Render for ApiKeysWithProviders {
.map(|(icon, name)| {
h_flex()
.gap_1p5()
- .child(
- match icon {
- ProviderIcon::Name(icon_name) => Icon::new(icon_name),
- ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path),
- }
- .size(IconSize::XSmall)
- .color(Color::Muted),
- )
+ .child(match icon {
+ ProviderIcon::Name(icon_name) => Icon::new(icon_name)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ })
.child(Label::new(name))
});
div()
@@ -45,7 +45,7 @@ impl AgentPanelOnboarding {
fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
- .visible_providers()
+ .providers()
.iter()
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
@@ -339,6 +339,17 @@ pub struct LanguageModelAuthConfig {
/// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token").
#[serde(default)]
pub credential_label: Option<String>,
+ /// OAuth configuration for web-based authentication flows.
+ #[serde(default)]
+ pub oauth: Option<OAuthConfig>,
+}
+
+/// OAuth configuration for web-based authentication.
+#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
+pub struct OAuthConfig {
+ /// The text to display on the sign-in button (e.g., "Sign in with GitHub").
+ #[serde(default)]
+ pub sign_in_button_label: Option<String>,
}
impl ExtensionManifest {
@@ -31,20 +31,22 @@ pub use wit::{
},
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
- CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData,
- MessageContent as LlmMessageContent, MessageRole as LlmMessageRole,
- ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo,
- OauthHttpRequest as LlmOauthHttpRequest, OauthHttpResponse as LlmOauthHttpResponse,
- OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult,
- ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage,
- StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent,
- TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice,
- ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat,
- ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent,
- ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError,
+ CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
+ ImageData as LlmImageData, MessageContent as LlmMessageContent,
+ MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
+ ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
+ OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
+ OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
+ RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
+ ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
+ ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
+ ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
+ ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
+ ToolUseJsonParseError as LlmToolUseJsonParseError,
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
oauth_start_web_auth as llm_oauth_start_web_auth,
+ request_credential as llm_request_credential,
send_oauth_http_request as llm_oauth_http_request,
store_credential as llm_store_credential,
},
@@ -300,6 +302,31 @@ pub trait Extension: Send + Sync {
false
}
+ /// Attempt to authenticate the provider.
+ /// This is called for background credential checks - it should check for
+ /// existing credentials and return Ok if found, or an error if not.
+ fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
+ Err("`llm_provider_authenticate` not implemented".to_string())
+ }
+
+ /// Start an OAuth device flow sign-in.
+ /// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
+ /// Opens the browser to the verification URL and returns the user code that should
+ /// be displayed to the user.
+ fn llm_provider_start_device_flow_sign_in(
+ &mut self,
+ _provider_id: &str,
+ ) -> Result<String, String> {
+ Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
+ }
+
+ /// Poll for device flow sign-in completion.
+ /// This is called after llm_provider_start_device_flow_sign_in returns the user code.
+ /// The extension should poll the OAuth provider until the user authorizes or the flow times out.
+ fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
+ Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
+ }
+
/// Reset credentials for the provider.
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
Err("`llm_provider_reset_credentials` not implemented".to_string())
@@ -624,6 +651,18 @@ impl wit::Guest for Component {
extension().llm_provider_is_authenticated(&provider_id)
}
+ fn llm_provider_authenticate(provider_id: String) -> Result<(), String> {
+ extension().llm_provider_authenticate(&provider_id)
+ }
+
+ fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
+ extension().llm_provider_start_device_flow_sign_in(&provider_id)
+ }
+
+ fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
+ extension().llm_provider_poll_device_flow_sign_in(&provider_id)
+ }
+
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
extension().llm_provider_reset_credentials(&provider_id)
}
@@ -18,7 +18,7 @@ world extension {
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
use llm-provider.{
provider-info, model-info, completion-request,
- cache-configuration, completion-event, token-usage
+ credential-type, cache-configuration, completion-event, token-usage
};
/// Initializes the extension.
@@ -183,6 +183,33 @@ world extension {
/// Check if the provider is authenticated.
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
+ /// Attempt to authenticate the provider.
+ /// This is called for background credential checks - it should check for
+ /// existing credentials and return Ok if found, or an error if not.
+ /// For interactive OAuth flows, use the device flow functions instead.
+ export llm-provider-authenticate: func(provider-id: string) -> result<_, string>;
+
+ /// Start an OAuth device flow sign-in.
+ /// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
+ ///
+ /// The device flow works as follows:
+ /// 1. Extension requests a device code from the OAuth provider
+ /// 2. Extension opens the verification URL in the browser
+ /// 3. Extension returns the user code to display to the user
+ /// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
+ /// 5. Extension polls for the access token while user authorizes in browser
+ /// 6. Once authorized, extension stores the credential and returns success
+ ///
+ /// Returns the user code that should be displayed to the user while they
+ /// complete authorization in the browser.
+ export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
+
+ /// Poll for device flow sign-in completion.
+ /// This is called after llm-provider-start-device-flow-sign-in returns the user code.
+ /// The extension should poll the OAuth provider until the user authorizes or the flow times out.
+ /// Returns Ok(()) on successful authentication, or an error message on failure.
+ export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
+
/// Reset credentials for the provider.
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
@@ -235,6 +235,14 @@ interface llm-provider {
cache-read-input-tokens: option<u64>,
}
+ /// Credential types that can be requested.
+ enum credential-type {
+ /// An API key.
+ api-key,
+ /// An OAuth token.
+ oauth-token,
+ }
+
/// Cache configuration for prompt caching.
record cache-configuration {
/// Maximum number of cache anchors.
@@ -249,6 +257,9 @@ interface llm-provider {
record oauth-web-auth-config {
/// The URL to open in the user's browser to start authentication.
/// This should include client_id, redirect_uri, scope, state, etc.
+ /// Use `{port}` as a placeholder in the URL - it will be replaced with
+ /// the actual localhost port before opening the browser.
+ /// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
auth-url: string,
/// The path to listen on for the OAuth callback (e.g., "/callback").
/// A localhost server will be started to receive the redirect.
@@ -288,6 +299,15 @@ interface llm-provider {
body: string,
}
+ /// Request a credential from the user.
+ /// Returns true if the credential was provided, false if the user cancelled.
+ request-credential: func(
+ provider-id: string,
+ credential-type: credential-type,
+ label: string,
+ placeholder: string
+ ) -> result<bool, string>;
+
/// Get a stored credential for this provider.
get-credential: func(provider-id: string) -> option<string>;
@@ -10,14 +10,14 @@ use crate::wasm_host::wit::{
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use editor::Editor;
-use extension::LanguageModelAuthConfig;
+use extension::{LanguageModelAuthConfig, OAuthConfig};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use gpui::Focusable;
use gpui::{
- AnyView, App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task,
- TextStyleRefinement, UnderlineStyle, Window, px,
+ AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter,
+ MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px,
};
use language_model::tool_schema::LanguageModelToolSchemaFormat;
use language_model::{
@@ -182,9 +182,37 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
self.state.read(cx).is_authenticated
}
- fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
- // Authentication is handled via the configuration view UI
- Task::ready(Ok(()))
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ let extension = self.extension.clone();
+ let provider_id = self.provider_info.id.clone();
+ let state = self.state.clone();
+
+ cx.spawn(async move |cx| {
+ let result = extension
+ .call(|extension, store| {
+ async move {
+ extension
+ .call_llm_provider_authenticate(store, &provider_id)
+ .await
+ }
+ .boxed()
+ })
+ .await;
+
+ match result {
+ Ok(Ok(Ok(()))) => {
+ cx.update(|cx| {
+ state.update(cx, |state, _| {
+ state.is_authenticated = true;
+ });
+ })?;
+ Ok(())
+ }
+ Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))),
+ Ok(Err(e)) => Err(AuthenticateError::Other(e)),
+ Err(e) => Err(AuthenticateError::Other(e)),
+ }
+ })
}
fn configuration_view(
@@ -287,6 +315,9 @@ struct ExtensionProviderConfigurationView {
api_key_editor: Entity<Editor>,
loading_settings: bool,
loading_credentials: bool,
+ oauth_in_progress: bool,
+ oauth_error: Option<String>,
+ device_user_code: Option<String>,
_subscriptions: Vec<Subscription>,
}
@@ -324,6 +355,9 @@ impl ExtensionProviderConfigurationView {
api_key_editor,
loading_settings: true,
loading_credentials: true,
+ oauth_in_progress: false,
+ oauth_error: None,
+ device_user_code: None,
_subscriptions: vec![state_subscription],
};
@@ -572,9 +606,130 @@ impl ExtensionProviderConfigurationView {
.detach();
}
+ fn start_oauth_sign_in(&mut self, cx: &mut Context<Self>) {
+ if self.oauth_in_progress {
+ return;
+ }
+
+ self.oauth_in_progress = true;
+ self.oauth_error = None;
+ self.device_user_code = None;
+ cx.notify();
+
+ let extension = self.extension.clone();
+ let provider_id = self.extension_provider_id.clone();
+ let state = self.state.clone();
+
+ cx.spawn(async move |this, cx| {
+ // Step 1: Start device flow - opens browser and returns user code
+ let start_result = extension
+ .call({
+ let provider_id = provider_id.clone();
+ |ext, store| {
+ async move {
+ ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id)
+ .await
+ }
+ .boxed()
+ }
+ })
+ .await;
+
+ let user_code = match start_result {
+ Ok(Ok(Ok(code))) => code,
+ Ok(Ok(Err(e))) => {
+ log::error!("Device flow start failed: {}", e);
+ this.update(cx, |this, cx| {
+ this.oauth_in_progress = false;
+ this.oauth_error = Some(e);
+ cx.notify();
+ })
+ .log_err();
+ return;
+ }
+ Ok(Err(e)) | Err(e) => {
+ log::error!("Device flow start error: {}", e);
+ this.update(cx, |this, cx| {
+ this.oauth_in_progress = false;
+ this.oauth_error = Some(e.to_string());
+ cx.notify();
+ })
+ .log_err();
+ return;
+ }
+ };
+
+ // Update UI to show the user code before polling
+ this.update(cx, |this, cx| {
+ this.device_user_code = Some(user_code);
+ cx.notify();
+ })
+ .log_err();
+
+ // Step 2: Poll for authentication completion
+ let poll_result = extension
+ .call({
+ let provider_id = provider_id.clone();
+ |ext, store| {
+ async move {
+ ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id)
+ .await
+ }
+ .boxed()
+ }
+ })
+ .await;
+
+ let error_message = match poll_result {
+ Ok(Ok(Ok(()))) => {
+ let _ = cx.update(|cx| {
+ state.update(cx, |state, cx| {
+ state.is_authenticated = true;
+ cx.notify();
+ });
+ });
+ None
+ }
+ Ok(Ok(Err(e))) => {
+ log::error!("Device flow poll failed: {}", e);
+ Some(e)
+ }
+ Ok(Err(e)) | Err(e) => {
+ log::error!("Device flow poll error: {}", e);
+ Some(e.to_string())
+ }
+ };
+
+ this.update(cx, |this, cx| {
+ this.oauth_in_progress = false;
+ this.oauth_error = error_message;
+ this.device_user_code = None;
+ cx.notify();
+ })
+ .log_err();
+ })
+ .detach();
+ }
+
fn is_authenticated(&self, cx: &Context<Self>) -> bool {
self.state.read(cx).is_authenticated
}
+
+ fn has_oauth_config(&self) -> bool {
+ self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some())
+ }
+
+ fn oauth_config(&self) -> Option<&OAuthConfig> {
+ self.auth_config.as_ref().and_then(|c| c.oauth.as_ref())
+ }
+
+ fn has_api_key_config(&self) -> bool {
+ // API key is available if there's a credential_label or no oauth-only config
+ self.auth_config
+ .as_ref()
+ .map(|c| c.credential_label.is_some() || c.oauth.is_none())
+ .unwrap_or(true)
+ }
}
impl gpui::Render for ExtensionProviderConfigurationView {
@@ -583,6 +738,8 @@ impl gpui::Render for ExtensionProviderConfigurationView {
let is_authenticated = self.is_authenticated(cx);
let env_var_allowed = self.state.read(cx).env_var_allowed;
let api_key_from_env = self.state.read(cx).api_key_from_env;
+ let has_oauth = self.has_oauth_config();
+ let has_api_key = self.has_api_key_config();
if is_loading {
return v_flex()
@@ -652,7 +809,7 @@ impl gpui::Render for ExtensionProviderConfigurationView {
)
.child(
Label::new(format!(
- "{} is not set or empty. You can set it and restart Zed, or enter an API key below.",
+ "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.",
env_var_name
))
.color(Color::Warning)
@@ -664,8 +821,20 @@ impl gpui::Render for ExtensionProviderConfigurationView {
}
}
- // Render API key section
+ // If authenticated, show success state with sign out option
if is_authenticated && !api_key_from_env {
+ let reset_label = if has_oauth && !has_api_key {
+ "Sign Out"
+ } else {
+ "Reset Credentials"
+ };
+
+ let status_label = if has_oauth && !has_api_key {
+ "Signed in"
+ } else {
+ "Authenticated"
+ };
+
content = content.child(
v_flex()
.gap_2()
@@ -677,39 +846,176 @@ impl gpui::Render for ExtensionProviderConfigurationView {
.color(Color::Success)
.size(ui::IconSize::Small),
)
- .child(Label::new("API key configured").color(Color::Success)),
+ .child(Label::new(status_label).color(Color::Success)),
)
.child(
- ui::Button::new("reset-api-key", "Reset API Key")
+ ui::Button::new("reset-credentials", reset_label)
.style(ui::ButtonStyle::Subtle)
.on_click(cx.listener(|this, _, window, cx| {
this.reset_api_key(window, cx);
})),
),
);
- } else if !api_key_from_env {
- let credential_label = self
- .auth_config
- .as_ref()
- .and_then(|c| c.credential_label.clone())
- .unwrap_or_else(|| "API Key".to_string());
- content = content.child(
- v_flex()
- .gap_2()
- .on_action(cx.listener(Self::save_api_key))
- .child(
- Label::new(credential_label)
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(self.api_key_editor.clone())
- .child(
- Label::new("Enter your API key and press Enter to save")
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
- );
+ return content.into_any_element();
+ }
+
+ // Not authenticated - show available auth options
+ if !api_key_from_env {
+ // Render OAuth sign-in button if configured
+ if has_oauth {
+ let oauth_config = self.oauth_config();
+ let button_label = oauth_config
+ .and_then(|c| c.sign_in_button_label.clone())
+ .unwrap_or_else(|| "Sign In".to_string());
+
+ let oauth_in_progress = self.oauth_in_progress;
+
+ let oauth_error = self.oauth_error.clone();
+
+ content = content.child(
+ v_flex()
+ .gap_2()
+ .child(
+ ui::Button::new("oauth-sign-in", button_label)
+ .style(ui::ButtonStyle::Filled)
+ .disabled(oauth_in_progress)
+ .on_click(cx.listener(|this, _, _window, cx| {
+ this.start_oauth_sign_in(cx);
+ })),
+ )
+ .when(oauth_in_progress, |this| {
+ let user_code = self.device_user_code.clone();
+ this.child(
+ v_flex()
+ .gap_1()
+ .when_some(user_code, |this, code| {
+ let copied = cx
+ .read_from_clipboard()
+ .map(|item| item.text().as_ref() == Some(&code))
+ .unwrap_or(false);
+ let code_for_click = code.clone();
+ this.child(
+ h_flex()
+ .gap_1()
+ .child(
+ Label::new("Enter code:")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ h_flex()
+ .gap_1()
+ .px_1()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .rounded_sm()
+ .cursor_pointer()
+ .on_mouse_down(
+ MouseButton::Left,
+ move |_, window, cx| {
+ cx.write_to_clipboard(
+ ClipboardItem::new_string(
+ code_for_click.clone(),
+ ),
+ );
+ window.refresh();
+ },
+ )
+ .child(
+ Label::new(code)
+ .size(LabelSize::Small)
+ .color(Color::Accent),
+ )
+ .child(
+ ui::Icon::new(if copied {
+ ui::IconName::Check
+ } else {
+ ui::IconName::Copy
+ })
+ .size(ui::IconSize::Small)
+ .color(if copied {
+ Color::Success
+ } else {
+ Color::Muted
+ }),
+ ),
+ ),
+ )
+ })
+ .child(
+ Label::new("Waiting for authorization in browser...")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ })
+ .when_some(oauth_error, |this, error| {
+ this.child(
+ v_flex()
+ .gap_1()
+ .child(
+ h_flex()
+ .gap_2()
+ .child(
+ ui::Icon::new(ui::IconName::Warning)
+ .color(Color::Error)
+ .size(ui::IconSize::Small),
+ )
+ .child(
+ Label::new("Authentication failed")
+ .color(Color::Error)
+ .size(LabelSize::Small),
+ ),
+ )
+ .child(
+ div().pl_6().child(
+ Label::new(error)
+ .color(Color::Error)
+ .size(LabelSize::Small),
+ ),
+ ),
+ )
+ }),
+ );
+ }
+
+ // Render API key input if configured (and we have both options, show a separator)
+ if has_api_key {
+ if has_oauth {
+ content = content.child(
+ h_flex()
+ .gap_2()
+ .items_center()
+ .child(div().h_px().flex_1().bg(cx.theme().colors().border))
+ .child(Label::new("or").size(LabelSize::Small).color(Color::Muted))
+ .child(div().h_px().flex_1().bg(cx.theme().colors().border)),
+ );
+ }
+
+ let credential_label = self
+ .auth_config
+ .as_ref()
+ .and_then(|c| c.credential_label.clone())
+ .unwrap_or_else(|| "API Key".to_string());
+
+ content = content.child(
+ v_flex()
+ .gap_2()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(
+ Label::new(credential_label)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new("Enter your API key and press Enter to save")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ );
+ }
}
content.into_any_element()
@@ -770,8 +1076,7 @@ impl LanguageModel for ExtensionLanguageModel {
}
fn name(&self) -> LanguageModelName {
- // HACK: Add "(Extension)" prefix to help distinguish extension models during debugging
- LanguageModelName::from(format!("(Extension) {}", self.model_info.name))
+ LanguageModelName::from(self.model_info.name.clone())
}
fn provider_id(&self) -> LanguageModelProviderId {
@@ -35,15 +35,16 @@ pub use latest::{
zed::extension::context_server::ContextServerConfiguration,
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
- CompletionRequest as LlmCompletionRequest, ImageData as LlmImageData,
- MessageContent as LlmMessageContent, MessageRole as LlmMessageRole,
- ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo,
- ProviderInfo as LlmProviderInfo, RequestMessage as LlmRequestMessage,
- StopReason as LlmStopReason, ThinkingContent as LlmThinkingContent,
- TokenUsage as LlmTokenUsage, ToolChoice as LlmToolChoice,
- ToolDefinition as LlmToolDefinition, ToolInputFormat as LlmToolInputFormat,
- ToolResult as LlmToolResult, ToolResultContent as LlmToolResultContent,
- ToolUse as LlmToolUse, ToolUseJsonParseError as LlmToolUseJsonParseError,
+ CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
+ ImageData as LlmImageData, MessageContent as LlmMessageContent,
+ MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
+ ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
+ RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
+ ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
+ ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
+ ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
+ ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
+ ToolUseJsonParseError as LlmToolUseJsonParseError,
},
zed::extension::lsp::{
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
@@ -1229,6 +1230,53 @@ impl Extension {
}
}
+ pub async fn call_llm_provider_authenticate(
+ &self,
+ store: &mut Store<WasmState>,
+ provider_id: &str,
+ ) -> Result<Result<(), String>> {
+ match self {
+ Extension::V0_8_0(ext) => ext.call_llm_provider_authenticate(store, provider_id).await,
+ _ => anyhow::bail!("`llm_provider_authenticate` not available prior to v0.8.0"),
+ }
+ }
+
+ pub async fn call_llm_provider_start_device_flow_sign_in(
+ &self,
+ store: &mut Store<WasmState>,
+ provider_id: &str,
+ ) -> Result<Result<String, String>> {
+ match self {
+ Extension::V0_8_0(ext) => {
+ ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
+ .await
+ }
+ _ => {
+ anyhow::bail!(
+ "`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
+ )
+ }
+ }
+ }
+
+ pub async fn call_llm_provider_poll_device_flow_sign_in(
+ &self,
+ store: &mut Store<WasmState>,
+ provider_id: &str,
+ ) -> Result<Result<(), String>> {
+ match self {
+ Extension::V0_8_0(ext) => {
+ ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
+ .await
+ }
+ _ => {
+ anyhow::bail!(
+ "`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
+ )
+ }
+ }
+ }
+
pub async fn call_llm_provider_reset_credentials(
&self,
store: &mut Store<WasmState>,
@@ -1112,6 +1112,20 @@ impl ExtensionImports for WasmState {
}
impl llm_provider::Host for WasmState {
+ async fn request_credential(
+ &mut self,
+ _provider_id: String,
+ _credential_type: llm_provider::CredentialType,
+ _label: String,
+ _placeholder: String,
+ ) -> wasmtime::Result<Result<bool, String>> {
+ // For now, credential requests return false (not provided)
+ // Extensions should use get_env_var to check for env vars first,
+ // then store_credential/get_credential for manual storage
+ // Full UI credential prompting will be added in a future phase
+ Ok(Ok(false))
+ }
+
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
@@ -1288,8 +1302,9 @@ impl llm_provider::Host for WasmState {
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
.port();
+ let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
cx.update(|cx| {
- cx.open_url(&auth_url);
+ cx.open_url(&auth_url_with_port);
})?;
let accept_future = async {
@@ -10,4 +10,7 @@ repository = "https://github.com/zed-industries/zed"
name = "Copilot Chat"
[language_model_providers.copilot-chat.auth]
-env_var = "GH_COPILOT_TOKEN"
+env_var = "GH_COPILOT_TOKEN"
+
+[language_model_providers.copilot-chat.auth.oauth]
+sign_in_button_label = "Sign in with GitHub"
@@ -1,13 +1,86 @@
use std::collections::HashMap;
use std::sync::Mutex;
+use std::thread;
+use std::time::Duration;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
+const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
+const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
+const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
+const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
+
+struct DeviceFlowState {
+ device_code: String,
+ interval: u64,
+ expires_in: u64,
+}
+
+#[derive(Clone)]
+struct ApiToken {
+ api_key: String,
+ api_endpoint: String,
+}
+
+#[derive(Clone, Deserialize)]
+struct CopilotModel {
+ id: String,
+ name: String,
+ #[serde(default)]
+ is_chat_default: bool,
+ #[serde(default)]
+ is_chat_fallback: bool,
+ #[serde(default)]
+ model_picker_enabled: bool,
+ #[serde(default)]
+ capabilities: ModelCapabilities,
+ #[serde(default)]
+ policy: Option<ModelPolicy>,
+}
+
+#[derive(Clone, Default, Deserialize)]
+struct ModelCapabilities {
+ #[serde(default)]
+ family: String,
+ #[serde(default)]
+ limits: ModelLimits,
+ #[serde(default)]
+ supports: ModelSupportedFeatures,
+ #[serde(rename = "type", default)]
+ model_type: String,
+}
+
+#[derive(Clone, Default, Deserialize)]
+struct ModelLimits {
+ #[serde(default)]
+ max_context_window_tokens: u64,
+ #[serde(default)]
+ max_output_tokens: u64,
+}
+
+#[derive(Clone, Default, Deserialize)]
+struct ModelSupportedFeatures {
+ #[serde(default)]
+ streaming: bool,
+ #[serde(default)]
+ tool_calls: bool,
+ #[serde(default)]
+ vision: bool,
+}
+
+#[derive(Clone, Deserialize)]
+struct ModelPolicy {
+ state: String,
+}
+
struct CopilotChatProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
+ device_flow_state: Mutex<Option<DeviceFlowState>>,
+ api_token: Mutex<Option<ApiToken>>,
+ cached_models: Mutex<Option<Vec<CopilotModel>>>,
}
struct StreamState {
@@ -25,95 +98,6 @@ struct AccumulatedToolCall {
arguments: String,
}
-struct ModelDefinition {
- id: &'static str,
- display_name: &'static str,
- max_tokens: u64,
- max_output_tokens: Option<u64>,
- supports_images: bool,
- is_default: bool,
- is_default_fast: bool,
-}
-
-const MODELS: &[ModelDefinition] = &[
- ModelDefinition {
- id: "gpt-4o",
- display_name: "GPT-4o",
- max_tokens: 128_000,
- max_output_tokens: Some(16_384),
- supports_images: true,
- is_default: true,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "gpt-4o-mini",
- display_name: "GPT-4o Mini",
- max_tokens: 128_000,
- max_output_tokens: Some(16_384),
- supports_images: true,
- is_default: false,
- is_default_fast: true,
- },
- ModelDefinition {
- id: "gpt-4.1",
- display_name: "GPT-4.1",
- max_tokens: 1_000_000,
- max_output_tokens: Some(32_768),
- supports_images: true,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "o1",
- display_name: "o1",
- max_tokens: 200_000,
- max_output_tokens: Some(100_000),
- supports_images: true,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "o3-mini",
- display_name: "o3-mini",
- max_tokens: 200_000,
- max_output_tokens: Some(100_000),
- supports_images: false,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "claude-3.5-sonnet",
- display_name: "Claude 3.5 Sonnet",
- max_tokens: 200_000,
- max_output_tokens: Some(8_192),
- supports_images: true,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "claude-3.7-sonnet",
- display_name: "Claude 3.7 Sonnet",
- max_tokens: 200_000,
- max_output_tokens: Some(8_192),
- supports_images: true,
- is_default: false,
- is_default_fast: false,
- },
- ModelDefinition {
- id: "gemini-2.0-flash-001",
- display_name: "Gemini 2.0 Flash",
- max_tokens: 1_000_000,
- max_output_tokens: Some(8_192),
- supports_images: true,
- is_default: false,
- is_default_fast: false,
- },
-];
-
-fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
- MODELS.iter().find(|m| m.id == model_id)
-}
-
#[derive(Serialize)]
struct OpenAiRequest {
model: String,
@@ -389,10 +373,7 @@ fn convert_request(
LlmToolChoice::None => "none".to_string(),
});
- let model_def = get_model_definition(model_id);
- let max_tokens = request
- .max_tokens
- .or(model_def.and_then(|m| m.max_output_tokens));
+ let max_tokens = request.max_tokens;
Ok(OpenAiRequest {
model: model_id.to_string(),
@@ -422,42 +403,46 @@ impl zed::Extension for CopilotChatProvider {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
+ device_flow_state: Mutex::new(None),
+ api_token: Mutex::new(None),
+ cached_models: Mutex::new(None),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
- id: "copilot_chat".into(),
+ id: "copilot-chat".into(),
name: "Copilot Chat".into(),
icon: Some("icons/copilot.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
- Ok(MODELS
- .iter()
- .map(|m| LlmModelInfo {
- id: m.id.to_string(),
- name: m.display_name.to_string(),
- max_token_count: m.max_tokens,
- max_output_tokens: m.max_output_tokens,
- capabilities: LlmModelCapabilities {
- supports_images: m.supports_images,
- supports_tools: true,
- supports_tool_choice_auto: true,
- supports_tool_choice_any: true,
- supports_tool_choice_none: true,
- supports_thinking: false,
- tool_input_format: LlmToolInputFormat::JsonSchema,
- },
- is_default: m.is_default,
- is_default_fast: m.is_default_fast,
- })
- .collect())
+ // Try to get models from cache first
+ if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
+ return Ok(convert_models_to_llm_info(models));
+ }
+
+ // Need to fetch models - requires authentication
+ let oauth_token = match llm_get_credential("copilot-chat") {
+ Some(token) => token,
+ None => return Ok(Vec::new()), // Not authenticated, return empty
+ };
+
+ // Get API token
+ let api_token = self.get_api_token(&oauth_token)?;
+
+ // Fetch models from API
+ let models = self.fetch_models(&api_token)?;
+
+ // Cache the models
+ *self.cached_models.lock().unwrap() = Some(models.clone());
+
+ Ok(convert_models_to_llm_info(&models))
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
- llm_get_credential("copilot_chat").is_some()
+ llm_get_credential("copilot-chat").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
@@ -466,13 +451,11 @@ impl zed::Extension for CopilotChatProvider {
Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
-## Configuration
+## Authentication
-Enter your GitHub Copilot token below. You need an active GitHub Copilot subscription.
+Click **Sign in with GitHub** to authenticate with your GitHub account. You'll be redirected to GitHub to authorize access. This requires an active GitHub Copilot subscription.
-To get your token:
-1. Ensure you have a GitHub Copilot subscription
-2. Generate a token from your GitHub Copilot settings
+Alternatively, you can set the `GH_COPILOT_TOKEN` environment variable with your token.
## Available Models
@@ -502,8 +485,156 @@ This extension requires an active GitHub Copilot subscription.
)
}
+ fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
+ // Check if we have existing credentials
+ if llm_get_credential("copilot-chat").is_some() {
+ return Ok(());
+ }
+
+ // No credentials found - return error for background auth checks.
+ // The device flow will be triggered by the host when the user clicks
+ // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in.
+ Err("CredentialsNotFound".to_string())
+ }
+
+ fn llm_provider_start_device_flow_sign_in(
+ &mut self,
+ _provider_id: &str,
+ ) -> Result<String, String> {
+ // Step 1: Request device and user verification codes
+ let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
+ url: GITHUB_DEVICE_CODE_URL.to_string(),
+ method: "POST".to_string(),
+ headers: vec![
+ ("Accept".to_string(), "application/json".to_string()),
+ (
+ "Content-Type".to_string(),
+ "application/x-www-form-urlencoded".to_string(),
+ ),
+ ],
+ body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
+ })?;
+
+ if device_code_response.status != 200 {
+ return Err(format!(
+ "Failed to get device code: HTTP {}",
+ device_code_response.status
+ ));
+ }
+
+ #[derive(Deserialize)]
+ struct DeviceCodeResponse {
+ device_code: String,
+ user_code: String,
+ verification_uri: String,
+ #[serde(default)]
+ verification_uri_complete: Option<String>,
+ expires_in: u64,
+ interval: u64,
+ }
+
+ let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
+ .map_err(|e| format!("Failed to parse device code response: {}", e))?;
+
+ // Store device flow state for polling
+ *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
+ device_code: device_info.device_code,
+ interval: device_info.interval,
+ expires_in: device_info.expires_in,
+ });
+
+ // Step 2: Open browser to verification URL
+ // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
+ let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
+ format!(
+ "{}?user_code={}",
+ device_info.verification_uri, &device_info.user_code
+ )
+ });
+ llm_oauth_open_browser(&verification_url)?;
+
+ // Return the user code for the host to display
+ Ok(device_info.user_code)
+ }
+
+ fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
+ let state = self
+ .device_flow_state
+ .lock()
+ .unwrap()
+ .take()
+ .ok_or("No device flow in progress")?;
+
+ let poll_interval = Duration::from_secs(state.interval.max(5));
+ let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
+
+ for _ in 0..max_attempts {
+ thread::sleep(poll_interval);
+
+ let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
+ url: GITHUB_ACCESS_TOKEN_URL.to_string(),
+ method: "POST".to_string(),
+ headers: vec![
+ ("Accept".to_string(), "application/json".to_string()),
+ (
+ "Content-Type".to_string(),
+ "application/x-www-form-urlencoded".to_string(),
+ ),
+ ],
+ body: format!(
+ "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
+ GITHUB_COPILOT_CLIENT_ID, state.device_code
+ ),
+ })?;
+
+ #[derive(Deserialize)]
+ struct TokenResponse {
+ access_token: Option<String>,
+ error: Option<String>,
+ error_description: Option<String>,
+ }
+
+ let token_json: TokenResponse = serde_json::from_str(&token_response.body)
+ .map_err(|e| format!("Failed to parse token response: {}", e))?;
+
+ if let Some(access_token) = token_json.access_token {
+ llm_store_credential("copilot-chat", &access_token)?;
+ return Ok(());
+ }
+
+ if let Some(error) = &token_json.error {
+ match error.as_str() {
+ "authorization_pending" => {
+ // User hasn't authorized yet, keep polling
+ continue;
+ }
+ "slow_down" => {
+ // Need to slow down polling
+ thread::sleep(Duration::from_secs(5));
+ continue;
+ }
+ "expired_token" => {
+ return Err("Device code expired. Please try again.".to_string());
+ }
+ "access_denied" => {
+ return Err("Authorization was denied.".to_string());
+ }
+ _ => {
+ let description = token_json.error_description.unwrap_or_default();
+ return Err(format!("OAuth error: {} - {}", error, description));
+ }
+ }
+ }
+ }
+
+ Err("Authorization timed out. Please try again.".to_string())
+ }
+
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
- llm_delete_credential("copilot_chat")
+ // Clear cached API token and models
+ *self.api_token.lock().unwrap() = None;
+ *self.cached_models.lock().unwrap() = None;
+ llm_delete_credential("copilot-chat")
}
fn llm_stream_completion_start(
@@ -512,21 +643,29 @@ This extension requires an active GitHub Copilot subscription.
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
- let api_key = llm_get_credential("copilot_chat").ok_or_else(|| {
+ let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
"No token configured. Please add your GitHub Copilot token in settings.".to_string()
})?;
+ // Get or refresh API token
+ let api_token = self.get_api_token(&oauth_token)?;
+
let openai_request = convert_request(model_id, request)?;
let body = serde_json::to_vec(&openai_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
+ let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
+
let http_request = HttpRequest {
method: HttpMethod::Post,
- url: "https://api.githubcopilot.com/chat/completions".to_string(),
+ url: completions_url,
headers: vec![
("Content-Type".to_string(), "application/json".to_string()),
- ("Authorization".to_string(), format!("Bearer {}", api_key)),
+ (
+ "Authorization".to_string(),
+ format!("Bearer {}", api_token.api_key),
+ ),
(
"Copilot-Integration-Id".to_string(),
"vscode-chat".to_string(),
@@ -679,4 +818,187 @@ This extension requires an active GitHub Copilot subscription.
}
}
+impl CopilotChatProvider {
+ fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
+ // Check if we have a cached token
+ if let Some(token) = self.api_token.lock().unwrap().clone() {
+ return Ok(token);
+ }
+
+ // Request a new API token
+ let http_request = HttpRequest {
+ method: HttpMethod::Get,
+ url: GITHUB_COPILOT_TOKEN_URL.to_string(),
+ headers: vec![
+ (
+ "Authorization".to_string(),
+ format!("token {}", oauth_token),
+ ),
+ ("Accept".to_string(), "application/json".to_string()),
+ ],
+ body: None,
+ redirect_policy: RedirectPolicy::FollowAll,
+ };
+
+ let response = http_request
+ .fetch()
+ .map_err(|e| format!("Failed to request API token: {}", e))?;
+
+ #[derive(Deserialize)]
+ struct ApiTokenResponse {
+ token: String,
+ endpoints: ApiEndpoints,
+ }
+
+ #[derive(Deserialize)]
+ struct ApiEndpoints {
+ api: String,
+ }
+
+ let token_response: ApiTokenResponse =
+ serde_json::from_slice(&response.body).map_err(|e| {
+ format!(
+ "Failed to parse API token response: {} - body: {}",
+ e,
+ String::from_utf8_lossy(&response.body)
+ )
+ })?;
+
+ let api_token = ApiToken {
+ api_key: token_response.token,
+ api_endpoint: token_response.endpoints.api,
+ };
+
+ // Cache the token
+ *self.api_token.lock().unwrap() = Some(api_token.clone());
+
+ Ok(api_token)
+ }
+
+ fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
+ let models_url = format!("{}/models", api_token.api_endpoint);
+
+ let http_request = HttpRequest {
+ method: HttpMethod::Get,
+ url: models_url,
+ headers: vec![
+ (
+ "Authorization".to_string(),
+ format!("Bearer {}", api_token.api_key),
+ ),
+ ("Content-Type".to_string(), "application/json".to_string()),
+ (
+ "Copilot-Integration-Id".to_string(),
+ "vscode-chat".to_string(),
+ ),
+ ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
+ ("x-github-api-version".to_string(), "2025-05-01".to_string()),
+ ],
+ body: None,
+ redirect_policy: RedirectPolicy::FollowAll,
+ };
+
+ let response = http_request
+ .fetch()
+ .map_err(|e| format!("Failed to fetch models: {}", e))?;
+
+ #[derive(Deserialize)]
+ struct ModelsResponse {
+ data: Vec<CopilotModel>,
+ }
+
+ let models_response: ModelsResponse =
+ serde_json::from_slice(&response.body).map_err(|e| {
+ format!(
+ "Failed to parse models response: {} - body: {}",
+ e,
+ String::from_utf8_lossy(&response.body)
+ )
+ })?;
+
+ // Filter models like the built-in Copilot Chat does
+ let mut models: Vec<CopilotModel> = models_response
+ .data
+ .into_iter()
+ .filter(|model| {
+ model.model_picker_enabled
+ && model.capabilities.model_type == "chat"
+ && model
+ .policy
+ .as_ref()
+ .map(|p| p.state == "enabled")
+ .unwrap_or(true)
+ })
+ .collect();
+
+ // Sort so default model is first
+ if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
+ let default_model = models.remove(pos);
+ models.insert(0, default_model);
+ }
+
+ Ok(models)
+ }
+}
+
+fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
+ models
+ .iter()
+ .map(|m| {
+ let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
+ m.capabilities.limits.max_context_window_tokens
+ } else {
+ 128_000 // Default fallback
+ };
+ let max_output = if m.capabilities.limits.max_output_tokens > 0 {
+ Some(m.capabilities.limits.max_output_tokens)
+ } else {
+ None
+ };
+
+ LlmModelInfo {
+ id: m.id.clone(),
+ name: m.name.clone(),
+ max_token_count: max_tokens,
+ max_output_tokens: max_output,
+ capabilities: LlmModelCapabilities {
+ supports_images: m.capabilities.supports.vision,
+ supports_tools: m.capabilities.supports.tool_calls,
+ supports_tool_choice_auto: m.capabilities.supports.tool_calls,
+ supports_tool_choice_any: m.capabilities.supports.tool_calls,
+ supports_tool_choice_none: m.capabilities.supports.tool_calls,
+ supports_thinking: false,
+ tool_input_format: LlmToolInputFormat::JsonSchema,
+ },
+ is_default: m.is_chat_default,
+ is_default_fast: m.is_chat_fallback,
+ }
+ })
+ .collect()
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_device_flow_request_body() {
+ let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
+ assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
+ assert!(body.contains("scope=read:user"));
+ }
+
+ #[test]
+ fn test_token_poll_request_body() {
+ let device_code = "test_device_code_123";
+ let body = format!(
+ "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
+ GITHUB_COPILOT_CLIENT_ID, device_code
+ );
+ assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
+ assert!(body.contains("device_code=test_device_code_123"));
+ assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
+ }
+}
+
zed::register_extension!(CopilotChatProvider);