From e1fe0b32870db4ce12a9b0666cc699ac2bae7a80 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Mon, 8 Dec 2025 16:25:41 -0500 Subject: [PATCH] Restore providers, deduplicate if extensions are present --- Cargo.lock | 11 + crates/agent_ui/src/agent_configuration.rs | 20 +- crates/agent_ui/src/agent_panel.rs | 2 +- .../agent_ui/src/language_model_selector.rs | 6 +- .../src/agent_api_keys_onboarding.rs | 2 +- .../src/agent_panel_onboarding_content.rs | 2 +- crates/icons/src/icons.rs | 1 + crates/language_model/src/registry.rs | 219 ++- crates/language_models/Cargo.toml | 5 + crates/language_models/src/api_key.rs | 16 +- crates/language_models/src/extension.rs | 26 +- crates/language_models/src/language_models.rs | 99 +- crates/language_models/src/provider.rs | 3 + .../language_models/src/provider/anthropic.rs | 1045 +++++++++++ .../src/provider/copilot_chat.rs | 1565 +++++++++++++++++ crates/language_models/src/provider/google.rs | 708 ++++++-- .../language_models/src/provider/open_ai.rs | 541 +++++- .../src/provider/open_router.rs | 1095 ++++++++++++ crates/language_models/src/settings.rs | 25 +- .../src/settings_content/language_model.rs | 30 + 20 files changed, 5269 insertions(+), 152 deletions(-) create mode 100644 crates/language_models/src/provider/anthropic.rs create mode 100644 crates/language_models/src/provider/copilot_chat.rs create mode 100644 crates/language_models/src/provider/open_router.rs diff --git a/Cargo.lock b/Cargo.lock index da60b0b2318f630b2f64f03d535c09d11bd4938d..bfe1ab0dd14709a423735c35842d0b14cac7628d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8854,16 +8854,19 @@ dependencies = [ "collections", "component", "convert_case 0.8.0", + "copilot", "credentials_provider", "deepseek", "editor", "extension", + "extension_host", "fs", "futures 0.3.31", "google_ai", "gpui", "gpui_tokio", "http_client", + "language", "language_model", "lmstudio", "log", @@ -8871,6 +8874,8 @@ dependencies = [ "mistral", "ollama", "open_ai", + "open_router", + "partial-json-fixer", "project", "release_channel", "schemars", @@ -11219,6 +11224,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "partial-json-fixer" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13" + [[package]] name = "password-hash" version = "0.4.2" diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 617aea8fd553cb8e9c7cd5c9814bc420adec4df3..9f214e9429d701069b41604cbf93ebfbb0cd58fb 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -83,14 +83,24 @@ impl AgentConfiguration { window, |this, _, event: &language_model::Event, window, cx| match event { language_model::Event::AddedProvider(provider_id) => { - let provider = LanguageModelRegistry::read_global(cx).provider(provider_id); - if let Some(provider) = provider { - this.add_provider_configuration_view(&provider, window, cx); + let registry = LanguageModelRegistry::read_global(cx); + // Only add if the provider is visible + if let Some(provider) = registry.provider(provider_id) { + if !registry.should_hide_provider(provider_id) { + this.add_provider_configuration_view(&provider, window, cx); + } } } language_model::Event::RemovedProvider(provider_id) => { this.remove_provider_configuration_view(provider_id); } + language_model::Event::ProvidersChanged => { + // Rebuild all provider views when visibility changes + this.configuration_views_by_provider.clear(); + this.expanded_provider_configurations.clear(); + this.build_provider_configuration_views(window, cx); + cx.notify(); + } _ => {} }, ); @@ -117,7 +127,7 @@ impl AgentConfiguration { } fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context) { - let providers = LanguageModelRegistry::read_global(cx).providers(); + let providers = LanguageModelRegistry::read_global(cx).visible_providers(); for provider in providers { self.add_provider_configuration_view(&provider, window, cx); } @@ -420,7 +430,7 @@ impl AgentConfiguration { &mut self, cx: &mut Context, ) -> impl IntoElement { - let providers = LanguageModelRegistry::read_global(cx).providers(); + let providers = LanguageModelRegistry::read_global(cx).visible_providers(); let popover_menu = PopoverMenu::new("add-provider-popover") .trigger( diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 18e8f1e731defa82e865dd45e66389634992037c..abbf0725594b7ee1a45d227261d0f5e05398673c 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -2291,7 +2291,7 @@ impl AgentPanel { let history_is_empty = self.history_store.read(cx).is_empty(cx); let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() .any(|provider| { provider.is_authenticated(cx) diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 9fd717a597e14918c3a3adc909ff53d2bb8de740..312a2293a072da34c8ffd7e6985c0a4b4c919fdf 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -46,7 +46,9 @@ pub fn language_model_selector( } fn all_models(cx: &App) -> GroupedModels { - let providers = LanguageModelRegistry::global(cx).read(cx).providers(); + let providers = LanguageModelRegistry::global(cx) + .read(cx) + .visible_providers(); let recommended = providers .iter() @@ -423,7 +425,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { let configured_providers = language_model_registry .read(cx) - .providers() + .visible_providers() .into_iter() .filter(|provider| provider.is_authenticated(cx)) .collect::>(); diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index bdf1ce3640bf5041b63d952625429156814dadfb..84c9695f4cb48b281e558da1b79e1c7aaae46192 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -44,7 +44,7 @@ impl ApiKeysWithProviders { fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> { LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index ae92268ff4db459e748b806e47f6f89851783bd9..831d97e5d4b7289f19aef40a3be5df8d967eb7a2 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -45,7 +45,7 @@ impl AgentPanelOnboarding { fn has_configured_providers(cx: &App) -> bool { LanguageModelRegistry::read_global(cx) - .providers() + .visible_providers() .iter() .any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID) } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index 6cea6db9079a38ec5c539d9bcaf9e9b84428420b..cc84129250cfdbe968aa3d86f1d00d0789d01480 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -9,6 +9,7 @@ use strum::{EnumIter, EnumString, IntoStaticStr}; #[strum(serialize_all = "snake_case")] pub enum IconName { Ai, + AiAnthropic, AiBedrock, AiClaude, AiDeepSeek, diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 27b8309810962981d3c0ec78e6e67dfdfba122bf..e14ebb0d3b82569a5a30be0b5601a478294490e8 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -2,12 +2,16 @@ use crate::{ LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; -use collections::BTreeMap; +use collections::{BTreeMap, HashSet}; use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*}; use std::{str::FromStr, sync::Arc}; use thiserror::Error; use util::maybe; +/// Function type for checking if a built-in provider should be hidden. +/// Returns Some(extension_id) if the provider should be hidden when that extension is installed. +pub type BuiltinProviderHidingFn = Box Option<&'static str> + Send + Sync>; + pub fn init(cx: &mut App) { let registry = cx.new(|_cx| LanguageModelRegistry::default()); cx.set_global(GlobalLanguageModelRegistry(registry)); @@ -48,6 +52,11 @@ pub struct LanguageModelRegistry { thread_summary_model: Option, providers: BTreeMap>, inline_alternatives: Vec>, + /// Set of installed extension IDs that provide language models. + /// Used to determine which built-in providers should be hidden. + installed_llm_extension_ids: HashSet>, + /// Function to check if a built-in provider should be hidden by an extension. + builtin_provider_hiding_fn: Option, } #[derive(Debug)] @@ -104,6 +113,8 @@ pub enum Event { ProviderStateChanged(LanguageModelProviderId), AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), + /// Emitted when provider visibility changes due to extension install/uninstall. + ProvidersChanged, } impl EventEmitter for LanguageModelRegistry {} @@ -183,6 +194,65 @@ impl LanguageModelRegistry { providers } + /// Returns providers, filtering out hidden built-in providers. + pub fn visible_providers(&self) -> Vec> { + self.providers() + .into_iter() + .filter(|p| !self.should_hide_provider(&p.id())) + .collect() + } + + /// Sets the function used to check if a built-in provider should be hidden. + pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) { + self.builtin_provider_hiding_fn = Some(hiding_fn); + } + + /// Called when an extension is installed/loaded. + /// If the extension provides language models, track it so we can hide the corresponding built-in. + pub fn extension_installed(&mut self, extension_id: Arc, cx: &mut Context) { + if self.installed_llm_extension_ids.insert(extension_id) { + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Called when an extension is uninstalled/unloaded. + pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context) { + if self.installed_llm_extension_ids.remove(extension_id) { + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Sync the set of installed LLM extension IDs. + pub fn sync_installed_llm_extensions( + &mut self, + extension_ids: HashSet>, + cx: &mut Context, + ) { + if extension_ids != self.installed_llm_extension_ids { + self.installed_llm_extension_ids = extension_ids; + cx.emit(Event::ProvidersChanged); + cx.notify(); + } + } + + /// Returns the set of installed LLM extension IDs. + pub fn installed_llm_extension_ids(&self) -> &HashSet> { + &self.installed_llm_extension_ids + } + + /// Returns true if a provider should be hidden from the UI. + /// Built-in providers are hidden when their corresponding extension is installed. + pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool { + if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn { + if let Some(extension_id) = hiding_fn(&provider_id.0) { + return self.installed_llm_extension_ids.contains(extension_id); + } + } + false + } + pub fn configuration_error( &self, model: Option, @@ -416,4 +486,151 @@ mod tests { let providers = registry.read(cx).providers(); assert!(providers.is_empty()); } + + #[gpui::test] + fn test_provider_hiding_on_extension_install(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + let provider_id = provider.id(); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + // Set up a hiding function that hides the fake provider when "fake-extension" is installed + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + }); + + // Provider should be visible initially + let visible = registry.read(cx).visible_providers(); + assert_eq!(visible.len(), 1); + assert_eq!(visible[0].id(), provider_id); + + // Install the extension + registry.update(cx, |registry, cx| { + registry.extension_installed("fake-extension".into(), cx); + }); + + // Provider should now be hidden + let visible = registry.read(cx).visible_providers(); + assert!(visible.is_empty()); + + // But still in providers() + let all = registry.read(cx).providers(); + assert_eq!(all.len(), 1); + } + + #[gpui::test] + fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + let provider_id = provider.id(); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + // Set up hiding function + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + + // Start with extension installed + registry.extension_installed("fake-extension".into(), cx); + }); + + // Provider should be hidden + let visible = registry.read(cx).visible_providers(); + assert!(visible.is_empty()); + + // Uninstall the extension + registry.update(cx, |registry, cx| { + registry.extension_uninstalled("fake-extension", cx); + }); + + // Provider should now be visible again + let visible = registry.read(cx).visible_providers(); + assert_eq!(visible.len(), 1); + assert_eq!(visible[0].id(), provider_id); + } + + #[gpui::test] + fn test_should_hide_provider(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + registry.update(cx, |registry, cx| { + // Set up hiding function + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "anthropic" { + Some("anthropic") + } else if id == "openai" { + Some("openai") + } else { + None + } + })); + + // Install only anthropic extension + registry.extension_installed("anthropic".into(), cx); + }); + + let registry_read = registry.read(cx); + + // Anthropic should be hidden + assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into()))); + + // OpenAI should not be hidden (extension not installed) + assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into()))); + + // Unknown provider should not be hidden + assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into()))); + } + + #[gpui::test] + fn test_sync_installed_llm_extensions(cx: &mut App) { + let registry = cx.new(|_| LanguageModelRegistry::default()); + + let provider = Arc::new(FakeLanguageModelProvider::default()); + + registry.update(cx, |registry, cx| { + registry.register_provider(provider.clone(), cx); + + registry.set_builtin_provider_hiding_fn(Box::new(|id| { + if id == "fake" { + Some("fake-extension") + } else { + None + } + })); + }); + + // Sync with a set containing the extension + let mut extension_ids = HashSet::default(); + extension_ids.insert(Arc::from("fake-extension")); + + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(extension_ids, cx); + }); + + // Provider should be hidden + assert!(registry.read(cx).visible_providers().is_empty()); + + // Sync with empty set + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(HashSet::default(), cx); + }); + + // Provider should be visible again + assert_eq!(registry.read(cx).visible_providers().len(), 1); + } } diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b9b354f2fad7ea9bc7573a6bb2af880f773d15e3..4aaf9dcec5d33a8625297ffba98e1dbbc1c57fa8 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -25,15 +25,18 @@ cloud_llm_client.workspace = true collections.workspace = true component.workspace = true convert_case.workspace = true +copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } extension.workspace = true +extension_host.workspace = true fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true +language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -41,6 +44,8 @@ menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } +open_router = { workspace = true, features = ["schemars"] } +partial-json-fixer.workspace = true release_channel.workspace = true schemars.workspace = true semver.workspace = true diff --git a/crates/language_models/src/api_key.rs b/crates/language_models/src/api_key.rs index 20d83c9d95e90380f99731a1bcfb903bb8ab93e9..122234b6ced6d0bf1b7a0d684683c841824ccd2d 100644 --- a/crates/language_models/src/api_key.rs +++ b/crates/language_models/src/api_key.rs @@ -223,13 +223,27 @@ impl ApiKeyState { } impl ApiKey { - fn from_env(env_var_name: SharedString, key: &str) -> Self { + pub fn key(&self) -> &str { + &self.key + } + + pub fn from_env(env_var_name: SharedString, key: &str) -> Self { Self { source: ApiKeySource::EnvVar(env_var_name), key: key.into(), } } + pub async fn load_from_system_keychain( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> Result { + Self::load_from_system_keychain_impl(url, credentials_provider, cx) + .await + .into_authenticate_result() + } + async fn load_from_system_keychain_impl( url: &str, credentials_provider: &dyn CredentialsProvider, diff --git a/crates/language_models/src/extension.rs b/crates/language_models/src/extension.rs index 9af6f41bd59955ade4b8030ef0689a9b5952d727..59dc98f211c19520c92124655ddace799d189158 100644 --- a/crates/language_models/src/extension.rs +++ b/crates/language_models/src/extension.rs @@ -1,7 +1,31 @@ +use collections::HashMap; use extension::{ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration}; use gpui::{App, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; + +/// Maps built-in provider IDs to their corresponding extension IDs. +/// When an extension with this ID is installed, the built-in provider should be hidden. +pub static BUILTIN_TO_EXTENSION_MAP: LazyLock> = + LazyLock::new(|| { + let mut map = HashMap::default(); + map.insert("anthropic", "anthropic"); + map.insert("openai", "openai"); + map.insert("google", "google-ai"); + map.insert("open_router", "open-router"); + map.insert("copilot_chat", "copilot-chat"); + map + }); + +/// Returns the extension ID that should hide the given built-in provider. +pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> { + BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied() +} + +/// Returns true if the given provider ID is a built-in provider that can be hidden by an extension. +pub fn is_hideable_builtin_provider(provider_id: &str) -> bool { + BUILTIN_TO_EXTENSION_MAP.contains_key(provider_id) +} /// Proxy implementation that registers extension-based language model providers /// with the LanguageModelRegistry. diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index b07fa39033159fc7ba2b33d168cc3bdf9217ef28..2e72539768c7f08192ef677ad19a36a416f2c3ed 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use ::extension::ExtensionHostProxy; use ::settings::{Settings, SettingsStore}; use client::{Client, UserStore}; use collections::HashSet; @@ -9,20 +8,25 @@ use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; mod api_key; -mod extension; +pub mod extension; mod google_ai_api_key; pub mod provider; mod settings; pub mod ui; -pub use google_ai_api_key::api_key_for_gemini_cli; - +pub use crate::extension::{extension_for_builtin_provider, is_hideable_builtin_provider}; +pub use crate::google_ai_api_key::api_key_for_gemini_cli; +use crate::provider::anthropic::AnthropicLanguageModelProvider; use crate::provider::bedrock::BedrockLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; +use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; +use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider; pub use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; +use crate::provider::open_ai::OpenAiLanguageModelProvider; use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider; +use crate::provider::open_router::OpenRouterLanguageModelProvider; use crate::provider::vercel::VercelLanguageModelProvider; use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; @@ -33,11 +37,65 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { register_language_model_providers(registry, user_store, client.clone(), cx); }); - // Register the extension language model provider proxy - let extension_proxy = ExtensionHostProxy::default_global(cx); - extension_proxy.register_language_model_provider_proxy( - extension::ExtensionLanguageModelProxy::new(registry.clone()), - ); + // Set up the provider hiding function + registry.update(cx, |registry, _cx| { + registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider)); + }); + + // Subscribe to extension store events to track LLM extension installations + if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) { + cx.subscribe(&extension_store, { + let registry = registry.clone(); + move |extension_store, event, cx| { + match event { + extension_host::Event::ExtensionInstalled(extension_id) => { + // Check if this extension has language_model_providers + if let Some(manifest) = extension_store + .read(cx) + .extension_manifest_for_id(extension_id) + { + if !manifest.language_model_providers.is_empty() { + registry.update(cx, |registry, cx| { + registry.extension_installed(extension_id.clone(), cx); + }); + } + } + } + extension_host::Event::ExtensionUninstalled(extension_id) => { + registry.update(cx, |registry, cx| { + registry.extension_uninstalled(extension_id, cx); + }); + } + extension_host::Event::ExtensionsUpdated => { + // Re-sync installed extensions on bulk updates + let mut new_ids = HashSet::default(); + for (extension_id, entry) in extension_store.read(cx).installed_extensions() + { + if !entry.manifest.language_model_providers.is_empty() { + new_ids.insert(extension_id.clone()); + } + } + registry.update(cx, |registry, cx| { + registry.sync_installed_llm_extensions(new_ids, cx); + }); + } + _ => {} + } + } + }) + .detach(); + + // Initialize with currently installed extensions + registry.update(cx, |registry, cx| { + let mut initial_ids = HashSet::default(); + for (extension_id, entry) in extension_store.read(cx).installed_extensions() { + if !entry.manifest.language_model_providers.is_empty() { + initial_ids.insert(extension_id.clone()); + } + } + registry.sync_installed_llm_extensions(initial_ids, cx); + }); + } let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) .openai_compatible @@ -117,6 +175,17 @@ fn register_language_model_providers( )), cx, ); + registry.register_provider( + Arc::new(AnthropicLanguageModelProvider::new( + client.http_client(), + cx, + )), + cx, + ); + registry.register_provider( + Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)), + cx, + ); registry.register_provider( Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)), cx, @@ -129,6 +198,10 @@ fn register_language_model_providers( Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)), cx, ); + registry.register_provider( + Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)), + cx, + ); registry.register_provider( MistralLanguageModelProvider::global(client.http_client(), cx), cx, @@ -137,6 +210,13 @@ fn register_language_model_providers( Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)), cx, ); + registry.register_provider( + Arc::new(OpenRouterLanguageModelProvider::new( + client.http_client(), + cx, + )), + cx, + ); registry.register_provider( Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)), cx, @@ -145,4 +225,5 @@ fn register_language_model_providers( Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)), cx, ); + registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index b5d10c1ede3d70e4d7dc8725131cf9e19a216ca3..d780195c66ec0d19c2b7d53e62b5e3629baa8a43 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,5 +1,7 @@ +pub mod anthropic; pub mod bedrock; pub mod cloud; +pub mod copilot_chat; pub mod deepseek; pub mod google; pub mod lmstudio; @@ -7,5 +9,6 @@ pub mod mistral; pub mod ollama; pub mod open_ai; pub mod open_ai_compatible; +pub mod open_router; pub mod vercel; pub mod x_ai; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs new file mode 100644 index 0000000000000000000000000000000000000000..1affe38a08d22e2aaed8c1207513ce41a13b8e59 --- /dev/null +++ b/crates/language_models/src/provider/anthropic.rs @@ -0,0 +1,1045 @@ +use anthropic::{ + ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, + ToolResultContent, ToolResultPart, Usage, +}; +use anyhow::{Result, anyhow}; +use collections::{BTreeMap, HashMap}; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, + LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, + LanguageModelToolResultContent, MessageContent, RateLimiter, Role, +}; +use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::{Arc, LazyLock}; +use strum::IntoEnumIterator; +use ui::{List, prelude::*}; +use ui_input::InputField; +use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; + +use crate::api_key::ApiKeyState; +use crate::ui::{ConfiguredApiCard, InstructionListItem}; + +pub use settings::AnthropicAvailableModel as AvailableModel; + +const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct AnthropicSettings { + pub api_url: String, + /// Extend Zed's list of Anthropic models. + pub available_models: Vec, +} + +pub struct AnthropicLanguageModelProvider { + http_client: Arc, + state: Entity, +} + +const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + +pub struct State { + api_key_state: ApiKeyState, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key_state.has_key() + } + + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) + } + + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } +} + +impl AnthropicLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); + cx.notify(); + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: anthropic::Model) -> Arc { + Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + fn settings(cx: &App) -> &AnthropicSettings { + &crate::AllLanguageModelSettings::get_global(cx).anthropic + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + ANTHROPIC_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } +} + +impl LanguageModelProviderState for AnthropicLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for AnthropicLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn icon(&self) -> IconName { + IconName::AiAnthropic + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(anthropic::Model::default_fast())) + } + + fn recommended_models(&self, _cx: &App) -> Vec> { + [ + anthropic::Model::ClaudeSonnet4_5, + anthropic::Model::ClaudeSonnet4_5Thinking, + ] + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from anthropic::Model::iter() + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &AnthropicLanguageModelProvider::settings(cx).available_models { + models.insert( + model.name.clone(), + anthropic::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + tool_override: model.tool_override.clone(), + cache_configuration: model.cache_configuration.as_ref().map(|config| { + anthropic::AnthropicModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + } + }), + max_output_tokens: model.max_output_tokens, + default_temperature: model.default_temperature, + extra_beta_headers: model.extra_beta_headers.clone(), + mode: model.mode.unwrap_or_default().into(), + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view( + &self, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) + } +} + +pub struct AnthropicModel { + id: LanguageModelId, + model: anthropic::Model, + state: Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +pub fn count_anthropic_tokens( + request: LanguageModelRequest, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + use language_model::MessageContent; + + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => { + // Thinking blocks are not included in the input token count. + } + MessageContent::RedactedThinking(_) => { + // Thinking blocks are not included in the input token count. + } + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => { + // TODO: Estimate token usage from tool uses. + } + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) + }) + .boxed() +} + +impl AnthropicModel { + fn stream_completion( + &self, + request: anthropic::Request, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let http_client = self.http_client.clone(); + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) + }) else { + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); + }; + + let beta_headers = self.model.beta_headers(); + + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = anthropic::stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + beta_headers, + ); + request.await.map_err(Into::into) + } + .boxed() + } +} + +impl LanguageModel for AnthropicModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn supports_tools(&self) -> bool { + true + } + + fn supports_images(&self) -> bool { + true + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn telemetry_id(&self) -> String { + format!("anthropic/{}", self.model.id()) + } + + fn api_key(&self, cx: &App) -> Option { + self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + state.api_key_state.key(&api_url).map(|key| key.to_string()) + }) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + Some(self.model.max_output_tokens()) + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_anthropic_tokens(request, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let request = into_anthropic( + request, + self.model.request_id().into(), + self.model.default_temperature(), + self.model.max_output_tokens(), + self.model.mode(), + ); + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await?; + Ok(AnthropicEventMapper::new().map_stream(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } + + fn cache_configuration(&self) -> Option { + self.model + .cache_configuration() + .map(|config| LanguageModelCacheConfiguration { + max_cache_anchors: config.max_cache_anchors, + should_speculate: config.should_speculate, + min_total_token: config.min_total_token, + }) + } +} + +pub fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: AnthropicModelMode, +) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(anthropic::RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if !thinking.is_empty() { + Some(anthropic::RequestContent::Thinking { + thinking, + signature: signature.unwrap_or_default(), + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(anthropic::RequestContent::RedactedThinking { data }) + } else { + None + } + } + MessageContent::Image(image) => Some(anthropic::RequestContent::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }) + } + }) + .collect(); + let anthropic_role = match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + // Mark the last segment of the message as cached + if message.cache { + let cache_control_value = Some(anthropic::CacheControl { + cache_type: anthropic::CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + anthropic::RequestContent::RedactedThinking { .. } => { + // Caching is not possible, fallback to next message + } + anthropic::RequestContent::Text { cache_control, .. } + | anthropic::RequestContent::Thinking { cache_control, .. } + | anthropic::RequestContent::Image { cache_control, .. } + | anthropic::RequestContent::ToolUse { cache_control, .. } + | anthropic::RequestContent::ToolResult { cache_control, .. } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(anthropic::Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + anthropic::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(anthropic::StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed + && let AnthropicModelMode::Thinking { budget_tokens } = mode + { + Some(anthropic::Thinking::Enabled { budget_tokens }) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, + LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, + LanguageModelToolChoice::None => anthropic::ToolChoice::None, + }), + metadata: None, + stop_sequences: Vec::new(), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = serde_json::Value::from_str( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))]; + } + } + vec![] + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let input_value = if input_json.is_empty() { + Ok(serde_json::Value::Object(serde_json::Map::default())) + } else { + serde_json::Value::from_str(input_json) + }; + let event_result = match input_value { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + thought_signature: None, + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +/// Updates usage data by preferring counts from `new`. +fn update_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_usage(usage: &Usage) -> language_model::TokenUsage { + language_model::TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} + +struct ConfigurationView { + api_key_editor: Entity, + state: Entity, + load_credentials_task: Option>, + target_agent: ConfigurationViewTargetAgent, +} + +impl ConfigurationView { + const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + + fn new( + state: Entity, + target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn({ + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)), + state, + load_credentials_task, + target_agent, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + let configured_card_label = if env_var_set { + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") + } else { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + if api_url == ANTHROPIC_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", api_url) + } + }; + + if self.load_credentials_task.is_some() { + div() + .child(Label::new("Loading credentials...")) + .into_any_element() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(), + ConfigurationViewTargetAgent::Other(agent) => agent.clone(), + }))) + .child( + List::new() + .child( + InstructionListItem::new( + "Create one by visiting", + Some("Anthropic's settings"), + Some("https://console.anthropic.com/settings/keys") + ) + ) + .child( + InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent") + ) + ) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + } else { + ConfiguredApiCard::new(configured_card_label) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) + .when(env_var_set, |this| { + this.tooltip_label(format!( + "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable." + )) + }) + .into_any_element() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anthropic::AnthropicModelMode; + use language_model::{LanguageModelRequestMessage, MessageContent}; + + #[test] + fn test_cache_control_only_on_last_segment() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("Some prompt".to_string()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + MessageContent::Image(language_model::LanguageModelImage::empty()), + ], + cache: true, + reasoning_details: None, + }], + thread_id: None, + prompt_id: None, + intent: None, + mode: None, + stop: vec![], + temperature: None, + tools: vec![], + tool_choice: None, + thinking_allowed: true, + }; + + let anthropic_request = into_anthropic( + request, + "claude-3-5-sonnet".to_string(), + 0.7, + 4096, + AnthropicModelMode::Default, + ); + + assert_eq!(anthropic_request.messages.len(), 1); + + let message = &anthropic_request.messages[0]; + assert_eq!(message.content.len(), 5); + + assert!(matches!( + message.content[0], + anthropic::RequestContent::Text { + cache_control: None, + .. + } + )); + for i in 1..3 { + assert!(matches!( + message.content[i], + anthropic::RequestContent::Image { + cache_control: None, + .. + } + )); + } + + assert!(matches!( + message.content[4], + anthropic::RequestContent::Image { + cache_control: Some(anthropic::CacheControl { + cache_type: anthropic::CacheControlType::Ephemeral, + }), + .. + } + )); + } +} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs new file mode 100644 index 0000000000000000000000000000000000000000..92ac342a39ff04ae42f5b01b5777a5d16563c37f --- /dev/null +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -0,0 +1,1565 @@ +use std::pin::Pin; +use std::str::FromStr as _; +use std::sync::Arc; + +use anyhow::{Result, anyhow}; +use cloud_llm_client::CompletionIntent; +use collections::HashMap; +use copilot::copilot_chat::{ + ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, + Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool, + ToolCall, +}; +use copilot::{Copilot, Status}; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use futures::{FutureExt, Stream, StreamExt}; +use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg}; +use http_client::StatusCode; +use language::language_settings::all_language_settings; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, + LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, + StopReason, TokenUsage, +}; +use settings::SettingsStore; +use ui::{CommonAnimationExt, prelude::*}; +use util::debug_panic; + +use crate::ui::ConfiguredApiCard; + +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); +const PROVIDER_NAME: LanguageModelProviderName = + LanguageModelProviderName::new("GitHub Copilot Chat"); + +pub struct CopilotChatLanguageModelProvider { + state: Entity, +} + +pub struct State { + _copilot_chat_subscription: Option, + _settings_subscription: Subscription, +} + +impl State { + fn is_authenticated(&self, cx: &App) -> bool { + CopilotChat::global(cx) + .map(|m| m.read(cx).is_authenticated()) + .unwrap_or(false) + } +} + +impl CopilotChatLanguageModelProvider { + pub fn new(cx: &mut App) -> Self { + let state = cx.new(|cx| { + let copilot_chat_subscription = CopilotChat::global(cx) + .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify())); + State { + _copilot_chat_subscription: copilot_chat_subscription, + _settings_subscription: cx.observe_global::(|_, cx| { + if let Some(copilot_chat) = CopilotChat::global(cx) { + let language_settings = all_language_settings(None, cx); + let configuration = copilot::copilot_chat::CopilotChatConfiguration { + enterprise_uri: language_settings + .edit_predictions + .copilot + .enterprise_uri + .clone(), + }; + copilot_chat.update(cx, |chat, cx| { + chat.set_configuration(configuration, cx); + }); + } + cx.notify(); + }), + } + }); + + Self { state } + } + + fn create_language_model(&self, model: CopilotChatModel) -> Arc { + Arc::new(CopilotChatLanguageModel { + model, + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for CopilotChatLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for CopilotChatLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn icon(&self) -> IconName { + IconName::Copilot + } + + fn default_model(&self, cx: &App) -> Option> { + let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?; + models + .first() + .map(|model| self.create_language_model(model.clone())) + } + + fn default_fast_model(&self, cx: &App) -> Option> { + // The default model should be Copilot Chat's 'base model', which is likely a relatively fast + // model (e.g. 4o) and a sensible choice when considering premium requests + self.default_model(cx) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else { + return Vec::new(); + }; + models + .iter() + .map(|model| self.create_language_model(model.clone())) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated(cx) + } + + fn authenticate(&self, cx: &mut App) -> Task> { + if self.is_authenticated(cx) { + return Task::ready(Ok(())); + }; + + let Some(copilot) = Copilot::global(cx) else { + return Task::ready(Err(anyhow!(concat!( + "Copilot must be enabled for Copilot Chat to work. ", + "Please enable Copilot and try again." + )) + .into())); + }; + + let err = match copilot.read(cx).status() { + Status::Authorized => return Task::ready(Ok(())), + Status::Disabled => anyhow!( + "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." + ), + Status::Error(err) => anyhow!(format!( + "Received the following error while signing into Copilot: {err}" + )), + Status::Starting { task: _ } => anyhow!( + "Copilot is still starting, please wait for Copilot to start then try again" + ), + Status::Unauthorized => anyhow!( + "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription." + ), + Status::SignedOut { .. } => { + anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.") + } + Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."), + }; + + Task::ready(Err(err.into())) + } + + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + _: &mut Window, + cx: &mut App, + ) -> AnyView { + let state = self.state.clone(); + cx.new(|cx| ConfigurationView::new(state, cx)).into() + } + + fn reset_credentials(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!( + "Signing out of GitHub Copilot Chat is currently not supported." + ))) + } +} + +fn collect_tiktoken_messages( + request: LanguageModelRequest, +) -> Vec { + request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>() +} + +pub struct CopilotChatLanguageModel { + model: CopilotChatModel, + request_limiter: RateLimiter, +} + +impl LanguageModel for CopilotChatLanguageModel { + fn id(&self) -> LanguageModelId { + LanguageModelId::from(self.model.id().to_string()) + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools() + } + + fn supports_images(&self) -> bool { + self.model.supports_vision() + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self.model.vendor() { + ModelVendor::OpenAI | ModelVendor::Anthropic => { + LanguageModelToolSchemaFormat::JsonSchema + } + ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } + } + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => self.supports_tools(), + } + } + + fn telemetry_id(&self) -> String { + format!("copilot_chat/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + let model = self.model.clone(); + cx.background_spawn(async move { + let messages = collect_tiktoken_messages(request); + // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). + let tokenizer_model = match model.tokenizer() { + Some("o200k_base") => "gpt-4o", + Some("cl100k_base") => "gpt-4", + _ => "gpt-4o", + }; + + tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) + .map(|tokens| tokens as u64) + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let is_user_initiated = request.intent.is_none_or(|intent| match intent { + CompletionIntent::UserPrompt + | CompletionIntent::ThreadContextSummarization + | CompletionIntent::InlineAssist + | CompletionIntent::TerminalInlineAssist + | CompletionIntent::GenerateGitCommitMessage => true, + + CompletionIntent::ToolResults + | CompletionIntent::ThreadSummarization + | CompletionIntent::CreateFile + | CompletionIntent::EditFile => false, + }); + + if self.model.supports_response() { + let responses_request = into_copilot_responses(&self.model, request); + let request_limiter = self.request_limiter.clone(); + let future = cx.spawn(async move |cx| { + let request = + CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone()); + request_limiter + .stream(async move { + let stream = request.await?; + let mapper = CopilotResponsesEventMapper::new(); + Ok(mapper.map_stream(stream).boxed()) + }) + .await + }); + return async move { Ok(future.await?.boxed()) }.boxed(); + } + + let copilot_request = match into_copilot_chat(&self.model, request) { + Ok(request) => request, + Err(err) => return futures::future::ready(Err(err.into())).boxed(), + }; + let is_streaming = copilot_request.stream; + + let request_limiter = self.request_limiter.clone(); + let future = cx.spawn(async move |cx| { + let request = + CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone()); + request_limiter + .stream(async move { + let response = request.await?; + Ok(map_to_language_model_completion_events( + response, + is_streaming, + )) + }) + .await + }); + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +pub fn map_to_language_model_completion_events( + events: Pin>>>, + is_streaming: bool, +) -> impl Stream> { + #[derive(Default)] + struct RawToolCall { + id: String, + name: String, + arguments: String, + thought_signature: Option, + } + + struct State { + events: Pin>>>, + tool_calls_by_index: HashMap, + reasoning_opaque: Option, + reasoning_text: Option, + } + + futures::stream::unfold( + State { + events, + tool_calls_by_index: HashMap::default(), + reasoning_opaque: None, + reasoning_text: None, + }, + move |mut state| async move { + if let Some(event) = state.events.next().await { + match event { + Ok(event) => { + let Some(choice) = event.choices.first() else { + return Some(( + vec![Err(anyhow!("Response contained no choices").into())], + state, + )); + }; + + let delta = if is_streaming { + choice.delta.as_ref() + } else { + choice.message.as_ref() + }; + + let Some(delta) = delta else { + return Some(( + vec![Err(anyhow!("Response contained no delta").into())], + state, + )); + }; + + let mut events = Vec::new(); + if let Some(content) = delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + + // Capture reasoning data from the delta (e.g. for Gemini 3) + if let Some(opaque) = delta.reasoning_opaque.clone() { + state.reasoning_opaque = Some(opaque); + } + if let Some(text) = delta.reasoning_text.clone() { + state.reasoning_text = Some(text); + } + + for (index, tool_call) in delta.tool_calls.iter().enumerate() { + let tool_index = tool_call.index.unwrap_or(index); + let entry = state.tool_calls_by_index.entry(tool_index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + + if let Some(thought_signature) = function.thought_signature.clone() + { + entry.thought_signature = Some(thought_signature); + } + } + } + + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + ))); + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + Some("tool_calls") => { + // Gemini 3 models send reasoning_opaque/reasoning_text that must + // be preserved and sent back in subsequent requests. Emit as + // ReasoningDetails so the agent stores it in the message. + if state.reasoning_opaque.is_some() + || state.reasoning_text.is_some() + { + let mut details = serde_json::Map::new(); + if let Some(opaque) = state.reasoning_opaque.take() { + details.insert( + "reasoning_opaque".to_string(), + serde_json::Value::String(opaque), + ); + } + if let Some(text) = state.reasoning_text.take() { + details.insert( + "reasoning_text".to_string(), + serde_json::Value::String(text), + ); + } + events.push(Ok( + LanguageModelCompletionEvent::ReasoningDetails( + serde_json::Value::Object(details), + ), + )); + } + + events.extend(state.tool_calls_by_index.drain().map( + |(_, tool_call)| { + // The model can output an empty string + // to indicate the absence of arguments. + // When that happens, create an empty + // object instead. + let arguments = if tool_call.arguments.is_empty() { + Ok(serde_json::Value::Object(Default::default())) + } else { + serde_json::Value::from_str(&tool_call.arguments) + }; + match arguments { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments, + thought_signature: tool_call.thought_signature, + }, + )), + Err(error) => Ok( + LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }, + ), + } + }, + )); + + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::ToolUse, + ))); + } + Some(stop_reason) => { + log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}"); + events.push(Ok(LanguageModelCompletionEvent::Stop( + StopReason::EndTurn, + ))); + } + None => {} + } + + return Some((events, state)); + } + Err(err) => return Some((vec![Err(anyhow!(err).into())], state)), + } + } + + None + }, + ) + .flat_map(futures::stream::iter) +} + +pub struct CopilotResponsesEventMapper { + pending_stop_reason: Option, +} + +impl CopilotResponsesEventMapper { + pub fn new() -> Self { + Self { + pending_stop_reason: None, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + fn map_event( + &mut self, + event: copilot::copilot_responses::StreamEvent, + ) -> Vec> { + match event { + copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item { + copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => { + vec![Ok(LanguageModelCompletionEvent::StartMessage { + message_id: id, + })] + } + _ => Vec::new(), + }, + + copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Text(delta))] + } + } + + copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item { + copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(), + copilot::copilot_responses::ResponseOutputItem::FunctionCall { + call_id, + name, + arguments, + thought_signature, + .. + } => { + let mut events = Vec::new(); + match serde_json::from_str::(&arguments) { + Ok(input) => events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: call_id.into(), + name: name.as_str().into(), + is_input_complete: true, + input, + raw_input: arguments.clone(), + thought_signature, + }, + ))), + Err(error) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: call_id.into(), + tool_name: name.as_str().into(), + raw_input: arguments.clone().into(), + json_parse_error: error.to_string(), + })) + } + } + // Record that we already emitted a tool-use stop so we can avoid duplicating + // a Stop event on Completed. + self.pending_stop_reason = Some(StopReason::ToolUse); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + events + } + copilot::copilot_responses::ResponseOutputItem::Reasoning { + summary, + encrypted_content, + .. + } => { + let mut events = Vec::new(); + + if let Some(blocks) = summary { + let mut text = String::new(); + for block in blocks { + text.push_str(&block.text); + } + if !text.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text, + signature: None, + })); + } + } + + if let Some(data) = encrypted_content { + events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data })); + } + + events + } + }, + + copilot::copilot_responses::StreamEvent::Completed { response } => { + let mut events = Vec::new(); + if let Some(usage) = response.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + if self.pending_stop_reason.take() != Some(StopReason::ToolUse) { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + events + } + + copilot::copilot_responses::StreamEvent::Incomplete { response } => { + let reason = response + .incomplete_details + .as_ref() + .and_then(|details| details.reason.as_ref()); + let stop_reason = match reason { + Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => { + StopReason::MaxTokens + } + Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => { + StopReason::Refusal + } + _ => self + .pending_stop_reason + .take() + .unwrap_or(StopReason::EndTurn), + }; + + let mut events = Vec::new(); + if let Some(usage) = response.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + + copilot::copilot_responses::StreamEvent::Failed { response } => { + let provider = PROVIDER_NAME; + let (status_code, message) = match response.error { + Some(error) => { + let status_code = StatusCode::from_str(&error.code) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + (status_code, error.message) + } + None => ( + StatusCode::INTERNAL_SERVER_ERROR, + "response.failed".to_string(), + ), + }; + vec![Err(LanguageModelCompletionError::HttpResponseError { + provider, + status_code, + message, + })] + } + + copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err( + LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))), + )], + + copilot::copilot_responses::StreamEvent::Created { .. } + | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(), + } + } +} + +fn into_copilot_chat( + model: &copilot::copilot_chat::Model, + request: LanguageModelRequest, +) -> Result { + let mut request_messages: Vec = Vec::new(); + for message in request.messages { + if let Some(last_message) = request_messages.last_mut() { + if last_message.role == message.role { + last_message.content.extend(message.content); + } else { + request_messages.push(message); + } + } else { + request_messages.push(message); + } + } + + let mut messages: Vec = Vec::new(); + for message in request_messages { + match message.role { + Role::User => { + for content in &message.content { + if let MessageContent::ToolResult(tool_result) = content { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => text.to_string().into(), + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + ChatMessageContent::Multipart(vec![ChatMessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + }, + }]) + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + "[Tool responded with an image, but this model does not support vision]".to_string().into() + } + } + }; + + messages.push(ChatMessage::Tool { + tool_call_id: tool_result.tool_use_id.to_string(), + content, + }); + } + } + + let mut content_parts = Vec::new(); + for content in &message.content { + match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } + if !text.is_empty() => + { + if let Some(ChatMessagePart::Text { text: text_content }) = + content_parts.last_mut() + { + text_content.push_str(text); + } else { + content_parts.push(ChatMessagePart::Text { + text: text.to_string(), + }); + } + } + MessageContent::Image(image) if model.supports_vision() => { + content_parts.push(ChatMessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + }, + }); + } + _ => {} + } + } + + if !content_parts.is_empty() { + messages.push(ChatMessage::User { + content: content_parts.into(), + }); + } + } + Role::Assistant => { + let mut tool_calls = Vec::new(); + for content in &message.content { + if let MessageContent::ToolUse(tool_use) = content { + tool_calls.push(ToolCall { + id: tool_use.id.to_string(), + content: copilot::copilot_chat::ToolCallContent::Function { + function: copilot::copilot_chat::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input)?, + thought_signature: tool_use.thought_signature.clone(), + }, + }, + }); + } + } + + let text_content = { + let mut buffer = String::new(); + for string in message.content.iter().filter_map(|content| match content { + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + Some(text.as_str()) + } + MessageContent::ToolUse(_) + | MessageContent::RedactedThinking(_) + | MessageContent::ToolResult(_) + | MessageContent::Image(_) => None, + }) { + buffer.push_str(string); + } + + buffer + }; + + // Extract reasoning_opaque and reasoning_text from reasoning_details + let (reasoning_opaque, reasoning_text) = + if let Some(details) = &message.reasoning_details { + let opaque = details + .get("reasoning_opaque") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let text = details + .get("reasoning_text") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + (opaque, text) + } else { + (None, None) + }; + + messages.push(ChatMessage::Assistant { + content: if text_content.is_empty() { + ChatMessageContent::empty() + } else { + text_content.into() + }, + tool_calls, + reasoning_opaque, + reasoning_text, + }); + } + Role::System => messages.push(ChatMessage::System { + content: message.string_contents(), + }), + } + } + + let tools = request + .tools + .iter() + .map(|tool| Tool::Function { + function: copilot::copilot_chat::Function { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.input_schema.clone(), + }, + }) + .collect::>(); + + Ok(CopilotChatRequest { + intent: true, + n: 1, + stream: model.uses_streaming(), + temperature: 0.1, + model: model.id().to_string(), + messages, + tools, + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto, + LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any, + LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None, + }), + }) +} + +fn into_copilot_responses( + model: &copilot::copilot_chat::Model, + request: LanguageModelRequest, +) -> copilot::copilot_responses::Request { + use copilot::copilot_responses as responses; + + let LanguageModelRequest { + thread_id: _, + prompt_id: _, + intent: _, + mode: _, + messages, + tools, + tool_choice, + stop: _, + temperature, + thinking_allowed: _, + } = request; + + let mut input_items: Vec = Vec::new(); + + for message in messages { + match message.role { + Role::User => { + for content in &message.content { + if let MessageContent::ToolResult(tool_result) = content { + let output = if let Some(out) = &tool_result.output { + match out { + serde_json::Value::String(s) => { + responses::ResponseFunctionOutput::Text(s.clone()) + } + serde_json::Value::Null => { + responses::ResponseFunctionOutput::Text(String::new()) + } + other => responses::ResponseFunctionOutput::Text(other.to_string()), + } + } else { + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + responses::ResponseFunctionOutput::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + responses::ResponseFunctionOutput::Content(vec![ + responses::ResponseInputContent::InputImage { + image_url: Some(image.to_base64_url()), + detail: Default::default(), + }, + ]) + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + responses::ResponseFunctionOutput::Text( + "[Tool responded with an image, but this model does not support vision]".into(), + ) + } + } + } + }; + + input_items.push(responses::ResponseInputItem::FunctionCallOutput { + call_id: tool_result.tool_use_id.to_string(), + output, + status: None, + }); + } + } + + let mut parts: Vec = Vec::new(); + for content in &message.content { + match content { + MessageContent::Text(text) => { + parts.push(responses::ResponseInputContent::InputText { + text: text.clone(), + }); + } + + MessageContent::Image(image) => { + if model.supports_vision() { + parts.push(responses::ResponseInputContent::InputImage { + image_url: Some(image.to_base64_url()), + detail: Default::default(), + }); + } + } + _ => {} + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "user".into(), + content: Some(parts), + status: None, + }); + } + } + + Role::Assistant => { + for content in &message.content { + if let MessageContent::ToolUse(tool_use) = content { + input_items.push(responses::ResponseInputItem::FunctionCall { + call_id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + arguments: tool_use.raw_input.clone(), + status: None, + thought_signature: tool_use.thought_signature.clone(), + }); + } + } + + for content in &message.content { + if let MessageContent::RedactedThinking(data) = content { + input_items.push(responses::ResponseInputItem::Reasoning { + id: None, + summary: Vec::new(), + encrypted_content: data.clone(), + }); + } + } + + let mut parts: Vec = Vec::new(); + for content in &message.content { + match content { + MessageContent::Text(text) => { + parts.push(responses::ResponseInputContent::OutputText { + text: text.clone(), + }); + } + MessageContent::Image(_) => { + parts.push(responses::ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + }); + } + _ => {} + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "assistant".into(), + content: Some(parts), + status: Some("completed".into()), + }); + } + } + + Role::System => { + let mut parts: Vec = Vec::new(); + for content in &message.content { + if let MessageContent::Text(text) = content { + parts.push(responses::ResponseInputContent::InputText { + text: text.clone(), + }); + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "system".into(), + content: Some(parts), + status: None, + }); + } + } + } + } + + let converted_tools: Vec = tools + .into_iter() + .map(|tool| responses::ToolDefinition::Function { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + strict: None, + }) + .collect(); + + let mapped_tool_choice = tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => responses::ToolChoice::Auto, + LanguageModelToolChoice::Any => responses::ToolChoice::Any, + LanguageModelToolChoice::None => responses::ToolChoice::None, + }); + + responses::Request { + model: model.id().to_string(), + input: input_items, + stream: model.uses_streaming(), + temperature, + tools: converted_tools, + tool_choice: mapped_tool_choice, + reasoning: None, // We would need to add support for setting from user settings. + include: Some(vec![ + copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent, + ]), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use copilot::copilot_responses as responses; + use futures::StreamExt; + + fn map_events(events: Vec) -> Vec { + futures::executor::block_on(async { + CopilotResponsesEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + .into_iter() + .map(Result::unwrap) + .collect() + }) + } + + #[test] + fn responses_stream_maps_text_and_usage() { + let events = vec![ + responses::StreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: Some(Vec::new()), + }, + }, + responses::StreamEvent::OutputTextDelta { + item_id: "msg_1".into(), + output_index: 0, + delta: "Hello".into(), + }, + responses::StreamEvent::Completed { + response: responses::Response { + usage: Some(responses::ResponseUsage { + input_tokens: Some(5), + output_tokens: Some(3), + total_tokens: Some(8), + }), + ..Default::default() + }, + }, + ]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_1" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Text(ref text) if text == "Hello" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 5, + output_tokens: 3, + .. + }) + )); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) + )); + } + + #[test] + fn responses_stream_maps_tool_calls() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{\"x\":1}".into(), + status: None, + thought_signature: None, + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(ref use_) if use_.id.to_string() == "call_1" && use_.name.as_ref() == "do_it" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_json_parse_error() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{not json}".into(), + status: None, + thought_signature: None, + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUseJsonParseError { ref id, ref tool_name, .. } + if id.to_string() == "call_1" && tool_name.as_ref() == "do_it" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_maps_reasoning_summary_and_encrypted_content() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::Reasoning { + id: "r1".into(), + summary: Some(vec![responses::ResponseReasoningItem { + kind: "summary_text".into(), + text: "Chain".into(), + }]), + encrypted_content: Some("ENC".into()), + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC" + )); + } + + #[test] + fn responses_stream_handles_incomplete_max_tokens() { + let events = vec![responses::StreamEvent::Incomplete { + response: responses::Response { + usage: Some(responses::ResponseUsage { + input_tokens: Some(10), + output_tokens: Some(0), + total_tokens: Some(10), + }), + incomplete_details: Some(responses::IncompleteDetails { + reason: Some(responses::IncompleteReason::MaxOutputTokens), + }), + ..Default::default() + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 10, + output_tokens: 0, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_incomplete_content_filter() { + let events = vec![responses::StreamEvent::Incomplete { + response: responses::Response { + usage: None, + incomplete_details: Some(responses::IncompleteDetails { + reason: Some(responses::IncompleteReason::ContentFilter), + }), + ..Default::default() + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped.last().unwrap(), + LanguageModelCompletionEvent::Stop(StopReason::Refusal) + )); + } + + #[test] + fn responses_stream_completed_no_duplicate_after_tool_use() { + let events = vec![ + responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{}".into(), + status: None, + thought_signature: None, + }, + }, + responses::StreamEvent::Completed { + response: responses::Response::default(), + }, + ]; + + let mapped = map_events(events); + + let mut stop_count = 0usize; + let mut saw_tool_use_stop = false; + for event in mapped { + if let LanguageModelCompletionEvent::Stop(reason) = event { + stop_count += 1; + if matches!(reason, StopReason::ToolUse) { + saw_tool_use_stop = true; + } + } + } + assert_eq!(stop_count, 1, "should emit exactly one Stop event"); + assert!(saw_tool_use_stop, "Stop reason should be ToolUse"); + } + + #[test] + fn responses_stream_failed_maps_http_response_error() { + let events = vec![responses::StreamEvent::Failed { + response: responses::Response { + error: Some(responses::ResponseError { + code: "429".into(), + message: "too many requests".into(), + }), + ..Default::default() + }, + }]; + + let mapped_results = futures::executor::block_on(async { + CopilotResponsesEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + }); + + assert_eq!(mapped_results.len(), 1); + match &mapped_results[0] { + Err(LanguageModelCompletionError::HttpResponseError { + status_code, + message, + .. + }) => { + assert_eq!(*status_code, http_client::StatusCode::TOO_MANY_REQUESTS); + assert_eq!(message, "too many requests"); + } + other => panic!("expected HttpResponseError, got {:?}", other), + } + } + + #[test] + fn chat_completions_stream_maps_reasoning_data() { + use copilot::copilot_chat::ResponseEvent; + + let events = vec![ + ResponseEvent { + choices: vec![copilot::copilot_chat::ResponseChoice { + index: Some(0), + finish_reason: None, + delta: Some(copilot::copilot_chat::ResponseDelta { + content: None, + role: Some(copilot::copilot_chat::Role::Assistant), + tool_calls: vec![copilot::copilot_chat::ToolCallChunk { + index: Some(0), + id: Some("call_abc123".to_string()), + function: Some(copilot::copilot_chat::FunctionChunk { + name: Some("list_directory".to_string()), + arguments: Some("{\"path\":\"test\"}".to_string()), + thought_signature: None, + }), + }], + reasoning_opaque: Some("encrypted_reasoning_token_xyz".to_string()), + reasoning_text: Some("Let me check the directory".to_string()), + }), + message: None, + }], + id: "chatcmpl-123".to_string(), + usage: None, + }, + ResponseEvent { + choices: vec![copilot::copilot_chat::ResponseChoice { + index: Some(0), + finish_reason: Some("tool_calls".to_string()), + delta: Some(copilot::copilot_chat::ResponseDelta { + content: None, + role: None, + tool_calls: vec![], + reasoning_opaque: None, + reasoning_text: None, + }), + message: None, + }], + id: "chatcmpl-123".to_string(), + usage: None, + }, + ]; + + let mapped = futures::executor::block_on(async { + map_to_language_model_completion_events( + Box::pin(futures::stream::iter(events.into_iter().map(Ok))), + true, + ) + .collect::>() + .await + }); + + let mut has_reasoning_details = false; + let mut has_tool_use = false; + let mut reasoning_opaque_value: Option = None; + let mut reasoning_text_value: Option = None; + + for event_result in mapped { + match event_result { + Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => { + has_reasoning_details = true; + reasoning_opaque_value = details + .get("reasoning_opaque") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + reasoning_text_value = details + .get("reasoning_text") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + } + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + has_tool_use = true; + assert_eq!(tool_use.id.to_string(), "call_abc123"); + assert_eq!(tool_use.name.as_ref(), "list_directory"); + } + _ => {} + } + } + + assert!( + has_reasoning_details, + "Should emit ReasoningDetails event for Gemini 3 reasoning" + ); + assert!(has_tool_use, "Should emit ToolUse event"); + assert_eq!( + reasoning_opaque_value, + Some("encrypted_reasoning_token_xyz".to_string()), + "Should capture reasoning_opaque" + ); + assert_eq!( + reasoning_text_value, + Some("Let me check the directory".to_string()), + "Should capture reasoning_text" + ); + } +} +struct ConfigurationView { + copilot_status: Option, + state: Entity, + _subscription: Option, +} + +impl ConfigurationView { + pub fn new(state: Entity, cx: &mut Context) -> Self { + let copilot = Copilot::global(cx); + + Self { + copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), + state, + _subscription: copilot.as_ref().map(|copilot| { + cx.observe(copilot, |this, model, cx| { + this.copilot_status = Some(model.read(cx).status()); + cx.notify(); + }) + }), + } + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if self.state.read(cx).is_authenticated(cx) { + ConfiguredApiCard::new("Authorized") + .button_label("Sign Out") + .on_click(|_, window, cx| { + window.dispatch_action(copilot::SignOut.boxed_clone(), cx); + }) + .into_any_element() + } else { + let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4); + + const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider."; + + match &self.copilot_status { + Some(status) => match status { + Status::Starting { task: _ } => h_flex() + .gap_2() + .child(loading_icon) + .child(Label::new("Starting Copilot…")) + .into_any_element(), + Status::SigningIn { prompt: _ } + | Status::SignedOut { + awaiting_signing_in: true, + } => h_flex() + .gap_2() + .child(loading_icon) + .child(Label::new("Signing into Copilot…")) + .into_any_element(), + Status::Error(_) => { + const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot."; + v_flex() + .gap_6() + .child(Label::new(LABEL)) + .child(svg().size_8().path(IconName::CopilotError.path())) + .into_any_element() + } + _ => { + const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; + + v_flex() + .gap_2() + .child(Label::new(LABEL)) + .child( + Button::new("sign_in", "Sign in to use GitHub Copilot") + .full_width() + .style(ButtonStyle::Outlined) + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .on_click(|_, window, cx| { + copilot::initiate_sign_in(window, cx) + }), + ) + .into_any_element() + } + }, + None => v_flex() + .gap_6() + .child(Label::new(ERROR_LABEL)) + .into_any_element(), + } + } + } +} diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index fdea1aabc013085f09c930b5cfa4a283d92f1a8b..c5a5affcd3d9e8c34f6306f86cb5348f86397892 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,22 +1,44 @@ -use anyhow::Result; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; +use anyhow::{Context as _, Result, anyhow}; +use collections::BTreeMap; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; use google_ai::{ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, ThinkingConfig, UsageMetadata, }; -use gpui::{App, AppContext as _}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; +use http_client::HttpClient; use language_model::{ - LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, - StopReason, + AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, + LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, +}; +use language_model::{ + LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, + LanguageModelRequest, RateLimiter, Role, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub use settings::GoogleAvailableModel as AvailableModel; -use std::{ - pin::Pin, - sync::atomic::{self, AtomicU64}, +use settings::{Settings, SettingsStore}; +use std::pin::Pin; +use std::sync::{ + Arc, LazyLock, + atomic::{self, AtomicU64}, }; +use strum::IntoEnumIterator; +use ui::{List, prelude::*}; +use ui_input::InputField; +use util::ResultExt; +use zed_env_vars::EnvVar; + +use crate::api_key::ApiKey; +use crate::api_key::ApiKeyState; +use crate::ui::{ConfiguredApiCard, InstructionListItem}; + +const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct GoogleSettings { @@ -35,6 +57,346 @@ pub enum ModelMode { }, } +pub struct GoogleLanguageModelProvider { + http_client: Arc, + state: Entity, +} + +pub struct State { + api_key_state: ApiKeyState, +} + +const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; +const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; + +static API_KEY_ENV_VAR: LazyLock = LazyLock::new(|| { + // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY + EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into())) +}); + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key_state.has_key() + } + + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) + } + + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } +} + +impl GoogleLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); + cx.notify(); + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: google_ai::Model) -> Arc { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { + if let Some(key) = API_KEY_ENV_VAR.value.clone() { + return Task::ready(Ok(key)); + } + let credentials_provider = ::global(cx); + let api_url = Self::api_url(cx).to_string(); + cx.spawn(async move |cx| { + Ok( + ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx) + .await? + .key() + .to_string(), + ) + }) + } + + fn settings(cx: &App) -> &GoogleSettings { + &crate::AllLanguageModelSettings::get_global(cx).google + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + google_ai::API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } +} + +impl LanguageModelProviderState for GoogleLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for GoogleLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn icon(&self) -> IconName { + IconName::AiGoogle + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(google_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(google_ai::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from google_ai::Model::iter() + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &GoogleLanguageModelProvider::settings(cx).available_models { + models.insert( + model.name.clone(), + google_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + mode: model.mode.unwrap_or_default(), + }, + ); + } + + models + .into_values() + .map(|model| { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view( + &self, + target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) + } +} + +pub struct GoogleLanguageModel { + id: LanguageModelId, + model: google_ai::Model, + state: Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl GoogleLanguageModel { + fn stream_completion( + &self, + request: google_ai::GenerateContentRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result>>, + > { + let http_client = self.http_client.clone(); + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = GoogleLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) + }) else { + return future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.context("Missing Google API key")?; + let request = google_ai::stream_generate_content( + http_client.as_ref(), + &api_url, + &api_key, + request, + ); + request.await.context("failed to stream completion") + } + .boxed() + } +} + +impl LanguageModel for GoogleLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn supports_tools(&self) -> bool { + self.model.supports_tools() + } + + fn supports_images(&self) -> bool { + self.model.supports_images() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto + | LanguageModelToolChoice::Any + | LanguageModelToolChoice::None => true, + } + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } + + fn telemetry_id(&self) -> String { + format!("google/{}", self.model.request_id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + let model_id = self.model.request_id().to_string(); + let request = into_google(request, model_id, self.model.mode()); + let http_client = self.http_client.clone(); + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = self.state.read(cx).api_key_state.key(&api_url); + + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + } + .into()); + }; + let response = google_ai::count_tokens( + http_client.as_ref(), + &api_url, + &api_key, + google_ai::CountTokensRequest { + generate_content_request: request, + }, + ) + .await?; + Ok(response.total_tokens) + } + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let request = into_google( + request, + self.model.request_id().to_string(), + self.model.mode(), + ); + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await.map_err(LanguageModelCompletionError::from)?; + Ok(GoogleEventMapper::new().map_stream(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } +} + pub fn into_google( mut request: LanguageModelRequest, model_id: String, @@ -77,6 +439,7 @@ pub fn into_google( })] } language_model::MessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); vec![Part::FunctionCallPart(google_ai::FunctionCallPart { @@ -94,6 +457,7 @@ pub fn into_google( google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object response: serde_json::json!({ "output": text }), @@ -106,6 +470,7 @@ pub fn into_google( Part::FunctionResponsePart(google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object response: serde_json::json!({ "output": "Tool responded with an image" }), @@ -154,7 +519,7 @@ pub fn into_google( role: match message.role { Role::User => google_ai::Role::User, Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, + Role::System => google_ai::Role::User, // Google AI doesn't have a system role }, }) } @@ -288,13 +653,13 @@ impl GoogleEventMapper { Part::InlineDataPart(_) => {} Part::FunctionCallPart(function_call_part) => { wants_to_use_tool = true; - let name: std::sync::Arc = - function_call_part.function_call.name.into(); + let name: Arc = function_call_part.function_call.name.into(); let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); let id: LanguageModelToolUseId = format!("{}-{}", name, next_tool_id).into(); + // Normalize empty string signatures to None let thought_signature = function_call_part .thought_signature .filter(|s| !s.is_empty()); @@ -313,7 +678,7 @@ impl GoogleEventMapper { Part::FunctionResponsePart(_) => {} Part::ThoughtPart(part) => { events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), + text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? signature: Some(part.thought_signature), })); } @@ -321,6 +686,8 @@ impl GoogleEventMapper { } } + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` if wants_to_use_tool { self.stop_reason = StopReason::ToolUse; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); @@ -333,6 +700,8 @@ pub fn count_google_tokens( request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { + // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. + // So we have to use tokenizer from tiktoken_rs to count tokens. cx.background_spawn(async move { let messages = request .messages @@ -349,6 +718,8 @@ pub fn count_google_tokens( }) .collect::>(); + // Tiktoken doesn't yet support these models, so we manually use the + // same tokenizer as GPT-4. tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) }) .boxed() @@ -389,6 +760,148 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { } } +struct ConfigurationView { + api_key_editor: Entity, + state: Entity, + target_agent: language_model::ConfigurationViewTargetAgent, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new( + state: Entity, + target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut Context, + ) -> Self { + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")), + target_agent, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { + return; + } + + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + let configured_card_label = if env_var_set { + format!( + "API key set in {} environment variable", + API_KEY_ENV_VAR.name + ) + } else { + let api_url = GoogleLanguageModelProvider::api_url(cx); + if api_url == google_ai::API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", api_url) + } + }; + + if self.load_credentials_task.is_some() { + div() + .child(Label::new("Loading credentials...")) + .into_any_element() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { + ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(), + ConfigurationViewTargetAgent::Other(agent) => agent.clone(), + }))) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("Google AI's console"), + Some("https://aistudio.google.com/app/apikey"), + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any_element() + } else { + ConfiguredApiCard::new(configured_card_label) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) + .when(env_var_set, |this| { + this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")) + }) + .into_any_element() + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -427,7 +940,7 @@ mod tests { let events = mapper.map_event(response); - assert_eq!(events.len(), 2); + assert_eq!(events.len(), 2); // ToolUse event + Stop event if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { assert_eq!(tool_use.name.as_ref(), "test_function"); @@ -521,25 +1034,18 @@ mod tests { parts: vec![ Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "function_a".to_string(), - args: json!({}), + name: "function_1".to_string(), + args: json!({"arg": "value1"}), }, - thought_signature: Some("sig_a".to_string()), + thought_signature: Some("signature_1".to_string()), }), Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "function_b".to_string(), - args: json!({}), + name: "function_2".to_string(), + args: json!({"arg": "value2"}), }, thought_signature: None, }), - Part::FunctionCallPart(FunctionCallPart { - function_call: FunctionCall { - name: "function_c".to_string(), - args: json!({}), - }, - thought_signature: Some("sig_c".to_string()), - }), ], role: GoogleRole::Model, }, @@ -554,35 +1060,35 @@ mod tests { let events = mapper.map_event(response); - let tool_uses: Vec<_> = events - .iter() - .filter_map(|e| { - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = e { - Some(tool_use) - } else { - None - } - }) - .collect(); + assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event - assert_eq!(tool_uses.len(), 3); - assert_eq!(tool_uses[0].thought_signature.as_deref(), Some("sig_a")); - assert_eq!(tool_uses[1].thought_signature, None); - assert_eq!(tool_uses[2].thought_signature.as_deref(), Some("sig_c")); + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "function_1"); + assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); + } else { + panic!("Expected ToolUse event for function_1"); + } + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { + assert_eq!(tool_use.name.as_ref(), "function_2"); + assert_eq!(tool_use.thought_signature, None); + } else { + panic!("Expected ToolUse event for function_2"); + } } #[test] fn test_tool_use_with_signature_converts_to_function_call_part() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test-id"), - name: "test_tool".into(), - input: json!({"key": "value"}), - raw_input: r#"{"key": "value"}"#.to_string(), + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), is_input_complete: true, - thought_signature: Some("test_sig".to_string()), + thought_signature: Some("test_signature_456".to_string()), }; - let request = into_google( + let request = super::into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -596,11 +1102,13 @@ mod tests { GoogleModelMode::Default, ); - let parts = &request.contents[0].parts; - assert_eq!(parts.len(), 1); - - if let Part::FunctionCallPart(fcp) = &parts[0] { - assert_eq!(fcp.thought_signature.as_deref(), Some("test_sig")); + assert_eq!(request.contents[0].parts.len(), 1); + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.function_call.name, "test_function"); + assert_eq!( + fc_part.thought_signature.as_deref(), + Some("test_signature_456") + ); } else { panic!("Expected FunctionCallPart"); } @@ -609,15 +1117,15 @@ mod tests { #[test] fn test_tool_use_without_signature_omits_field() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test-id"), - name: "test_tool".into(), - input: json!({"key": "value"}), - raw_input: r#"{"key": "value"}"#.to_string(), + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), is_input_complete: true, thought_signature: None, }; - let request = into_google( + let request = super::into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -631,10 +1139,9 @@ mod tests { GoogleModelMode::Default, ); - let parts = &request.contents[0].parts; - - if let Part::FunctionCallPart(fcp) = &parts[0] { - assert_eq!(fcp.thought_signature, None); + assert_eq!(request.contents[0].parts.len(), 1); + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature, None); } else { panic!("Expected FunctionCallPart"); } @@ -643,15 +1150,15 @@ mod tests { #[test] fn test_empty_signature_in_tool_use_normalized_to_none() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test-id"), - name: "test_tool".into(), - input: json!({}), - raw_input: "{}".to_string(), + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), is_input_complete: true, thought_signature: Some("".to_string()), }; - let request = into_google( + let request = super::into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -665,10 +1172,8 @@ mod tests { GoogleModelMode::Default, ); - let parts = &request.contents[0].parts; - - if let Part::FunctionCallPart(fcp) = &parts[0] { - assert_eq!(fcp.thought_signature, None); + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature, None); } else { panic!("Expected FunctionCallPart"); } @@ -676,8 +1181,9 @@ mod tests { #[test] fn test_round_trip_preserves_signature() { - let original_signature = "original_thought_signature_abc123"; + let mut mapper = GoogleEventMapper::new(); + // Simulate receiving a response from Google with a signature let response = GenerateContentResponse { candidates: Some(vec![GenerateContentCandidate { index: Some(0), @@ -687,7 +1193,7 @@ mod tests { name: "test_function".to_string(), args: json!({"arg": "value"}), }, - thought_signature: Some(original_signature.to_string()), + thought_signature: Some("round_trip_sig".to_string()), })], role: GoogleRole::Model, }, @@ -700,7 +1206,6 @@ mod tests { usage_metadata: None, }; - let mut mapper = GoogleEventMapper::new(); let events = mapper.map_event(response); let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { @@ -709,7 +1214,8 @@ mod tests { panic!("Expected ToolUse event"); }; - let request = into_google( + // Convert back to Google format + let request = super::into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -723,9 +1229,9 @@ mod tests { GoogleModelMode::Default, ); - let parts = &request.contents[0].parts; - if let Part::FunctionCallPart(fcp) = &parts[0] { - assert_eq!(fcp.thought_signature.as_deref(), Some(original_signature)); + // Verify signature is preserved + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); } else { panic!("Expected FunctionCallPart"); } @@ -741,14 +1247,14 @@ mod tests { content: Content { parts: vec![ Part::TextPart(TextPart { - text: "Let me help you with that.".to_string(), + text: "I'll help with that.".to_string(), }), Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "search".to_string(), - args: json!({"query": "test"}), + name: "helper_function".to_string(), + args: json!({"query": "help"}), }, - thought_signature: Some("thinking_sig".to_string()), + thought_signature: Some("mixed_sig".to_string()), }), ], role: GoogleRole::Model, @@ -764,46 +1270,38 @@ mod tests { let events = mapper.map_event(response); - let mut found_text = false; - let mut found_tool_with_sig = false; + assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event - for event in events { - match event { - Ok(LanguageModelCompletionEvent::Text(text)) => { - assert_eq!(text, "Let me help you with that."); - found_text = true; - } - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { - assert_eq!(tool_use.thought_signature.as_deref(), Some("thinking_sig")); - found_tool_with_sig = true; - } - _ => {} - } + if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { + assert_eq!(text, "I'll help with that."); + } else { + panic!("Expected Text event"); } - assert!(found_text, "Should have found text event"); - assert!( - found_tool_with_sig, - "Should have found tool use with signature" - ); + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { + assert_eq!(tool_use.name.as_ref(), "helper_function"); + assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); + } else { + panic!("Expected ToolUse event"); + } } #[test] fn test_special_characters_in_signature_preserved() { - let special_signature = "sig/with+special=chars&more%stuff"; - let mut mapper = GoogleEventMapper::new(); + let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); + let response = GenerateContentResponse { candidates: Some(vec![GenerateContentCandidate { index: Some(0), content: Content { parts: vec![Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "test".to_string(), - args: json!({}), + name: "test_function".to_string(), + args: json!({"arg": "value"}), }, - thought_signature: Some(special_signature.to_string()), + thought_signature: Some(signature_with_special_chars.clone()), })], role: GoogleRole::Model, }, @@ -821,7 +1319,7 @@ mod tests { if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { assert_eq!( tool_use.thought_signature.as_deref(), - Some(special_signature) + Some(signature_with_special_chars.as_str()) ); } else { panic!("Expected ToolUse event"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 6a4f42ab36a6c356b4738c6633bc6d9913a7d338..32ee95ce9bd423bf7f66efc1bc7440455380ab5c 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,17 +1,38 @@ use anyhow::{Result, anyhow}; -use collections::HashMap; -use futures::{FutureExt, Stream, future::BoxFuture}; -use gpui::{App, AppContext as _}; +use collections::{BTreeMap, HashMap}; +use futures::Stream; +use futures::{FutureExt, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; +use http_client::HttpClient; use language_model::{ - LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage, + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, + RateLimiter, Role, StopReason, TokenUsage, }; -use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent}; -pub use settings::OpenAiAvailableModel as AvailableModel; +use menu; +use open_ai::{ + ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion, +}; +use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; use std::pin::Pin; -use std::str::FromStr; +use std::str::FromStr as _; +use std::sync::{Arc, LazyLock}; +use strum::IntoEnumIterator; +use ui::{List, prelude::*}; +use ui_input::InputField; +use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; + +use crate::ui::ConfiguredApiCard; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; + +const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; +const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; -use language_model::LanguageModelToolResultContent; +const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { @@ -19,6 +40,314 @@ pub struct OpenAiSettings { pub available_models: Vec, } +pub struct OpenAiLanguageModelProvider { + http_client: Arc, + state: Entity, +} + +pub struct State { + api_key_state: ApiKeyState, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key_state.has_key() + } + + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) + } + + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } +} + +impl OpenAiLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); + cx.notify(); + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: open_ai::Model) -> Arc { + Arc::new(OpenAiLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + fn settings(cx: &App) -> &OpenAiSettings { + &crate::AllLanguageModelSettings::get_global(cx).openai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + open_ai::OPEN_AI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } +} + +impl LanguageModelProviderState for OpenAiLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for OpenAiLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn icon(&self) -> IconName { + IconName::AiOpenAi + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_ai::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_ai::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from open_ai::Model::iter() + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &OpenAiLanguageModelProvider::settings(cx).available_models { + models.insert( + model.name.clone(), + open_ai::Model::Custom { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + max_output_tokens: model.max_output_tokens, + max_completion_tokens: model.max_completion_tokens, + reasoning_effort: model.reasoning_effort.clone(), + }, + ); + } + + models + .into_values() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) + } +} + +pub struct OpenAiLanguageModel { + id: LanguageModelId, + model: open_ai::Model, + state: Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl OpenAiLanguageModel { + fn stream_completion( + &self, + request: open_ai::Request, + cx: &AsyncApp, + ) -> BoxFuture<'static, Result>>> + { + let http_client = self.http_client.clone(); + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) + }) else { + return future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { provider }); + }; + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); + let response = request.await?; + Ok(response) + }); + + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +impl LanguageModel for OpenAiLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn supports_tools(&self) -> bool { + true + } + + fn supports_images(&self) -> bool { + use open_ai::Model; + match &self.model { + Model::FourOmni + | Model::FourOmniMini + | Model::FourPointOne + | Model::FourPointOneMini + | Model::FourPointOneNano + | Model::Five + | Model::FiveMini + | Model::FiveNano + | Model::FivePointOne + | Model::O1 + | Model::O3 + | Model::O4Mini => true, + Model::ThreePointFiveTurbo + | Model::Four + | Model::FourTurbo + | Model::O3Mini + | Model::Custom { .. } => false, + } + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } + } + + fn telemetry_id(&self) -> String { + format!("openai/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_open_ai_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let request = into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.model.supports_prompt_cache_key(), + self.max_output_tokens(), + self.model.reasoning_effort(), + ); + let completions = self.stream_completion(request, cx); + async move { + let mapper = OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() + } +} + pub fn into_open_ai( request: LanguageModelRequest, model_id: &str, @@ -112,6 +441,7 @@ pub fn into_open_ai( temperature: request.temperature.or(Some(1.0)), max_completion_tokens: max_output_tokens, parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { + // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. Some(false) } else { None @@ -191,7 +521,6 @@ impl OpenAiEventMapper { events: Pin>>>, ) -> impl Stream> { - use futures::StreamExt; events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), @@ -319,12 +648,19 @@ pub fn count_open_ai_tokens( match model { Model::Custom { max_tokens, .. } => { let model = if max_tokens >= 100_000 { + // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o "gpt-4o" } else { + // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are + // supported with this tiktoken method "gpt-4" }; tiktoken_rs::num_tokens_from_messages(model, &messages) } + // Currently supported by tiktoken_rs + // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch + // arm with an override. We enumerate all supported models here so that we can check if new + // models are supported yet or not. Model::ThreePointFiveTurbo | Model::Four | Model::FourTurbo @@ -339,7 +675,7 @@ pub fn count_open_ai_tokens( | Model::O4Mini | Model::Five | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), + | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), // GPT-5.1 doesn't have tiktoken support yet; fall back on gpt-4o tokenizer Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), } .map(|tokens| tokens as u64) @@ -347,11 +683,191 @@ pub fn count_open_ai_tokens( .boxed() } +struct ConfigurationView { + api_key_editor: Entity, + state: Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + InputField::new( + window, + cx, + "sk-000000000000000000000000000000000000000000000000", + ) + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + // We don't log an error, because "not signed in" is also an error. + let _ = task.await; + } + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { + return; + } + + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + let configured_card_label = if env_var_set { + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") + } else { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + if api_url == OPEN_AI_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", api_url) + } + }; + + let api_key_section = if self.should_render_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create one by visiting", + Some("OpenAI's console"), + Some("https://platform.openai.com/api-keys"), + )) + .child(InstructionListItem::text_only( + "Ensure your OpenAI account has credits", + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new(format!( + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + Label::new( + "Note that having a subscription for another service like GitHub Copilot won't work.", + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any_element() + } else { + ConfiguredApiCard::new(configured_card_label) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) + .when(env_var_set, |this| { + this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")) + }) + .into_any_element() + }; + + let compatible_api_section = h_flex() + .mt_1p5() + .gap_0p5() + .flex_wrap() + .when(self.should_render_editor(cx), |this| { + this.pt_1p5() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + }) + .child( + h_flex() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child(Label::new("Zed also supports OpenAI-compatible models.")), + ) + .child( + Button::new("docs", "Learn More") + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .on_click(move |_, _window, cx| { + cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible") + }), + ); + + if self.load_credentials_task.is_some() { + div().child(Label::new("Loading credentials…")).into_any() + } else { + v_flex() + .size_full() + .child(api_key_section) + .child(compatible_api_section) + .into_any() + } + } +} + #[cfg(test)] mod tests { use gpui::TestAppContext; use language_model::LanguageModelRequestMessage; - use strum::IntoEnumIterator; use super::*; @@ -375,6 +891,7 @@ mod tests { thinking_allowed: true, }; + // Validate that all models are supported by tiktoken-rs for model in Model::iter() { let count = cx .executor() diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs new file mode 100644 index 0000000000000000000000000000000000000000..7b10ebf963033603ede691fa72d2fa523bcdbab9 --- /dev/null +++ b/crates/language_models/src/provider/open_router.rs @@ -0,0 +1,1095 @@ +use anyhow::{Result, anyhow}; +use collections::HashMap; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task}; +use http_client::HttpClient; +use language_model::{ + AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, + LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, +}; +use open_router::{ + Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models, +}; +use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsStore}; +use std::pin::Pin; +use std::str::FromStr as _; +use std::sync::{Arc, LazyLock}; +use ui::{List, prelude::*}; +use ui_input::InputField; +use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; + +use crate::ui::ConfiguredApiCard; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; + +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); + +const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenRouterSettings { + pub api_url: String, + pub available_models: Vec, +} + +pub struct OpenRouterLanguageModelProvider { + http_client: Arc, + state: Entity, +} + +pub struct State { + api_key_state: ApiKeyState, + http_client: Arc, + available_models: Vec, + fetch_models_task: Option>>, +} + +impl State { + fn is_authenticated(&self) -> bool { + self.api_key_state.has_key() + } + + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) + } + + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let task = self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); + + cx.spawn(async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result + }) + } + + fn fetch_models( + &mut self, + cx: &mut Context, + ) -> Task> { + let http_client = self.http_client.clone(); + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let Some(api_key) = self.api_key_state.key(&api_url) else { + return Task::ready(Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + })); + }; + cx.spawn(async move |this, cx| { + let models = list_models(http_client.as_ref(), &api_url, &api_key) + .await + .map_err(|e| { + LanguageModelCompletionError::Other(anyhow::anyhow!( + "OpenRouter error: {:?}", + e + )) + })?; + + this.update(cx, |this, cx| { + this.available_models = models; + cx.notify(); + }) + .map_err(|e| LanguageModelCompletionError::Other(e))?; + + Ok(()) + }) + } + + fn restart_fetch_models_task(&mut self, cx: &mut Context) { + if self.is_authenticated() { + let task = self.fetch_models(cx); + self.fetch_models_task.replace(task); + } else { + self.available_models = Vec::new(); + } + } +} + +impl OpenRouterLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| { + cx.observe_global::({ + let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone(); + move |this: &mut State, cx| { + let current_settings = OpenRouterLanguageModelProvider::settings(cx); + let settings_changed = current_settings != &last_settings; + if settings_changed { + last_settings = current_settings.clone(); + this.authenticate(cx).detach(); + cx.notify(); + } + } + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + } + }); + + Self { http_client, state } + } + + fn settings(cx: &App) -> &OpenRouterSettings { + &crate::AllLanguageModelSettings::get_global(cx).open_router + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + OPEN_ROUTER_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } + + fn create_language_model(&self, model: open_router::Model) -> Arc { + Arc::new(OpenRouterLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } +} + +impl LanguageModelProviderState for OpenRouterLanguageModelProvider { + type ObservableEntity = State; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } +} + +impl LanguageModelProvider for OpenRouterLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn icon(&self) -> IconName { + IconName::AiOpenRouter + } + + fn default_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default())) + } + + fn default_fast_model(&self, _cx: &App) -> Option> { + Some(self.create_language_model(open_router::Model::default_fast())) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let mut models_from_api = self.state.read(cx).available_models.clone(); + let mut settings_models = Vec::new(); + + for model in &Self::settings(cx).available_models { + settings_models.push(open_router::Model { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: model.supports_tools, + supports_images: model.supports_images, + mode: model.mode.unwrap_or_default(), + provider: model.provider.clone(), + }); + } + + for settings_model in &settings_models { + if let Some(pos) = models_from_api + .iter() + .position(|m| m.name == settings_model.name) + { + models_from_api[pos] = settings_model.clone(); + } else { + models_from_api.push(settings_model.clone()); + } + } + + models_from_api + .into_iter() + .map(|model| self.create_language_model(model)) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated() + } + + fn authenticate(&self, cx: &mut App) -> Task> { + self.state.update(cx, |state, cx| state.authenticate(cx)) + } + + fn configuration_view( + &self, + _target_agent: language_model::ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) + } +} + +pub struct OpenRouterLanguageModel { + id: LanguageModelId, + model: open_router::Model, + state: Entity, + http_client: Arc, + request_limiter: RateLimiter, +} + +impl OpenRouterLanguageModel { + fn stream_completion( + &self, + request: open_router::Request, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) + }) else { + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); + }; + + async move { + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; + let request = + open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + request.await.map_err(Into::into) + } + .boxed() + } +} + +impl LanguageModel for OpenRouterLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + PROVIDER_ID + } + + fn provider_name(&self) -> LanguageModelProviderName { + PROVIDER_NAME + } + + fn supports_tools(&self) -> bool { + self.model.supports_tool_calls() + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + let model_id = self.model.id().trim().to_lowercase(); + if model_id.contains("gemini") || model_id.contains("grok") { + LanguageModelToolSchemaFormat::JsonSchemaSubset + } else { + LanguageModelToolSchemaFormat::JsonSchema + } + } + + fn telemetry_id(&self) -> String { + format!("openrouter/{}", self.model.id()) + } + + fn max_token_count(&self) -> u64 { + self.model.max_token_count() + } + + fn max_output_tokens(&self) -> Option { + self.model.max_output_tokens() + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => true, + LanguageModelToolChoice::Any => true, + LanguageModelToolChoice::None => true, + } + } + + fn supports_images(&self) -> bool { + self.model.supports_images.unwrap_or(false) + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + count_open_router_tokens(request, self.model.clone(), cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let request = into_open_router(request, &self.model, self.max_output_tokens()); + let request = self.stream_completion(request, cx); + let future = self.request_limiter.stream(async move { + let response = request.await?; + Ok(OpenRouterEventMapper::new().map_stream(response)) + }); + async move { Ok(future.await?.boxed()) }.boxed() + } +} + +pub fn into_open_router( + request: LanguageModelRequest, + model: &Model, + max_output_tokens: Option, +) -> open_router::Request { + let mut messages = Vec::new(); + for message in request.messages { + let reasoning_details = message.reasoning_details.clone(); + for content in message.content { + match content { + MessageContent::Text(text) => add_message_content_part( + open_router::MessagePart::Text { text }, + message.role, + &mut messages, + ), + MessageContent::Thinking { .. } => {} + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + open_router::MessagePart::Image { + image_url: image.to_base64_url(), + }, + message.role, + &mut messages, + ); + } + MessageContent::ToolUse(tool_use) => { + let tool_call = open_router::ToolCall { + id: tool_use.id.to_string(), + content: open_router::ToolCallContent::Function { + function: open_router::FunctionContent { + name: tool_use.name.to_string(), + arguments: serde_json::to_string(&tool_use.input) + .unwrap_or_default(), + thought_signature: tool_use.thought_signature.clone(), + }, + }, + }; + + if let Some(open_router::RequestMessage::Assistant { + tool_calls, + reasoning_details: existing_reasoning, + .. + }) = messages.last_mut() + { + tool_calls.push(tool_call); + if existing_reasoning.is_none() && reasoning_details.is_some() { + *existing_reasoning = reasoning_details.clone(); + } + } else { + messages.push(open_router::RequestMessage::Assistant { + content: None, + tool_calls: vec![tool_call], + reasoning_details: reasoning_details.clone(), + }); + } + } + MessageContent::ToolResult(tool_result) => { + let content = match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + vec![open_router::MessagePart::Text { + text: text.to_string(), + }] + } + LanguageModelToolResultContent::Image(image) => { + vec![open_router::MessagePart::Image { + image_url: image.to_base64_url(), + }] + } + }; + + messages.push(open_router::RequestMessage::Tool { + content: content.into(), + tool_call_id: tool_result.tool_use_id.to_string(), + }); + } + } + } + } + + open_router::Request { + model: model.id().into(), + messages, + stream: true, + stop: request.stop, + temperature: request.temperature.unwrap_or(0.4), + max_tokens: max_output_tokens, + parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { + Some(false) + } else { + None + }, + usage: open_router::RequestUsage { include: true }, + reasoning: if request.thinking_allowed + && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode + { + Some(open_router::Reasoning { + effort: None, + max_tokens: budget_tokens, + exclude: Some(false), + enabled: Some(true), + }) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| open_router::ToolDefinition::Function { + function: open_router::FunctionDefinition { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + }, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto, + LanguageModelToolChoice::Any => open_router::ToolChoice::Required, + LanguageModelToolChoice::None => open_router::ToolChoice::None, + }), + provider: model.provider.clone(), + } +} + +fn add_message_content_part( + new_part: open_router::MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(open_router::RequestMessage::User { content })) + | (Role::System, Some(open_router::RequestMessage::System { content })) => { + content.push_part(new_part); + } + ( + Role::Assistant, + Some(open_router::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => open_router::RequestMessage::User { + content: open_router::MessageContent::from(vec![new_part]), + }, + Role::Assistant => open_router::RequestMessage::Assistant { + content: Some(open_router::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + reasoning_details: None, + }, + Role::System => open_router::RequestMessage::System { + content: open_router::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + +pub struct OpenRouterEventMapper { + tool_calls_by_index: HashMap, + reasoning_details: Option, +} + +impl OpenRouterEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + reasoning_details: None, + } + } + + pub fn map_stream( + mut self, + events: Pin< + Box< + dyn Send + Stream>, + >, + >, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(error.into())], + }) + }) + } + + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::from(anyhow!( + "Response contained no choices" + )))]; + }; + + let mut events = Vec::new(); + + if let Some(details) = choice.delta.reasoning_details.clone() { + // Emit reasoning_details immediately + events.push(Ok(LanguageModelCompletionEvent::ReasoningDetails( + details.clone(), + ))); + self.reasoning_details = Some(details); + } + + if let Some(reasoning) = choice.delta.reasoning.clone() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: reasoning, + signature: None, + })); + } + + if let Some(content) = choice.delta.content.clone() { + // OpenRouter send empty content string with the reasoning content + // This is a workaround for the OpenRouter API bug + if !content.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } + } + + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } + + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; + } + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); + } + + if let Some(signature) = function.thought_signature.clone() { + entry.thought_signature = Some(signature); + } + } + } + } + + if let Some(usage) = event.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + + match choice.finish_reason.as_deref() { + Some("stop") => { + // Don't emit reasoning_details here - already emitted immediately when captured + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + thought_signature: tool_call.thought_signature.clone(), + }, + )), + Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: tool_call.id.clone().into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.clone().into(), + json_parse_error: error.to_string(), + }), + } + })); + + // Don't emit reasoning_details here - already emitted immediately when captured + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",); + // Don't emit reasoning_details here - already emitted immediately when captured + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, + thought_signature: Option, +} + +pub fn count_open_router_tokens( + request: LanguageModelRequest, + _model: open_router::Model, + cx: &App, +) -> BoxFuture<'static, Result> { + cx.background_spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.string_contents()), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) + }) + .boxed() +} + +struct ConfigurationView { + api_key_editor: Entity, + state: Entity, + load_credentials_task: Option>, +} + +impl ConfigurationView { + fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + InputField::new( + window, + cx, + "sk_or_000000000000000000000000000000000000000000000000", + ) + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + + let load_credentials_task = Some(cx.spawn_in(window, { + let state = state.clone(); + async move |this, cx| { + if let Some(task) = state + .update(cx, |state, cx| state.authenticate(cx)) + .log_err() + { + let _ = task.await; + } + + this.update(cx, |this, cx| { + this.load_credentials_task = None; + cx.notify(); + }) + .log_err(); + } + })); + + Self { + api_key_editor, + state, + load_credentials_task, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { + return; + } + + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn should_render_editor(&self, cx: &mut Context) -> bool { + !self.state.read(cx).is_authenticated() + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); + let configured_card_label = if env_var_set { + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") + } else { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + if api_url == OPEN_ROUTER_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", api_url) + } + }; + + if self.load_credentials_task.is_some() { + div() + .child(Label::new("Loading credentials...")) + .into_any_element() + } else if self.should_render_editor(cx) { + v_flex() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) + .child( + List::new() + .child(InstructionListItem::new( + "Create an API key by visiting", + Some("OpenRouter's console"), + Some("https://openrouter.ai/keys"), + )) + .child(InstructionListItem::text_only( + "Ensure your OpenRouter account has credits", + )) + .child(InstructionListItem::text_only( + "Paste your API key below and hit enter to start using the assistant", + )), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), + ) + .size(LabelSize::Small).color(Color::Muted), + ) + .into_any_element() + } else { + ConfiguredApiCard::new(configured_card_label) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) + .when(env_var_set, |this| { + this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")) + }) + .into_any_element() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use open_router::{ChoiceDelta, FunctionChunk, ResponseMessageDelta, ToolCallChunk}; + + #[gpui::test] + async fn test_reasoning_details_preservation_with_tool_calls() { + // This test verifies that reasoning_details are properly captured and preserved + // when a model uses tool calling with reasoning/thinking tokens. + // + // The key regression this prevents: + // - OpenRouter sends multiple reasoning_details updates during streaming + // - First with actual content (encrypted reasoning data) + // - Then with empty array on completion + // - We must NOT overwrite the real data with the empty array + + let mut mapper = OpenRouterEventMapper::new(); + + // Simulate the streaming events as they come from OpenRouter/Gemini + let events = vec![ + // Event 1: Initial reasoning details with text + ResponseStreamEvent { + id: Some("response_123".into()), + created: 1234567890, + model: "google/gemini-3-pro-preview".into(), + choices: vec![ChoiceDelta { + index: 0, + delta: ResponseMessageDelta { + role: None, + content: None, + reasoning: None, + tool_calls: None, + reasoning_details: Some(serde_json::json!([ + { + "type": "reasoning.text", + "text": "Let me analyze this request...", + "format": "google-gemini-v1", + "index": 0 + } + ])), + }, + finish_reason: None, + }], + usage: None, + }, + // Event 2: More reasoning details + ResponseStreamEvent { + id: Some("response_123".into()), + created: 1234567890, + model: "google/gemini-3-pro-preview".into(), + choices: vec![ChoiceDelta { + index: 0, + delta: ResponseMessageDelta { + role: None, + content: None, + reasoning: None, + tool_calls: None, + reasoning_details: Some(serde_json::json!([ + { + "type": "reasoning.encrypted", + "data": "EtgDCtUDAdHtim9OF5jm4aeZSBAtl/randomized123", + "format": "google-gemini-v1", + "index": 0, + "id": "tool_call_abc123" + } + ])), + }, + finish_reason: None, + }], + usage: None, + }, + // Event 3: Tool call starts + ResponseStreamEvent { + id: Some("response_123".into()), + created: 1234567890, + model: "google/gemini-3-pro-preview".into(), + choices: vec![ChoiceDelta { + index: 0, + delta: ResponseMessageDelta { + role: None, + content: None, + reasoning: None, + tool_calls: Some(vec![ToolCallChunk { + index: 0, + id: Some("tool_call_abc123".into()), + function: Some(FunctionChunk { + name: Some("list_directory".into()), + arguments: Some("{\"path\":\"test\"}".into()), + thought_signature: Some("sha256:test_signature_xyz789".into()), + }), + }]), + reasoning_details: None, + }, + finish_reason: None, + }], + usage: None, + }, + // Event 4: Empty reasoning_details on tool_calls finish + // This is the critical event - we must not overwrite with this empty array! + ResponseStreamEvent { + id: Some("response_123".into()), + created: 1234567890, + model: "google/gemini-3-pro-preview".into(), + choices: vec![ChoiceDelta { + index: 0, + delta: ResponseMessageDelta { + role: None, + content: None, + reasoning: None, + tool_calls: None, + reasoning_details: Some(serde_json::json!([])), + }, + finish_reason: Some("tool_calls".into()), + }], + usage: None, + }, + ]; + + // Process all events + let mut collected_events = Vec::new(); + for event in events { + let mapped = mapper.map_event(event); + collected_events.extend(mapped); + } + + // Verify we got the expected events + let mut has_tool_use = false; + let mut reasoning_details_events = Vec::new(); + let mut thought_signature_value = None; + + for event_result in collected_events { + match event_result { + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + has_tool_use = true; + assert_eq!(tool_use.id.to_string(), "tool_call_abc123"); + assert_eq!(tool_use.name.as_ref(), "list_directory"); + thought_signature_value = tool_use.thought_signature.clone(); + } + Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => { + reasoning_details_events.push(details); + } + _ => {} + } + } + + // Assertions + assert!(has_tool_use, "Should have emitted ToolUse event"); + assert!( + !reasoning_details_events.is_empty(), + "Should have emitted ReasoningDetails events" + ); + + // We should have received multiple reasoning_details events (text, encrypted, empty) + // The agent layer is responsible for keeping only the first non-empty one + assert!( + reasoning_details_events.len() >= 2, + "Should have multiple reasoning_details events from streaming" + ); + + // Verify at least one contains the encrypted data + let has_encrypted = reasoning_details_events.iter().any(|details| { + if let serde_json::Value::Array(arr) = details { + arr.iter().any(|item| { + item["type"] == "reasoning.encrypted" + && item["data"] + .as_str() + .map_or(false, |s| s.contains("EtgDCtUDAdHtim9OF5jm4aeZSBAtl")) + }) + } else { + false + } + }); + assert!( + has_encrypted, + "Should have at least one reasoning_details with encrypted data" + ); + + // Verify thought_signature was captured + assert!( + thought_signature_value.is_some(), + "Tool use should have thought_signature" + ); + assert_eq!( + thought_signature_value.unwrap(), + "sha256:test_signature_xyz789" + ); + } + + #[gpui::test] + async fn test_agent_prevents_empty_reasoning_details_overwrite() { + // This test verifies that the agent layer prevents empty reasoning_details + // from overwriting non-empty ones, even though the mapper emits all events. + + // Simulate what the agent does when it receives multiple ReasoningDetails events + let mut agent_reasoning_details: Option = None; + + let events = vec![ + // First event: non-empty reasoning_details + serde_json::json!([ + { + "type": "reasoning.encrypted", + "data": "real_data_here", + "format": "google-gemini-v1" + } + ]), + // Second event: empty array (should not overwrite) + serde_json::json!([]), + ]; + + for details in events { + // This mimics the agent's logic: only store if we don't already have it + if agent_reasoning_details.is_none() { + agent_reasoning_details = Some(details); + } + } + + // Verify the agent kept the first non-empty reasoning_details + assert!(agent_reasoning_details.is_some()); + let final_details = agent_reasoning_details.unwrap(); + if let serde_json::Value::Array(arr) = &final_details { + assert!( + !arr.is_empty(), + "Agent should have kept the non-empty reasoning_details" + ); + assert_eq!(arr[0]["data"], "real_data_here"); + } else { + panic!("Expected array"); + } + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 9c029b6aa7a8588250fb95b1f429da20164ce7cd..43a8e7334a744c84d6edfae3ffc97115eb8f51b2 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -4,22 +4,16 @@ use collections::HashMap; use settings::RegisterSetting; use crate::provider::{ - bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, deepseek::DeepSeekSettings, - google::GoogleSettings, lmstudio::LmStudioSettings, mistral::MistralSettings, - ollama::OllamaSettings, open_ai::OpenAiSettings, open_ai_compatible::OpenAiCompatibleSettings, + anthropic::AnthropicSettings, bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, + deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, + mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, + open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings, vercel::VercelSettings, x_ai::XAiSettings, }; -pub use settings::OpenRouterAvailableModel as AvailableModel; - -#[derive(Default, Clone, Debug, PartialEq)] -pub struct OpenRouterSettings { - pub api_url: String, - pub available_models: Vec, -} - #[derive(Debug, RegisterSetting)] pub struct AllLanguageModelSettings { + pub anthropic: AnthropicSettings, pub bedrock: AmazonBedrockSettings, pub deepseek: DeepSeekSettings, pub google: GoogleSettings, @@ -39,6 +33,7 @@ impl settings::Settings for AllLanguageModelSettings { fn from_settings(content: &settings::SettingsContent) -> Self { let language_models = content.language_models.clone().unwrap(); + let anthropic = language_models.anthropic.unwrap(); let bedrock = language_models.bedrock.unwrap(); let deepseek = language_models.deepseek.unwrap(); let google = language_models.google.unwrap(); @@ -52,12 +47,16 @@ impl settings::Settings for AllLanguageModelSettings { let x_ai = language_models.x_ai.unwrap(); let zed_dot_dev = language_models.zed_dot_dev.unwrap(); Self { + anthropic: AnthropicSettings { + api_url: anthropic.api_url.unwrap(), + available_models: anthropic.available_models.unwrap_or_default(), + }, bedrock: AmazonBedrockSettings { available_models: bedrock.available_models.unwrap_or_default(), region: bedrock.region, - endpoint: bedrock.endpoint_url, + endpoint: bedrock.endpoint_url, // todo(should be api_url) profile_name: bedrock.profile, - role_arn: None, + role_arn: None, // todo(was never a setting for this...) authentication_method: bedrock.authentication_method.map(Into::into), allow_global: bedrock.allow_global, }, diff --git a/crates/settings/src/settings_content/language_model.rs b/crates/settings/src/settings_content/language_model.rs index f99e1687130d8046b812700cbb2dc33b00f8d881..48f5a463a4b8d896885d9ba5b7d804d16ecb5b6b 100644 --- a/crates/settings/src/settings_content/language_model.rs +++ b/crates/settings/src/settings_content/language_model.rs @@ -8,6 +8,7 @@ use std::sync::Arc; #[with_fallible_options] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] pub struct AllLanguageModelSettingsContent { + pub anthropic: Option, pub bedrock: Option, pub deepseek: Option, pub google: Option, @@ -23,6 +24,35 @@ pub struct AllLanguageModelSettingsContent { pub zed_dot_dev: Option, } +#[with_fallible_options] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] +pub struct AnthropicSettingsContent { + pub api_url: Option, + pub available_models: Option>, +} + +#[with_fallible_options] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom)] +pub struct AnthropicAvailableModel { + /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc + pub name: String, + /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. + pub display_name: Option, + /// The model's context window size. + pub max_tokens: u64, + /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling. + pub tool_override: Option, + /// Configuration of Anthropic's caching API. + pub cache_configuration: Option, + pub max_output_tokens: Option, + #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] + pub default_temperature: Option, + #[serde(default)] + pub extra_beta_headers: Vec, + /// The model's mode (e.g. thinking) + pub mode: Option, +} + #[with_fallible_options] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] pub struct AmazonBedrockSettingsContent {