diff --git a/Cargo.lock b/Cargo.lock index 854a74f25adc1337da20667414e0495ff3d78911..9bb6a3322e613d717d0b3b2755c43fba0f7f3941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8915,7 +8915,6 @@ dependencies = [ "collections", "component", "convert_case 0.8.0", - "copilot", "credentials_provider", "deepseek", "editor", @@ -8926,7 +8925,6 @@ dependencies = [ "gpui", "gpui_tokio", "http_client", - "language", "language_model", "lmstudio", "log", @@ -8934,8 +8932,6 @@ dependencies = [ "mistral", "ollama", "open_ai", - "open_router", - "partial-json-fixer", "project", "release_channel", "schemars", @@ -11347,12 +11343,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "partial-json-fixer" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13" - [[package]] name = "password-hash" version = "0.4.2" diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index c1b2efb081551f82752dc15a909eec64ff78d94e..a14a486367e6a639b68651e0856f59138983c512 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -5,7 +5,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env}; use acp_thread::AgentConnection; use anyhow::{Context as _, Result}; use gpui::{App, SharedString, Task}; -use language_models::provider::google::GoogleLanguageModelProvider; +use language_models::api_key_for_gemini_cli; use project::agent_server_store::GEMINI_NAME; #[derive(Clone)] @@ -42,11 +42,7 @@ impl AgentServer for Gemini { cx.spawn(async move |cx| { extra_env.insert("SURFACE".to_owned(), "zed".to_owned()); - if let Some(api_key) = cx - .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)? - .await - .ok() - { + if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() { extra_env.insert("GEMINI_API_KEY".into(), api_key); } let (command, root_dir, login) = store diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index c0b89647a43030538af8d4db6ab80524048a1574..c52ea87df48bb88655b59e676e58d988061c57eb 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -4,6 +4,8 @@ mod copilot_migration; pub mod extension_settings; mod google_ai_migration; pub mod headless_host; +mod open_router_migration; +mod openai_migration; pub mod wasm_host; #[cfg(test)] @@ -893,6 +895,11 @@ impl ExtensionStore { copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx); anthropic_migration::migrate_anthropic_credentials_if_needed(&extension_id, cx); google_ai_migration::migrate_google_ai_credentials_if_needed(&extension_id, cx); + openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx); + open_router_migration::migrate_open_router_credentials_if_needed( + &extension_id, + cx, + ); }) .ok(); } diff --git a/crates/extension_host/src/open_router_migration.rs b/crates/extension_host/src/open_router_migration.rs new file mode 100644 index 0000000000000000000000000000000000000000..e80c13b45ddb97f59ad5213a9390f98cb5d8e0d0 --- /dev/null +++ b/crates/extension_host/src/open_router_migration.rs @@ -0,0 +1,157 @@ +use credentials_provider::CredentialsProvider; +use gpui::App; + +const OPEN_ROUTER_EXTENSION_ID: &str = "open-router"; +const OPEN_ROUTER_PROVIDER_ID: &str = "open-router"; +const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1"; + +pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) { + if extension_id != OPEN_ROUTER_EXTENSION_ID { + return; + } + + let extension_credential_key = format!( + "extension-llm-{}:{}", + OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID + ); + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + let existing_credential = credentials_provider + .read_credentials(&extension_credential_key, &cx) + .await + .ok() + .flatten(); + + if existing_credential.is_some() { + log::debug!("OpenRouter extension already has credentials, skipping migration"); + return; + } + + let old_credential = credentials_provider + .read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx) + .await + .ok() + .flatten(); + + let api_key = match old_credential { + Some((_, key_bytes)) => match String::from_utf8(key_bytes) { + Ok(key) => key, + Err(_) => { + log::error!("Failed to decode OpenRouter API key as UTF-8"); + return; + } + }, + None => { + log::debug!("No existing OpenRouter API key found to migrate"); + return; + } + }; + + log::info!("Migrating existing OpenRouter API key to OpenRouter extension"); + + match credentials_provider + .write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx) + .await + { + Ok(()) => { + log::info!("Successfully migrated OpenRouter API key to extension"); + } + Err(err) => { + log::error!("Failed to migrate OpenRouter API key: {}", err); + } + } + }) + .detach(); +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + + #[gpui::test] + async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) { + let api_key = "sk-or-test-key-12345"; + + cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes()); + + cx.update(|cx| { + migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let migrated = cx.read_credentials("extension-llm-open-router:open-router"); + assert!(migrated.is_some(), "Credentials should have been migrated"); + let (username, password) = migrated.unwrap(); + assert_eq!(username, "Bearer"); + assert_eq!(String::from_utf8(password).unwrap(), api_key); + } + + #[gpui::test] + async fn test_skips_migration_if_extension_already_has_credentials(cx: &mut TestAppContext) { + let old_api_key = "sk-or-old-key"; + let existing_key = "sk-or-existing-key"; + + cx.write_credentials( + OPEN_ROUTER_DEFAULT_API_URL, + "Bearer", + old_api_key.as_bytes(), + ); + cx.write_credentials( + "extension-llm-open-router:open-router", + "Bearer", + existing_key.as_bytes(), + ); + + cx.update(|cx| { + migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-open-router:open-router"); + let (_, password) = credentials.unwrap(); + assert_eq!( + String::from_utf8(password).unwrap(), + existing_key, + "Should not overwrite existing credentials" + ); + } + + #[gpui::test] + async fn test_skips_migration_if_no_old_credentials(cx: &mut TestAppContext) { + cx.update(|cx| { + migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-open-router:open-router"); + assert!( + credentials.is_none(), + "Should not create credentials if none existed" + ); + } + + #[gpui::test] + async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) { + let api_key = "sk-or-test-key"; + + cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes()); + + cx.update(|cx| { + migrate_open_router_credentials_if_needed("some-other-extension", cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-open-router:open-router"); + assert!( + credentials.is_none(), + "Should not migrate for other extensions" + ); + } +} diff --git a/crates/extension_host/src/openai_migration.rs b/crates/extension_host/src/openai_migration.rs new file mode 100644 index 0000000000000000000000000000000000000000..b535b9f8cef6c4b2039a606e3232e4fb89d3cc35 --- /dev/null +++ b/crates/extension_host/src/openai_migration.rs @@ -0,0 +1,153 @@ +use credentials_provider::CredentialsProvider; +use gpui::App; + +const OPENAI_EXTENSION_ID: &str = "openai"; +const OPENAI_PROVIDER_ID: &str = "openai"; +const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1"; + +pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) { + if extension_id != OPENAI_EXTENSION_ID { + return; + } + + let extension_credential_key = format!( + "extension-llm-{}:{}", + OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID + ); + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + let existing_credential = credentials_provider + .read_credentials(&extension_credential_key, &cx) + .await + .ok() + .flatten(); + + if existing_credential.is_some() { + log::debug!("OpenAI extension already has credentials, skipping migration"); + return; + } + + let old_credential = credentials_provider + .read_credentials(OPENAI_DEFAULT_API_URL, &cx) + .await + .ok() + .flatten(); + + let api_key = match old_credential { + Some((_, key_bytes)) => match String::from_utf8(key_bytes) { + Ok(key) => key, + Err(_) => { + log::error!("Failed to decode OpenAI API key as UTF-8"); + return; + } + }, + None => { + log::debug!("No existing OpenAI API key found to migrate"); + return; + } + }; + + log::info!("Migrating existing OpenAI API key to OpenAI extension"); + + match credentials_provider + .write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx) + .await + { + Ok(()) => { + log::info!("Successfully migrated OpenAI API key to extension"); + } + Err(err) => { + log::error!("Failed to migrate OpenAI API key: {}", err); + } + } + }) + .detach(); +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + + #[gpui::test] + async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) { + let api_key = "sk-test-key-12345"; + + cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes()); + + cx.update(|cx| { + migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let migrated = cx.read_credentials("extension-llm-openai:openai"); + assert!(migrated.is_some(), "Credentials should have been migrated"); + let (username, password) = migrated.unwrap(); + assert_eq!(username, "Bearer"); + assert_eq!(String::from_utf8(password).unwrap(), api_key); + } + + #[gpui::test] + async fn test_skips_migration_if_extension_already_has_credentials(cx: &mut TestAppContext) { + let old_api_key = "sk-old-key"; + let existing_key = "sk-existing-key"; + + cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", old_api_key.as_bytes()); + cx.write_credentials( + "extension-llm-openai:openai", + "Bearer", + existing_key.as_bytes(), + ); + + cx.update(|cx| { + migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-openai:openai"); + let (_, password) = credentials.unwrap(); + assert_eq!( + String::from_utf8(password).unwrap(), + existing_key, + "Should not overwrite existing credentials" + ); + } + + #[gpui::test] + async fn test_skips_migration_if_no_old_credentials(cx: &mut TestAppContext) { + cx.update(|cx| { + migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-openai:openai"); + assert!( + credentials.is_none(), + "Should not create credentials if none existed" + ); + } + + #[gpui::test] + async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) { + let api_key = "sk-test-key"; + + cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes()); + + cx.update(|cx| { + migrate_openai_credentials_if_needed("some-other-extension", cx); + }); + + cx.run_until_parked(); + + let credentials = cx.read_credentials("extension-llm-openai:openai"); + assert!( + credentials.is_none(), + "Should not migrate for other extensions" + ); + } +} diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index decb32c5aa400012bd21b397b5e7b359e0760b43..b9b354f2fad7ea9bc7573a6bb2af880f773d15e3 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -25,7 +25,6 @@ cloud_llm_client.workspace = true collections.workspace = true component.workspace = true convert_case.workspace = true -copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } extension.workspace = true @@ -35,7 +34,6 @@ google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true -language.workspace = true language_model.workspace = true lmstudio = { workspace = true, features = ["schemars"] } log.workspace = true @@ -43,8 +41,6 @@ menu.workspace = true mistral = { workspace = true, features = ["schemars"] } ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } -open_router = { workspace = true, features = ["schemars"] } -partial-json-fixer.workspace = true release_channel.workspace = true schemars.workspace = true semver.workspace = true diff --git a/crates/language_models/src/api_key.rs b/crates/language_models/src/api_key.rs index 122234b6ced6d0bf1b7a0d684683c841824ccd2d..20d83c9d95e90380f99731a1bcfb903bb8ab93e9 100644 --- a/crates/language_models/src/api_key.rs +++ b/crates/language_models/src/api_key.rs @@ -223,27 +223,13 @@ impl ApiKeyState { } impl ApiKey { - pub fn key(&self) -> &str { - &self.key - } - - pub fn from_env(env_var_name: SharedString, key: &str) -> Self { + fn from_env(env_var_name: SharedString, key: &str) -> Self { Self { source: ApiKeySource::EnvVar(env_var_name), key: key.into(), } } - pub async fn load_from_system_keychain( - url: &str, - credentials_provider: &dyn CredentialsProvider, - cx: &AsyncApp, - ) -> Result { - Self::load_from_system_keychain_impl(url, credentials_provider, cx) - .await - .into_authenticate_result() - } - async fn load_from_system_keychain_impl( url: &str, credentials_provider: &dyn CredentialsProvider, diff --git a/crates/language_models/src/google_ai_api_key.rs b/crates/language_models/src/google_ai_api_key.rs new file mode 100644 index 0000000000000000000000000000000000000000..300e0d11ac967aa07429d2eea6c174d8bfd7d79c --- /dev/null +++ b/crates/language_models/src/google_ai_api_key.rs @@ -0,0 +1,43 @@ +use anyhow::Result; +use credentials_provider::CredentialsProvider; +use gpui::{App, Task}; + +const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; +const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; +const GOOGLE_AI_EXTENSION_CREDENTIAL_KEY: &str = "extension-llm-google-ai:google-ai"; + +/// Returns the Google AI API key for use by the Gemini CLI. +/// +/// This function checks the following sources in order: +/// 1. `GEMINI_API_KEY` environment variable +/// 2. `GOOGLE_AI_API_KEY` environment variable +/// 3. Extension credential store (`extension-llm-google-ai:google-ai`) +pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { + if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR_NAME) { + if !key.is_empty() { + return Task::ready(Ok(key)); + } + } + + if let Ok(key) = std::env::var(GOOGLE_AI_API_KEY_VAR_NAME) { + if !key.is_empty() { + return Task::ready(Ok(key)); + } + } + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + let credential = credentials_provider + .read_credentials(GOOGLE_AI_EXTENSION_CREDENTIAL_KEY, &cx) + .await?; + + match credential { + Some((_, key_bytes)) => { + let key = String::from_utf8(key_bytes)?; + Ok(key) + } + None => Err(anyhow::anyhow!("No Google AI API key found")), + } + }) +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 8b8ca1e2912e2a65714ebf1f8d05d3d6a83ac3d4..b07fa39033159fc7ba2b33d168cc3bdf9217ef28 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -10,20 +10,19 @@ use provider::deepseek::DeepSeekLanguageModelProvider; mod api_key; mod extension; +mod google_ai_api_key; pub mod provider; mod settings; pub mod ui; +pub use google_ai_api_key::api_key_for_gemini_cli; + use crate::provider::bedrock::BedrockLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; -use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; -use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider; pub use crate::provider::mistral::MistralLanguageModelProvider; use crate::provider::ollama::OllamaLanguageModelProvider; -use crate::provider::open_ai::OpenAiLanguageModelProvider; use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider; -use crate::provider::open_router::OpenRouterLanguageModelProvider; use crate::provider::vercel::VercelLanguageModelProvider; use crate::provider::x_ai::XAiLanguageModelProvider; pub use crate::settings::*; @@ -118,10 +117,6 @@ fn register_language_model_providers( )), cx, ); - registry.register_provider( - Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)), - cx, - ); registry.register_provider( Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)), cx, @@ -134,10 +129,6 @@ fn register_language_model_providers( Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)), cx, ); - registry.register_provider( - Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)), - cx, - ); registry.register_provider( MistralLanguageModelProvider::global(client.http_client(), cx), cx, @@ -146,13 +137,6 @@ fn register_language_model_providers( Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)), cx, ); - registry.register_provider( - Arc::new(OpenRouterLanguageModelProvider::new( - client.http_client(), - cx, - )), - cx, - ); registry.register_provider( Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)), cx, @@ -161,5 +145,4 @@ fn register_language_model_providers( Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)), cx, ); - registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx); } diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index e585fc06f6b52313dd79dd3b307d0ab305817977..b5d10c1ede3d70e4d7dc8725131cf9e19a216ca3 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,6 +1,5 @@ pub mod bedrock; pub mod cloud; -pub mod copilot_chat; pub mod deepseek; pub mod google; pub mod lmstudio; @@ -8,6 +7,5 @@ pub mod mistral; pub mod ollama; pub mod open_ai; pub mod open_ai_compatible; -pub mod open_router; pub mod vercel; pub mod x_ai; diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs deleted file mode 100644 index 92ac342a39ff04ae42f5b01b5777a5d16563c37f..0000000000000000000000000000000000000000 --- a/crates/language_models/src/provider/copilot_chat.rs +++ /dev/null @@ -1,1565 +0,0 @@ -use std::pin::Pin; -use std::str::FromStr as _; -use std::sync::Arc; - -use anyhow::{Result, anyhow}; -use cloud_llm_client::CompletionIntent; -use collections::HashMap; -use copilot::copilot_chat::{ - ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl, - Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool, - ToolCall, -}; -use copilot::{Copilot, Status}; -use futures::future::BoxFuture; -use futures::stream::BoxStream; -use futures::{FutureExt, Stream, StreamExt}; -use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg}; -use http_client::StatusCode; -use language::language_settings::all_language_settings; -use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent, - LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role, - StopReason, TokenUsage, -}; -use settings::SettingsStore; -use ui::{CommonAnimationExt, prelude::*}; -use util::debug_panic; - -use crate::ui::ConfiguredApiCard; - -const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); -const PROVIDER_NAME: LanguageModelProviderName = - LanguageModelProviderName::new("GitHub Copilot Chat"); - -pub struct CopilotChatLanguageModelProvider { - state: Entity, -} - -pub struct State { - _copilot_chat_subscription: Option, - _settings_subscription: Subscription, -} - -impl State { - fn is_authenticated(&self, cx: &App) -> bool { - CopilotChat::global(cx) - .map(|m| m.read(cx).is_authenticated()) - .unwrap_or(false) - } -} - -impl CopilotChatLanguageModelProvider { - pub fn new(cx: &mut App) -> Self { - let state = cx.new(|cx| { - let copilot_chat_subscription = CopilotChat::global(cx) - .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify())); - State { - _copilot_chat_subscription: copilot_chat_subscription, - _settings_subscription: cx.observe_global::(|_, cx| { - if let Some(copilot_chat) = CopilotChat::global(cx) { - let language_settings = all_language_settings(None, cx); - let configuration = copilot::copilot_chat::CopilotChatConfiguration { - enterprise_uri: language_settings - .edit_predictions - .copilot - .enterprise_uri - .clone(), - }; - copilot_chat.update(cx, |chat, cx| { - chat.set_configuration(configuration, cx); - }); - } - cx.notify(); - }), - } - }); - - Self { state } - } - - fn create_language_model(&self, model: CopilotChatModel) -> Arc { - Arc::new(CopilotChatLanguageModel { - model, - request_limiter: RateLimiter::new(4), - }) - } -} - -impl LanguageModelProviderState for CopilotChatLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) - } -} - -impl LanguageModelProvider for CopilotChatLanguageModelProvider { - fn id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn icon(&self) -> IconName { - IconName::Copilot - } - - fn default_model(&self, cx: &App) -> Option> { - let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?; - models - .first() - .map(|model| self.create_language_model(model.clone())) - } - - fn default_fast_model(&self, cx: &App) -> Option> { - // The default model should be Copilot Chat's 'base model', which is likely a relatively fast - // model (e.g. 4o) and a sensible choice when considering premium requests - self.default_model(cx) - } - - fn provided_models(&self, cx: &App) -> Vec> { - let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else { - return Vec::new(); - }; - models - .iter() - .map(|model| self.create_language_model(model.clone())) - .collect() - } - - fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated(cx) - } - - fn authenticate(&self, cx: &mut App) -> Task> { - if self.is_authenticated(cx) { - return Task::ready(Ok(())); - }; - - let Some(copilot) = Copilot::global(cx) else { - return Task::ready(Err(anyhow!(concat!( - "Copilot must be enabled for Copilot Chat to work. ", - "Please enable Copilot and try again." - )) - .into())); - }; - - let err = match copilot.read(cx).status() { - Status::Authorized => return Task::ready(Ok(())), - Status::Disabled => anyhow!( - "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." - ), - Status::Error(err) => anyhow!(format!( - "Received the following error while signing into Copilot: {err}" - )), - Status::Starting { task: _ } => anyhow!( - "Copilot is still starting, please wait for Copilot to start then try again" - ), - Status::Unauthorized => anyhow!( - "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription." - ), - Status::SignedOut { .. } => { - anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.") - } - Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."), - }; - - Task::ready(Err(err.into())) - } - - fn configuration_view( - &self, - _target_agent: language_model::ConfigurationViewTargetAgent, - _: &mut Window, - cx: &mut App, - ) -> AnyView { - let state = self.state.clone(); - cx.new(|cx| ConfigurationView::new(state, cx)).into() - } - - fn reset_credentials(&self, _cx: &mut App) -> Task> { - Task::ready(Err(anyhow!( - "Signing out of GitHub Copilot Chat is currently not supported." - ))) - } -} - -fn collect_tiktoken_messages( - request: LanguageModelRequest, -) -> Vec { - request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>() -} - -pub struct CopilotChatLanguageModel { - model: CopilotChatModel, - request_limiter: RateLimiter, -} - -impl LanguageModel for CopilotChatLanguageModel { - fn id(&self) -> LanguageModelId { - LanguageModelId::from(self.model.id().to_string()) - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name().to_string()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn supports_tools(&self) -> bool { - self.model.supports_tools() - } - - fn supports_images(&self) -> bool { - self.model.supports_vision() - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - match self.model.vendor() { - ModelVendor::OpenAI | ModelVendor::Anthropic => { - LanguageModelToolSchemaFormat::JsonSchema - } - ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } - } - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => self.supports_tools(), - } - } - - fn telemetry_id(&self) -> String { - format!("copilot_chat/{}", self.model.id()) - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model = self.model.clone(); - cx.background_spawn(async move { - let messages = collect_tiktoken_messages(request); - // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor). - let tokenizer_model = match model.tokenizer() { - Some("o200k_base") => "gpt-4o", - Some("cl100k_base") => "gpt-4", - _ => "gpt-4o", - }; - - tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages) - .map(|tokens| tokens as u64) - }) - .boxed() - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let is_user_initiated = request.intent.is_none_or(|intent| match intent { - CompletionIntent::UserPrompt - | CompletionIntent::ThreadContextSummarization - | CompletionIntent::InlineAssist - | CompletionIntent::TerminalInlineAssist - | CompletionIntent::GenerateGitCommitMessage => true, - - CompletionIntent::ToolResults - | CompletionIntent::ThreadSummarization - | CompletionIntent::CreateFile - | CompletionIntent::EditFile => false, - }); - - if self.model.supports_response() { - let responses_request = into_copilot_responses(&self.model, request); - let request_limiter = self.request_limiter.clone(); - let future = cx.spawn(async move |cx| { - let request = - CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone()); - request_limiter - .stream(async move { - let stream = request.await?; - let mapper = CopilotResponsesEventMapper::new(); - Ok(mapper.map_stream(stream).boxed()) - }) - .await - }); - return async move { Ok(future.await?.boxed()) }.boxed(); - } - - let copilot_request = match into_copilot_chat(&self.model, request) { - Ok(request) => request, - Err(err) => return futures::future::ready(Err(err.into())).boxed(), - }; - let is_streaming = copilot_request.stream; - - let request_limiter = self.request_limiter.clone(); - let future = cx.spawn(async move |cx| { - let request = - CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone()); - request_limiter - .stream(async move { - let response = request.await?; - Ok(map_to_language_model_completion_events( - response, - is_streaming, - )) - }) - .await - }); - async move { Ok(future.await?.boxed()) }.boxed() - } -} - -pub fn map_to_language_model_completion_events( - events: Pin>>>, - is_streaming: bool, -) -> impl Stream> { - #[derive(Default)] - struct RawToolCall { - id: String, - name: String, - arguments: String, - thought_signature: Option, - } - - struct State { - events: Pin>>>, - tool_calls_by_index: HashMap, - reasoning_opaque: Option, - reasoning_text: Option, - } - - futures::stream::unfold( - State { - events, - tool_calls_by_index: HashMap::default(), - reasoning_opaque: None, - reasoning_text: None, - }, - move |mut state| async move { - if let Some(event) = state.events.next().await { - match event { - Ok(event) => { - let Some(choice) = event.choices.first() else { - return Some(( - vec![Err(anyhow!("Response contained no choices").into())], - state, - )); - }; - - let delta = if is_streaming { - choice.delta.as_ref() - } else { - choice.message.as_ref() - }; - - let Some(delta) = delta else { - return Some(( - vec![Err(anyhow!("Response contained no delta").into())], - state, - )); - }; - - let mut events = Vec::new(); - if let Some(content) = delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - - // Capture reasoning data from the delta (e.g. for Gemini 3) - if let Some(opaque) = delta.reasoning_opaque.clone() { - state.reasoning_opaque = Some(opaque); - } - if let Some(text) = delta.reasoning_text.clone() { - state.reasoning_text = Some(text); - } - - for (index, tool_call) in delta.tool_calls.iter().enumerate() { - let tool_index = tool_call.index.unwrap_or(index); - let entry = state.tool_calls_by_index.entry(tool_index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - - if let Some(thought_signature) = function.thought_signature.clone() - { - entry.thought_signature = Some(thought_signature); - } - } - } - - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, - ))); - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - Some("tool_calls") => { - // Gemini 3 models send reasoning_opaque/reasoning_text that must - // be preserved and sent back in subsequent requests. Emit as - // ReasoningDetails so the agent stores it in the message. - if state.reasoning_opaque.is_some() - || state.reasoning_text.is_some() - { - let mut details = serde_json::Map::new(); - if let Some(opaque) = state.reasoning_opaque.take() { - details.insert( - "reasoning_opaque".to_string(), - serde_json::Value::String(opaque), - ); - } - if let Some(text) = state.reasoning_text.take() { - details.insert( - "reasoning_text".to_string(), - serde_json::Value::String(text), - ); - } - events.push(Ok( - LanguageModelCompletionEvent::ReasoningDetails( - serde_json::Value::Object(details), - ), - )); - } - - events.extend(state.tool_calls_by_index.drain().map( - |(_, tool_call)| { - // The model can output an empty string - // to indicate the absence of arguments. - // When that happens, create an empty - // object instead. - let arguments = if tool_call.arguments.is_empty() { - Ok(serde_json::Value::Object(Default::default())) - } else { - serde_json::Value::from_str(&tool_call.arguments) - }; - match arguments { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments, - thought_signature: tool_call.thought_signature, - }, - )), - Err(error) => Ok( - LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - }, - ), - } - }, - )); - - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::ToolUse, - ))); - } - Some(stop_reason) => { - log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}"); - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - None => {} - } - - return Some((events, state)); - } - Err(err) => return Some((vec![Err(anyhow!(err).into())], state)), - } - } - - None - }, - ) - .flat_map(futures::stream::iter) -} - -pub struct CopilotResponsesEventMapper { - pending_stop_reason: Option, -} - -impl CopilotResponsesEventMapper { - pub fn new() -> Self { - Self { - pending_stop_reason: None, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], - }) - }) - } - - fn map_event( - &mut self, - event: copilot::copilot_responses::StreamEvent, - ) -> Vec> { - match event { - copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item { - copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => { - vec![Ok(LanguageModelCompletionEvent::StartMessage { - message_id: id, - })] - } - _ => Vec::new(), - }, - - copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => { - if delta.is_empty() { - Vec::new() - } else { - vec![Ok(LanguageModelCompletionEvent::Text(delta))] - } - } - - copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item { - copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(), - copilot::copilot_responses::ResponseOutputItem::FunctionCall { - call_id, - name, - arguments, - thought_signature, - .. - } => { - let mut events = Vec::new(); - match serde_json::from_str::(&arguments) { - Ok(input) => events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: call_id.into(), - name: name.as_str().into(), - is_input_complete: true, - input, - raw_input: arguments.clone(), - thought_signature, - }, - ))), - Err(error) => { - events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: call_id.into(), - tool_name: name.as_str().into(), - raw_input: arguments.clone().into(), - json_parse_error: error.to_string(), - })) - } - } - // Record that we already emitted a tool-use stop so we can avoid duplicating - // a Stop event on Completed. - self.pending_stop_reason = Some(StopReason::ToolUse); - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - events - } - copilot::copilot_responses::ResponseOutputItem::Reasoning { - summary, - encrypted_content, - .. - } => { - let mut events = Vec::new(); - - if let Some(blocks) = summary { - let mut text = String::new(); - for block in blocks { - text.push_str(&block.text); - } - if !text.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text, - signature: None, - })); - } - } - - if let Some(data) = encrypted_content { - events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data })); - } - - events - } - }, - - copilot::copilot_responses::StreamEvent::Completed { response } => { - let mut events = Vec::new(); - if let Some(usage) = response.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - if self.pending_stop_reason.take() != Some(StopReason::ToolUse) { - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - events - } - - copilot::copilot_responses::StreamEvent::Incomplete { response } => { - let reason = response - .incomplete_details - .as_ref() - .and_then(|details| details.reason.as_ref()); - let stop_reason = match reason { - Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => { - StopReason::MaxTokens - } - Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => { - StopReason::Refusal - } - _ => self - .pending_stop_reason - .take() - .unwrap_or(StopReason::EndTurn), - }; - - let mut events = Vec::new(); - if let Some(usage) = response.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); - events - } - - copilot::copilot_responses::StreamEvent::Failed { response } => { - let provider = PROVIDER_NAME; - let (status_code, message) = match response.error { - Some(error) => { - let status_code = StatusCode::from_str(&error.code) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - (status_code, error.message) - } - None => ( - StatusCode::INTERNAL_SERVER_ERROR, - "response.failed".to_string(), - ), - }; - vec![Err(LanguageModelCompletionError::HttpResponseError { - provider, - status_code, - message, - })] - } - - copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err( - LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))), - )], - - copilot::copilot_responses::StreamEvent::Created { .. } - | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(), - } - } -} - -fn into_copilot_chat( - model: &copilot::copilot_chat::Model, - request: LanguageModelRequest, -) -> Result { - let mut request_messages: Vec = Vec::new(); - for message in request.messages { - if let Some(last_message) = request_messages.last_mut() { - if last_message.role == message.role { - last_message.content.extend(message.content); - } else { - request_messages.push(message); - } - } else { - request_messages.push(message); - } - } - - let mut messages: Vec = Vec::new(); - for message in request_messages { - match message.role { - Role::User => { - for content in &message.content { - if let MessageContent::ToolResult(tool_result) = content { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string().into(), - LanguageModelToolResultContent::Image(image) => { - if model.supports_vision() { - ChatMessageContent::Multipart(vec![ChatMessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - }, - }]) - } else { - debug_panic!( - "This should be caught at {} level", - tool_result.tool_name - ); - "[Tool responded with an image, but this model does not support vision]".to_string().into() - } - } - }; - - messages.push(ChatMessage::Tool { - tool_call_id: tool_result.tool_use_id.to_string(), - content, - }); - } - } - - let mut content_parts = Vec::new(); - for content in &message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } - if !text.is_empty() => - { - if let Some(ChatMessagePart::Text { text: text_content }) = - content_parts.last_mut() - { - text_content.push_str(text); - } else { - content_parts.push(ChatMessagePart::Text { - text: text.to_string(), - }); - } - } - MessageContent::Image(image) if model.supports_vision() => { - content_parts.push(ChatMessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - }, - }); - } - _ => {} - } - } - - if !content_parts.is_empty() { - messages.push(ChatMessage::User { - content: content_parts.into(), - }); - } - } - Role::Assistant => { - let mut tool_calls = Vec::new(); - for content in &message.content { - if let MessageContent::ToolUse(tool_use) = content { - tool_calls.push(ToolCall { - id: tool_use.id.to_string(), - content: copilot::copilot_chat::ToolCallContent::Function { - function: copilot::copilot_chat::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input)?, - thought_signature: tool_use.thought_signature.clone(), - }, - }, - }); - } - } - - let text_content = { - let mut buffer = String::new(); - for string in message.content.iter().filter_map(|content| match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - Some(text.as_str()) - } - MessageContent::ToolUse(_) - | MessageContent::RedactedThinking(_) - | MessageContent::ToolResult(_) - | MessageContent::Image(_) => None, - }) { - buffer.push_str(string); - } - - buffer - }; - - // Extract reasoning_opaque and reasoning_text from reasoning_details - let (reasoning_opaque, reasoning_text) = - if let Some(details) = &message.reasoning_details { - let opaque = details - .get("reasoning_opaque") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - let text = details - .get("reasoning_text") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - (opaque, text) - } else { - (None, None) - }; - - messages.push(ChatMessage::Assistant { - content: if text_content.is_empty() { - ChatMessageContent::empty() - } else { - text_content.into() - }, - tool_calls, - reasoning_opaque, - reasoning_text, - }); - } - Role::System => messages.push(ChatMessage::System { - content: message.string_contents(), - }), - } - } - - let tools = request - .tools - .iter() - .map(|tool| Tool::Function { - function: copilot::copilot_chat::Function { - name: tool.name.clone(), - description: tool.description.clone(), - parameters: tool.input_schema.clone(), - }, - }) - .collect::>(); - - Ok(CopilotChatRequest { - intent: true, - n: 1, - stream: model.uses_streaming(), - temperature: 0.1, - model: model.id().to_string(), - messages, - tools, - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto, - LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any, - LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None, - }), - }) -} - -fn into_copilot_responses( - model: &copilot::copilot_chat::Model, - request: LanguageModelRequest, -) -> copilot::copilot_responses::Request { - use copilot::copilot_responses as responses; - - let LanguageModelRequest { - thread_id: _, - prompt_id: _, - intent: _, - mode: _, - messages, - tools, - tool_choice, - stop: _, - temperature, - thinking_allowed: _, - } = request; - - let mut input_items: Vec = Vec::new(); - - for message in messages { - match message.role { - Role::User => { - for content in &message.content { - if let MessageContent::ToolResult(tool_result) = content { - let output = if let Some(out) = &tool_result.output { - match out { - serde_json::Value::String(s) => { - responses::ResponseFunctionOutput::Text(s.clone()) - } - serde_json::Value::Null => { - responses::ResponseFunctionOutput::Text(String::new()) - } - other => responses::ResponseFunctionOutput::Text(other.to_string()), - } - } else { - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - responses::ResponseFunctionOutput::Text(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - if model.supports_vision() { - responses::ResponseFunctionOutput::Content(vec![ - responses::ResponseInputContent::InputImage { - image_url: Some(image.to_base64_url()), - detail: Default::default(), - }, - ]) - } else { - debug_panic!( - "This should be caught at {} level", - tool_result.tool_name - ); - responses::ResponseFunctionOutput::Text( - "[Tool responded with an image, but this model does not support vision]".into(), - ) - } - } - } - }; - - input_items.push(responses::ResponseInputItem::FunctionCallOutput { - call_id: tool_result.tool_use_id.to_string(), - output, - status: None, - }); - } - } - - let mut parts: Vec = Vec::new(); - for content in &message.content { - match content { - MessageContent::Text(text) => { - parts.push(responses::ResponseInputContent::InputText { - text: text.clone(), - }); - } - - MessageContent::Image(image) => { - if model.supports_vision() { - parts.push(responses::ResponseInputContent::InputImage { - image_url: Some(image.to_base64_url()), - detail: Default::default(), - }); - } - } - _ => {} - } - } - - if !parts.is_empty() { - input_items.push(responses::ResponseInputItem::Message { - role: "user".into(), - content: Some(parts), - status: None, - }); - } - } - - Role::Assistant => { - for content in &message.content { - if let MessageContent::ToolUse(tool_use) = content { - input_items.push(responses::ResponseInputItem::FunctionCall { - call_id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - arguments: tool_use.raw_input.clone(), - status: None, - thought_signature: tool_use.thought_signature.clone(), - }); - } - } - - for content in &message.content { - if let MessageContent::RedactedThinking(data) = content { - input_items.push(responses::ResponseInputItem::Reasoning { - id: None, - summary: Vec::new(), - encrypted_content: data.clone(), - }); - } - } - - let mut parts: Vec = Vec::new(); - for content in &message.content { - match content { - MessageContent::Text(text) => { - parts.push(responses::ResponseInputContent::OutputText { - text: text.clone(), - }); - } - MessageContent::Image(_) => { - parts.push(responses::ResponseInputContent::OutputText { - text: "[image omitted]".to_string(), - }); - } - _ => {} - } - } - - if !parts.is_empty() { - input_items.push(responses::ResponseInputItem::Message { - role: "assistant".into(), - content: Some(parts), - status: Some("completed".into()), - }); - } - } - - Role::System => { - let mut parts: Vec = Vec::new(); - for content in &message.content { - if let MessageContent::Text(text) = content { - parts.push(responses::ResponseInputContent::InputText { - text: text.clone(), - }); - } - } - - if !parts.is_empty() { - input_items.push(responses::ResponseInputItem::Message { - role: "system".into(), - content: Some(parts), - status: None, - }); - } - } - } - } - - let converted_tools: Vec = tools - .into_iter() - .map(|tool| responses::ToolDefinition::Function { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - strict: None, - }) - .collect(); - - let mapped_tool_choice = tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => responses::ToolChoice::Auto, - LanguageModelToolChoice::Any => responses::ToolChoice::Any, - LanguageModelToolChoice::None => responses::ToolChoice::None, - }); - - responses::Request { - model: model.id().to_string(), - input: input_items, - stream: model.uses_streaming(), - temperature, - tools: converted_tools, - tool_choice: mapped_tool_choice, - reasoning: None, // We would need to add support for setting from user settings. - include: Some(vec![ - copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent, - ]), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use copilot::copilot_responses as responses; - use futures::StreamExt; - - fn map_events(events: Vec) -> Vec { - futures::executor::block_on(async { - CopilotResponsesEventMapper::new() - .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) - .collect::>() - .await - .into_iter() - .map(Result::unwrap) - .collect() - }) - } - - #[test] - fn responses_stream_maps_text_and_usage() { - let events = vec![ - responses::StreamEvent::OutputItemAdded { - output_index: 0, - sequence_number: None, - item: responses::ResponseOutputItem::Message { - id: "msg_1".into(), - role: "assistant".into(), - content: Some(Vec::new()), - }, - }, - responses::StreamEvent::OutputTextDelta { - item_id: "msg_1".into(), - output_index: 0, - delta: "Hello".into(), - }, - responses::StreamEvent::Completed { - response: responses::Response { - usage: Some(responses::ResponseUsage { - input_tokens: Some(5), - output_tokens: Some(3), - total_tokens: Some(8), - }), - ..Default::default() - }, - }, - ]; - - let mapped = map_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_1" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Text(ref text) if text == "Hello" - )); - assert!(matches!( - mapped[2], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 5, - output_tokens: 3, - .. - }) - )); - assert!(matches!( - mapped[3], - LanguageModelCompletionEvent::Stop(StopReason::EndTurn) - )); - } - - #[test] - fn responses_stream_maps_tool_calls() { - let events = vec![responses::StreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: responses::ResponseOutputItem::FunctionCall { - id: Some("fn_1".into()), - call_id: "call_1".into(), - name: "do_it".into(), - arguments: "{\"x\":1}".into(), - status: None, - thought_signature: None, - }, - }]; - - let mapped = map_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUse(ref use_) if use_.id.to_string() == "call_1" && use_.name.as_ref() == "do_it" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_handles_json_parse_error() { - let events = vec![responses::StreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: responses::ResponseOutputItem::FunctionCall { - id: Some("fn_1".into()), - call_id: "call_1".into(), - name: "do_it".into(), - arguments: "{not json}".into(), - status: None, - thought_signature: None, - }, - }]; - - let mapped = map_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::ToolUseJsonParseError { ref id, ref tool_name, .. } - if id.to_string() == "call_1" && tool_name.as_ref() == "do_it" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::ToolUse) - )); - } - - #[test] - fn responses_stream_maps_reasoning_summary_and_encrypted_content() { - let events = vec![responses::StreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: responses::ResponseOutputItem::Reasoning { - id: "r1".into(), - summary: Some(vec![responses::ResponseReasoningItem { - kind: "summary_text".into(), - text: "Chain".into(), - }]), - encrypted_content: Some("ENC".into()), - }, - }]; - - let mapped = map_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain" - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC" - )); - } - - #[test] - fn responses_stream_handles_incomplete_max_tokens() { - let events = vec![responses::StreamEvent::Incomplete { - response: responses::Response { - usage: Some(responses::ResponseUsage { - input_tokens: Some(10), - output_tokens: Some(0), - total_tokens: Some(10), - }), - incomplete_details: Some(responses::IncompleteDetails { - reason: Some(responses::IncompleteReason::MaxOutputTokens), - }), - ..Default::default() - }, - }]; - - let mapped = map_events(events); - assert!(matches!( - mapped[0], - LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: 10, - output_tokens: 0, - .. - }) - )); - assert!(matches!( - mapped[1], - LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) - )); - } - - #[test] - fn responses_stream_handles_incomplete_content_filter() { - let events = vec![responses::StreamEvent::Incomplete { - response: responses::Response { - usage: None, - incomplete_details: Some(responses::IncompleteDetails { - reason: Some(responses::IncompleteReason::ContentFilter), - }), - ..Default::default() - }, - }]; - - let mapped = map_events(events); - assert!(matches!( - mapped.last().unwrap(), - LanguageModelCompletionEvent::Stop(StopReason::Refusal) - )); - } - - #[test] - fn responses_stream_completed_no_duplicate_after_tool_use() { - let events = vec![ - responses::StreamEvent::OutputItemDone { - output_index: 0, - sequence_number: None, - item: responses::ResponseOutputItem::FunctionCall { - id: Some("fn_1".into()), - call_id: "call_1".into(), - name: "do_it".into(), - arguments: "{}".into(), - status: None, - thought_signature: None, - }, - }, - responses::StreamEvent::Completed { - response: responses::Response::default(), - }, - ]; - - let mapped = map_events(events); - - let mut stop_count = 0usize; - let mut saw_tool_use_stop = false; - for event in mapped { - if let LanguageModelCompletionEvent::Stop(reason) = event { - stop_count += 1; - if matches!(reason, StopReason::ToolUse) { - saw_tool_use_stop = true; - } - } - } - assert_eq!(stop_count, 1, "should emit exactly one Stop event"); - assert!(saw_tool_use_stop, "Stop reason should be ToolUse"); - } - - #[test] - fn responses_stream_failed_maps_http_response_error() { - let events = vec![responses::StreamEvent::Failed { - response: responses::Response { - error: Some(responses::ResponseError { - code: "429".into(), - message: "too many requests".into(), - }), - ..Default::default() - }, - }]; - - let mapped_results = futures::executor::block_on(async { - CopilotResponsesEventMapper::new() - .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) - .collect::>() - .await - }); - - assert_eq!(mapped_results.len(), 1); - match &mapped_results[0] { - Err(LanguageModelCompletionError::HttpResponseError { - status_code, - message, - .. - }) => { - assert_eq!(*status_code, http_client::StatusCode::TOO_MANY_REQUESTS); - assert_eq!(message, "too many requests"); - } - other => panic!("expected HttpResponseError, got {:?}", other), - } - } - - #[test] - fn chat_completions_stream_maps_reasoning_data() { - use copilot::copilot_chat::ResponseEvent; - - let events = vec![ - ResponseEvent { - choices: vec![copilot::copilot_chat::ResponseChoice { - index: Some(0), - finish_reason: None, - delta: Some(copilot::copilot_chat::ResponseDelta { - content: None, - role: Some(copilot::copilot_chat::Role::Assistant), - tool_calls: vec![copilot::copilot_chat::ToolCallChunk { - index: Some(0), - id: Some("call_abc123".to_string()), - function: Some(copilot::copilot_chat::FunctionChunk { - name: Some("list_directory".to_string()), - arguments: Some("{\"path\":\"test\"}".to_string()), - thought_signature: None, - }), - }], - reasoning_opaque: Some("encrypted_reasoning_token_xyz".to_string()), - reasoning_text: Some("Let me check the directory".to_string()), - }), - message: None, - }], - id: "chatcmpl-123".to_string(), - usage: None, - }, - ResponseEvent { - choices: vec![copilot::copilot_chat::ResponseChoice { - index: Some(0), - finish_reason: Some("tool_calls".to_string()), - delta: Some(copilot::copilot_chat::ResponseDelta { - content: None, - role: None, - tool_calls: vec![], - reasoning_opaque: None, - reasoning_text: None, - }), - message: None, - }], - id: "chatcmpl-123".to_string(), - usage: None, - }, - ]; - - let mapped = futures::executor::block_on(async { - map_to_language_model_completion_events( - Box::pin(futures::stream::iter(events.into_iter().map(Ok))), - true, - ) - .collect::>() - .await - }); - - let mut has_reasoning_details = false; - let mut has_tool_use = false; - let mut reasoning_opaque_value: Option = None; - let mut reasoning_text_value: Option = None; - - for event_result in mapped { - match event_result { - Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => { - has_reasoning_details = true; - reasoning_opaque_value = details - .get("reasoning_opaque") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - reasoning_text_value = details - .get("reasoning_text") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - } - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { - has_tool_use = true; - assert_eq!(tool_use.id.to_string(), "call_abc123"); - assert_eq!(tool_use.name.as_ref(), "list_directory"); - } - _ => {} - } - } - - assert!( - has_reasoning_details, - "Should emit ReasoningDetails event for Gemini 3 reasoning" - ); - assert!(has_tool_use, "Should emit ToolUse event"); - assert_eq!( - reasoning_opaque_value, - Some("encrypted_reasoning_token_xyz".to_string()), - "Should capture reasoning_opaque" - ); - assert_eq!( - reasoning_text_value, - Some("Let me check the directory".to_string()), - "Should capture reasoning_text" - ); - } -} -struct ConfigurationView { - copilot_status: Option, - state: Entity, - _subscription: Option, -} - -impl ConfigurationView { - pub fn new(state: Entity, cx: &mut Context) -> Self { - let copilot = Copilot::global(cx); - - Self { - copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), - state, - _subscription: copilot.as_ref().map(|copilot| { - cx.observe(copilot, |this, model, cx| { - this.copilot_status = Some(model.read(cx).status()); - cx.notify(); - }) - }), - } - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - if self.state.read(cx).is_authenticated(cx) { - ConfiguredApiCard::new("Authorized") - .button_label("Sign Out") - .on_click(|_, window, cx| { - window.dispatch_action(copilot::SignOut.boxed_clone(), cx); - }) - .into_any_element() - } else { - let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4); - - const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider."; - - match &self.copilot_status { - Some(status) => match status { - Status::Starting { task: _ } => h_flex() - .gap_2() - .child(loading_icon) - .child(Label::new("Starting Copilot…")) - .into_any_element(), - Status::SigningIn { prompt: _ } - | Status::SignedOut { - awaiting_signing_in: true, - } => h_flex() - .gap_2() - .child(loading_icon) - .child(Label::new("Signing into Copilot…")) - .into_any_element(), - Status::Error(_) => { - const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot."; - v_flex() - .gap_6() - .child(Label::new(LABEL)) - .child(svg().size_8().path(IconName::CopilotError.path())) - .into_any_element() - } - _ => { - const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription."; - - v_flex() - .gap_2() - .child(Label::new(LABEL)) - .child( - Button::new("sign_in", "Sign in to use GitHub Copilot") - .full_width() - .style(ButtonStyle::Outlined) - .icon_color(Color::Muted) - .icon(IconName::Github) - .icon_position(IconPosition::Start) - .icon_size(IconSize::Small) - .on_click(|_, window, cx| { - copilot::initiate_sign_in(window, cx) - }), - ) - .into_any_element() - } - }, - None => v_flex() - .gap_6() - .child(Label::new(ERROR_LABEL)) - .into_any_element(), - } - } - } -} diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index c5a5affcd3d9e8c34f6306f86cb5348f86397892..fdea1aabc013085f09c930b5cfa4a283d92f1a8b 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,44 +1,22 @@ -use anyhow::{Context as _, Result, anyhow}; -use collections::BTreeMap; -use credentials_provider::CredentialsProvider; -use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; +use anyhow::Result; +use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use google_ai::{ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, ThinkingConfig, UsageMetadata, }; -use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; -use http_client::HttpClient; +use gpui::{App, AppContext as _}; use language_model::{ - AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError, - LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, -}; -use language_model::{ - LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, RateLimiter, Role, + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, + StopReason, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; pub use settings::GoogleAvailableModel as AvailableModel; -use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::sync::{ - Arc, LazyLock, - atomic::{self, AtomicU64}, +use std::{ + pin::Pin, + sync::atomic::{self, AtomicU64}, }; -use strum::IntoEnumIterator; -use ui::{List, prelude::*}; -use ui_input::InputField; -use util::ResultExt; -use zed_env_vars::EnvVar; - -use crate::api_key::ApiKey; -use crate::api_key::ApiKeyState; -use crate::ui::{ConfiguredApiCard, InstructionListItem}; - -const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; #[derive(Default, Clone, Debug, PartialEq)] pub struct GoogleSettings { @@ -57,346 +35,6 @@ pub enum ModelMode { }, } -pub struct GoogleLanguageModelProvider { - http_client: Arc, - state: Entity, -} - -pub struct State { - api_key_state: ApiKeyState, -} - -const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; -const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; - -static API_KEY_ENV_VAR: LazyLock = LazyLock::new(|| { - // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY - EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into())) -}); - -impl State { - fn is_authenticated(&self) -> bool { - self.api_key_state.has_key() - } - - fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { - let api_url = GoogleLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) - } - - fn authenticate(&mut self, cx: &mut Context) -> Task> { - let api_url = GoogleLanguageModelProvider::api_url(cx); - self.api_key_state.load_if_needed( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ) - } -} - -impl GoogleLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| { - cx.observe_global::(|this: &mut State, cx| { - let api_url = Self::api_url(cx); - this.api_key_state.handle_url_change( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ); - cx.notify(); - }) - .detach(); - State { - api_key_state: ApiKeyState::new(Self::api_url(cx)), - } - }); - - Self { http_client, state } - } - - fn create_language_model(&self, model: google_ai::Model) -> Arc { - Arc::new(GoogleLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) - } - - pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { - if let Some(key) = API_KEY_ENV_VAR.value.clone() { - return Task::ready(Ok(key)); - } - let credentials_provider = ::global(cx); - let api_url = Self::api_url(cx).to_string(); - cx.spawn(async move |cx| { - Ok( - ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx) - .await? - .key() - .to_string(), - ) - }) - } - - fn settings(cx: &App) -> &GoogleSettings { - &crate::AllLanguageModelSettings::get_global(cx).google - } - - fn api_url(cx: &App) -> SharedString { - let api_url = &Self::settings(cx).api_url; - if api_url.is_empty() { - google_ai::API_URL.into() - } else { - SharedString::new(api_url.as_str()) - } - } -} - -impl LanguageModelProviderState for GoogleLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) - } -} - -impl LanguageModelProvider for GoogleLanguageModelProvider { - fn id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn icon(&self) -> IconName { - IconName::AiGoogle - } - - fn default_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(google_ai::Model::default())) - } - - fn default_fast_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(google_ai::Model::default_fast())) - } - - fn provided_models(&self, cx: &App) -> Vec> { - let mut models = BTreeMap::default(); - - // Add base models from google_ai::Model::iter() - for model in google_ai::Model::iter() { - if !matches!(model, google_ai::Model::Custom { .. }) { - models.insert(model.id().to_string(), model); - } - } - - // Override with available models from settings - for model in &GoogleLanguageModelProvider::settings(cx).available_models { - models.insert( - model.name.clone(), - google_ai::Model::Custom { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - mode: model.mode.unwrap_or_default(), - }, - ); - } - - models - .into_values() - .map(|model| { - Arc::new(GoogleLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) as Arc - }) - .collect() - } - - fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated() - } - - fn authenticate(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.authenticate(cx)) - } - - fn configuration_view( - &self, - target_agent: language_model::ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut App, - ) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) - .into() - } - - fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state - .update(cx, |state, cx| state.set_api_key(None, cx)) - } -} - -pub struct GoogleLanguageModel { - id: LanguageModelId, - model: google_ai::Model, - state: Entity, - http_client: Arc, - request_limiter: RateLimiter, -} - -impl GoogleLanguageModel { - fn stream_completion( - &self, - request: google_ai::GenerateContentRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result>>, - > { - let http_client = self.http_client.clone(); - - let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { - let api_url = GoogleLanguageModelProvider::api_url(cx); - (state.api_key_state.key(&api_url), api_url) - }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - - async move { - let api_key = api_key.context("Missing Google API key")?; - let request = google_ai::stream_generate_content( - http_client.as_ref(), - &api_url, - &api_key, - request, - ); - request.await.context("failed to stream completion") - } - .boxed() - } -} - -impl LanguageModel for GoogleLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name().to_string()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn supports_tools(&self) -> bool { - self.model.supports_tools() - } - - fn supports_images(&self) -> bool { - self.model.supports_images() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => true, - } - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } - - fn telemetry_id(&self) -> String { - format!("google/{}", self.model.request_id()) - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count() - } - - fn max_output_tokens(&self) -> Option { - self.model.max_output_tokens() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - let model_id = self.model.request_id().to_string(); - let request = into_google(request, model_id, self.model.mode()); - let http_client = self.http_client.clone(); - let api_url = GoogleLanguageModelProvider::api_url(cx); - let api_key = self.state.read(cx).api_key_state.key(&api_url); - - async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - } - .into()); - }; - let response = google_ai::count_tokens( - http_client.as_ref(), - &api_url, - &api_key, - google_ai::CountTokensRequest { - generate_content_request: request, - }, - ) - .await?; - Ok(response.total_tokens) - } - .boxed() - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - futures::stream::BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let request = into_google( - request, - self.model.request_id().to_string(), - self.model.mode(), - ); - let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await.map_err(LanguageModelCompletionError::from)?; - Ok(GoogleEventMapper::new().map_stream(response)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } -} - pub fn into_google( mut request: LanguageModelRequest, model_id: String, @@ -439,7 +77,6 @@ pub fn into_google( })] } language_model::MessageContent::ToolUse(tool_use) => { - // Normalize empty string signatures to None let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); vec![Part::FunctionCallPart(google_ai::FunctionCallPart { @@ -457,7 +94,6 @@ pub fn into_google( google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object response: serde_json::json!({ "output": text }), @@ -470,7 +106,6 @@ pub fn into_google( Part::FunctionResponsePart(google_ai::FunctionResponsePart { function_response: google_ai::FunctionResponse { name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object response: serde_json::json!({ "output": "Tool responded with an image" }), @@ -519,7 +154,7 @@ pub fn into_google( role: match message.role { Role::User => google_ai::Role::User, Role::Assistant => google_ai::Role::Model, - Role::System => google_ai::Role::User, // Google AI doesn't have a system role + Role::System => google_ai::Role::User, }, }) } @@ -653,13 +288,13 @@ impl GoogleEventMapper { Part::InlineDataPart(_) => {} Part::FunctionCallPart(function_call_part) => { wants_to_use_tool = true; - let name: Arc = function_call_part.function_call.name.into(); + let name: std::sync::Arc = + function_call_part.function_call.name.into(); let next_tool_id = TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); let id: LanguageModelToolUseId = format!("{}-{}", name, next_tool_id).into(); - // Normalize empty string signatures to None let thought_signature = function_call_part .thought_signature .filter(|s| !s.is_empty()); @@ -678,7 +313,7 @@ impl GoogleEventMapper { Part::FunctionResponsePart(_) => {} Part::ThoughtPart(part) => { events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? + text: "(Encrypted thought)".to_string(), signature: Some(part.thought_signature), })); } @@ -686,8 +321,6 @@ impl GoogleEventMapper { } } - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` if wants_to_use_tool { self.stop_reason = StopReason::ToolUse; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); @@ -700,8 +333,6 @@ pub fn count_google_tokens( request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly. - // So we have to use tokenizer from tiktoken_rs to count tokens. cx.background_spawn(async move { let messages = request .messages @@ -718,8 +349,6 @@ pub fn count_google_tokens( }) .collect::>(); - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64) }) .boxed() @@ -760,148 +389,6 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage { } } -struct ConfigurationView { - api_key_editor: Entity, - state: Entity, - target_agent: language_model::ConfigurationViewTargetAgent, - load_credentials_task: Option>, -} - -impl ConfigurationView { - fn new( - state: Entity, - target_agent: language_model::ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut Context, - ) -> Self { - cx.observe(&state, |_, _, cx| { - cx.notify(); - }) - .detach(); - - let load_credentials_task = Some(cx.spawn_in(window, { - let state = state.clone(); - async move |this, cx| { - if let Some(task) = state - .update(cx, |state, cx| state.authenticate(cx)) - .log_err() - { - // We don't log an error, because "not signed in" is also an error. - let _ = task.await; - } - this.update(cx, |this, cx| { - this.load_credentials_task = None; - cx.notify(); - }) - .log_err(); - } - })); - - Self { - api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")), - target_agent, - state, - load_credentials_task, - } - } - - fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); - if api_key.is_empty() { - return; - } - - // url changes can cause the editor to be displayed again - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(None, cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); - let configured_card_label = if env_var_set { - format!( - "API key set in {} environment variable", - API_KEY_ENV_VAR.name - ) - } else { - let api_url = GoogleLanguageModelProvider::api_url(cx); - if api_url == google_ai::API_URL { - "API key configured".to_string() - } else { - format!("API key configured for {}", api_url) - } - }; - - if self.load_credentials_task.is_some() { - div() - .child(Label::new("Loading credentials...")) - .into_any_element() - } else if self.should_render_editor(cx) { - v_flex() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { - ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(), - ConfigurationViewTargetAgent::Other(agent) => agent.clone(), - }))) - .child( - List::new() - .child(InstructionListItem::new( - "Create one by visiting", - Some("Google AI's console"), - Some("https://aistudio.google.com/app/apikey"), - )) - .child(InstructionListItem::text_only( - "Paste your API key below and hit enter to start using the assistant", - )), - ) - .child(self.api_key_editor.clone()) - .child( - Label::new( - format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."), - ) - .size(LabelSize::Small).color(Color::Muted), - ) - .into_any_element() - } else { - ConfiguredApiCard::new(configured_card_label) - .disabled(env_var_set) - .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) - .when(env_var_set, |this| { - this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")) - }) - .into_any_element() - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -940,7 +427,7 @@ mod tests { let events = mapper.map_event(response); - assert_eq!(events.len(), 2); // ToolUse event + Stop event + assert_eq!(events.len(), 2); if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { assert_eq!(tool_use.name.as_ref(), "test_function"); @@ -1034,18 +521,25 @@ mod tests { parts: vec![ Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "function_1".to_string(), - args: json!({"arg": "value1"}), + name: "function_a".to_string(), + args: json!({}), }, - thought_signature: Some("signature_1".to_string()), + thought_signature: Some("sig_a".to_string()), }), Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "function_2".to_string(), - args: json!({"arg": "value2"}), + name: "function_b".to_string(), + args: json!({}), }, thought_signature: None, }), + Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "function_c".to_string(), + args: json!({}), + }, + thought_signature: Some("sig_c".to_string()), + }), ], role: GoogleRole::Model, }, @@ -1060,35 +554,35 @@ mod tests { let events = mapper.map_event(response); - assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event - - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { - assert_eq!(tool_use.name.as_ref(), "function_1"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); - } else { - panic!("Expected ToolUse event for function_1"); - } + let tool_uses: Vec<_> = events + .iter() + .filter_map(|e| { + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = e { + Some(tool_use) + } else { + None + } + }) + .collect(); - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "function_2"); - assert_eq!(tool_use.thought_signature, None); - } else { - panic!("Expected ToolUse event for function_2"); - } + assert_eq!(tool_uses.len(), 3); + assert_eq!(tool_uses[0].thought_signature.as_deref(), Some("sig_a")); + assert_eq!(tool_uses[1].thought_signature, None); + assert_eq!(tool_uses[2].thought_signature.as_deref(), Some("sig_c")); } #[test] fn test_tool_use_with_signature_converts_to_function_call_part() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), + id: LanguageModelToolUseId::from("test-id"), + name: "test_tool".into(), + input: json!({"key": "value"}), + raw_input: r#"{"key": "value"}"#.to_string(), is_input_complete: true, - thought_signature: Some("test_signature_456".to_string()), + thought_signature: Some("test_sig".to_string()), }; - let request = super::into_google( + let request = into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -1102,13 +596,11 @@ mod tests { GoogleModelMode::Default, ); - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.function_call.name, "test_function"); - assert_eq!( - fc_part.thought_signature.as_deref(), - Some("test_signature_456") - ); + let parts = &request.contents[0].parts; + assert_eq!(parts.len(), 1); + + if let Part::FunctionCallPart(fcp) = &parts[0] { + assert_eq!(fcp.thought_signature.as_deref(), Some("test_sig")); } else { panic!("Expected FunctionCallPart"); } @@ -1117,15 +609,15 @@ mod tests { #[test] fn test_tool_use_without_signature_omits_field() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), + id: LanguageModelToolUseId::from("test-id"), + name: "test_tool".into(), + input: json!({"key": "value"}), + raw_input: r#"{"key": "value"}"#.to_string(), is_input_complete: true, thought_signature: None, }; - let request = super::into_google( + let request = into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -1139,9 +631,10 @@ mod tests { GoogleModelMode::Default, ); - assert_eq!(request.contents[0].parts.len(), 1); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); + let parts = &request.contents[0].parts; + + if let Part::FunctionCallPart(fcp) = &parts[0] { + assert_eq!(fcp.thought_signature, None); } else { panic!("Expected FunctionCallPart"); } @@ -1150,15 +643,15 @@ mod tests { #[test] fn test_empty_signature_in_tool_use_normalized_to_none() { let tool_use = language_model::LanguageModelToolUse { - id: LanguageModelToolUseId::from("test_id"), - name: "test_function".into(), - raw_input: json!({"arg": "value"}).to_string(), - input: json!({"arg": "value"}), + id: LanguageModelToolUseId::from("test-id"), + name: "test_tool".into(), + input: json!({}), + raw_input: "{}".to_string(), is_input_complete: true, thought_signature: Some("".to_string()), }; - let request = super::into_google( + let request = into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -1172,8 +665,10 @@ mod tests { GoogleModelMode::Default, ); - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature, None); + let parts = &request.contents[0].parts; + + if let Part::FunctionCallPart(fcp) = &parts[0] { + assert_eq!(fcp.thought_signature, None); } else { panic!("Expected FunctionCallPart"); } @@ -1181,9 +676,8 @@ mod tests { #[test] fn test_round_trip_preserves_signature() { - let mut mapper = GoogleEventMapper::new(); + let original_signature = "original_thought_signature_abc123"; - // Simulate receiving a response from Google with a signature let response = GenerateContentResponse { candidates: Some(vec![GenerateContentCandidate { index: Some(0), @@ -1193,7 +687,7 @@ mod tests { name: "test_function".to_string(), args: json!({"arg": "value"}), }, - thought_signature: Some("round_trip_sig".to_string()), + thought_signature: Some(original_signature.to_string()), })], role: GoogleRole::Model, }, @@ -1206,6 +700,7 @@ mod tests { usage_metadata: None, }; + let mut mapper = GoogleEventMapper::new(); let events = mapper.map_event(response); let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { @@ -1214,8 +709,7 @@ mod tests { panic!("Expected ToolUse event"); }; - // Convert back to Google format - let request = super::into_google( + let request = into_google( LanguageModelRequest { messages: vec![language_model::LanguageModelRequestMessage { role: Role::Assistant, @@ -1229,9 +723,9 @@ mod tests { GoogleModelMode::Default, ); - // Verify signature is preserved - if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { - assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); + let parts = &request.contents[0].parts; + if let Part::FunctionCallPart(fcp) = &parts[0] { + assert_eq!(fcp.thought_signature.as_deref(), Some(original_signature)); } else { panic!("Expected FunctionCallPart"); } @@ -1247,14 +741,14 @@ mod tests { content: Content { parts: vec![ Part::TextPart(TextPart { - text: "I'll help with that.".to_string(), + text: "Let me help you with that.".to_string(), }), Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "helper_function".to_string(), - args: json!({"query": "help"}), + name: "search".to_string(), + args: json!({"query": "test"}), }, - thought_signature: Some("mixed_sig".to_string()), + thought_signature: Some("thinking_sig".to_string()), }), ], role: GoogleRole::Model, @@ -1270,27 +764,35 @@ mod tests { let events = mapper.map_event(response); - assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event + let mut found_text = false; + let mut found_tool_with_sig = false; - if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { - assert_eq!(text, "I'll help with that."); - } else { - panic!("Expected Text event"); + for event in events { + match event { + Ok(LanguageModelCompletionEvent::Text(text)) => { + assert_eq!(text, "Let me help you with that."); + found_text = true; + } + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + assert_eq!(tool_use.thought_signature.as_deref(), Some("thinking_sig")); + found_tool_with_sig = true; + } + _ => {} + } } - if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { - assert_eq!(tool_use.name.as_ref(), "helper_function"); - assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); - } else { - panic!("Expected ToolUse event"); - } + assert!(found_text, "Should have found text event"); + assert!( + found_tool_with_sig, + "Should have found tool use with signature" + ); } #[test] fn test_special_characters_in_signature_preserved() { - let mut mapper = GoogleEventMapper::new(); + let special_signature = "sig/with+special=chars&more%stuff"; - let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); + let mut mapper = GoogleEventMapper::new(); let response = GenerateContentResponse { candidates: Some(vec![GenerateContentCandidate { @@ -1298,10 +800,10 @@ mod tests { content: Content { parts: vec![Part::FunctionCallPart(FunctionCallPart { function_call: FunctionCall { - name: "test_function".to_string(), - args: json!({"arg": "value"}), + name: "test".to_string(), + args: json!({}), }, - thought_signature: Some(signature_with_special_chars.clone()), + thought_signature: Some(special_signature.to_string()), })], role: GoogleRole::Model, }, @@ -1319,7 +821,7 @@ mod tests { if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { assert_eq!( tool_use.thought_signature.as_deref(), - Some(signature_with_special_chars.as_str()) + Some(special_signature) ); } else { panic!("Expected ToolUse event"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 46cea34e3e01cb0f8ad0f859827881f3ec74cad7..fbe13dfca0a5c61db8fd421f13289e718914c51a 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,38 +1,17 @@ use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashMap}; -use futures::Stream; -use futures::{FutureExt, StreamExt, future, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; -use http_client::HttpClient; +use collections::HashMap; +use futures::{FutureExt, Stream, future::BoxFuture}; +use gpui::{App, AppContext as _}; use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, TokenUsage, + LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage, }; -use menu; -use open_ai::{ - ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion, -}; -use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; +use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent}; +pub use settings::OpenAiAvailableModel as AvailableModel; use std::pin::Pin; -use std::str::FromStr as _; -use std::sync::{Arc, LazyLock}; -use strum::IntoEnumIterator; -use ui::{List, prelude::*}; -use ui_input::InputField; -use util::ResultExt; -use zed_env_vars::{EnvVar, env_var}; - -use crate::ui::ConfiguredApiCard; -use crate::{api_key::ApiKeyState, ui::InstructionListItem}; - -const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; +use std::str::FromStr; -const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; -static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); +use language_model::LanguageModelToolResultContent; #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { @@ -40,314 +19,6 @@ pub struct OpenAiSettings { pub available_models: Vec, } -pub struct OpenAiLanguageModelProvider { - http_client: Arc, - state: Entity, -} - -pub struct State { - api_key_state: ApiKeyState, -} - -impl State { - fn is_authenticated(&self) -> bool { - self.api_key_state.has_key() - } - - fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { - let api_url = OpenAiLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) - } - - fn authenticate(&mut self, cx: &mut Context) -> Task> { - let api_url = OpenAiLanguageModelProvider::api_url(cx); - self.api_key_state.load_if_needed( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ) - } -} - -impl OpenAiLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| { - cx.observe_global::(|this: &mut State, cx| { - let api_url = Self::api_url(cx); - this.api_key_state.handle_url_change( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ); - cx.notify(); - }) - .detach(); - State { - api_key_state: ApiKeyState::new(Self::api_url(cx)), - } - }); - - Self { http_client, state } - } - - fn create_language_model(&self, model: open_ai::Model) -> Arc { - Arc::new(OpenAiLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) - } - - fn settings(cx: &App) -> &OpenAiSettings { - &crate::AllLanguageModelSettings::get_global(cx).openai - } - - fn api_url(cx: &App) -> SharedString { - let api_url = &Self::settings(cx).api_url; - if api_url.is_empty() { - open_ai::OPEN_AI_API_URL.into() - } else { - SharedString::new(api_url.as_str()) - } - } -} - -impl LanguageModelProviderState for OpenAiLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) - } -} - -impl LanguageModelProvider for OpenAiLanguageModelProvider { - fn id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn icon(&self) -> IconName { - IconName::AiOpenAi - } - - fn default_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(open_ai::Model::default())) - } - - fn default_fast_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(open_ai::Model::default_fast())) - } - - fn provided_models(&self, cx: &App) -> Vec> { - let mut models = BTreeMap::default(); - - // Add base models from open_ai::Model::iter() - for model in open_ai::Model::iter() { - if !matches!(model, open_ai::Model::Custom { .. }) { - models.insert(model.id().to_string(), model); - } - } - - // Override with available models from settings - for model in &OpenAiLanguageModelProvider::settings(cx).available_models { - models.insert( - model.name.clone(), - open_ai::Model::Custom { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - max_output_tokens: model.max_output_tokens, - max_completion_tokens: model.max_completion_tokens, - reasoning_effort: model.reasoning_effort.clone(), - }, - ); - } - - models - .into_values() - .map(|model| self.create_language_model(model)) - .collect() - } - - fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated() - } - - fn authenticate(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.authenticate(cx)) - } - - fn configuration_view( - &self, - _target_agent: language_model::ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut App, - ) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) - .into() - } - - fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state - .update(cx, |state, cx| state.set_api_key(None, cx)) - } -} - -pub struct OpenAiLanguageModel { - id: LanguageModelId, - model: open_ai::Model, - state: Entity, - http_client: Arc, - request_limiter: RateLimiter, -} - -impl OpenAiLanguageModel { - fn stream_completion( - &self, - request: open_ai::Request, - cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { - let http_client = self.http_client.clone(); - - let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { - let api_url = OpenAiLanguageModelProvider::api_url(cx); - (state.api_key_state.key(&api_url), api_url) - }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - - let future = self.request_limiter.stream(async move { - let provider = PROVIDER_NAME; - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { provider }); - }; - let request = stream_completion( - http_client.as_ref(), - provider.0.as_str(), - &api_url, - &api_key, - request, - ); - let response = request.await?; - Ok(response) - }); - - async move { Ok(future.await?.boxed()) }.boxed() - } -} - -impl LanguageModel for OpenAiLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name().to_string()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn supports_tools(&self) -> bool { - true - } - - fn supports_images(&self) -> bool { - use open_ai::Model; - match &self.model { - Model::FourOmni - | Model::FourOmniMini - | Model::FourPointOne - | Model::FourPointOneMini - | Model::FourPointOneNano - | Model::Five - | Model::FiveMini - | Model::FiveNano - | Model::FivePointOne - | Model::O1 - | Model::O3 - | Model::O4Mini => true, - Model::ThreePointFiveTurbo - | Model::Four - | Model::FourTurbo - | Model::O3Mini - | Model::Custom { .. } => false, - } - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto => true, - LanguageModelToolChoice::Any => true, - LanguageModelToolChoice::None => true, - } - } - - fn telemetry_id(&self) -> String { - format!("openai/{}", self.model.id()) - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count() - } - - fn max_output_tokens(&self) -> Option { - self.model.max_output_tokens() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_open_ai_tokens(request, self.model.clone(), cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - futures::stream::BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let request = into_open_ai( - request, - self.model.id(), - self.model.supports_parallel_tool_calls(), - self.model.supports_prompt_cache_key(), - self.max_output_tokens(), - self.model.reasoning_effort(), - ); - let completions = self.stream_completion(request, cx); - async move { - let mapper = OpenAiEventMapper::new(); - Ok(mapper.map_stream(completions.await?).boxed()) - } - .boxed() - } -} - pub fn into_open_ai( request: LanguageModelRequest, model_id: &str, @@ -441,7 +112,6 @@ pub fn into_open_ai( temperature: request.temperature.unwrap_or(1.0), max_completion_tokens: max_output_tokens, parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { - // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. Some(false) } else { None @@ -521,6 +191,7 @@ impl OpenAiEventMapper { events: Pin>>>, ) -> impl Stream> { + use futures::StreamExt; events.flat_map(move |event| { futures::stream::iter(match event { Ok(event) => self.map_event(event), @@ -648,19 +319,12 @@ pub fn count_open_ai_tokens( match model { Model::Custom { max_tokens, .. } => { let model = if max_tokens >= 100_000 { - // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o "gpt-4o" } else { - // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are - // supported with this tiktoken method "gpt-4" }; tiktoken_rs::num_tokens_from_messages(model, &messages) } - // Currently supported by tiktoken_rs - // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch - // arm with an override. We enumerate all supported models here so that we can check if new - // models are supported yet or not. Model::ThreePointFiveTurbo | Model::Four | Model::FourTurbo @@ -675,7 +339,7 @@ pub fn count_open_ai_tokens( | Model::O4Mini | Model::Five | Model::FiveMini - | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), // GPT-5.1 doesn't have tiktoken support yet; fall back on gpt-4o tokenizer + | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages), } .map(|tokens| tokens as u64) @@ -683,191 +347,11 @@ pub fn count_open_ai_tokens( .boxed() } -struct ConfigurationView { - api_key_editor: Entity, - state: Entity, - load_credentials_task: Option>, -} - -impl ConfigurationView { - fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { - let api_key_editor = cx.new(|cx| { - InputField::new( - window, - cx, - "sk-000000000000000000000000000000000000000000000000", - ) - }); - - cx.observe(&state, |_, _, cx| { - cx.notify(); - }) - .detach(); - - let load_credentials_task = Some(cx.spawn_in(window, { - let state = state.clone(); - async move |this, cx| { - if let Some(task) = state - .update(cx, |state, cx| state.authenticate(cx)) - .log_err() - { - // We don't log an error, because "not signed in" is also an error. - let _ = task.await; - } - this.update(cx, |this, cx| { - this.load_credentials_task = None; - cx.notify(); - }) - .log_err(); - } - })); - - Self { - api_key_editor, - state, - load_credentials_task, - } - } - - fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); - if api_key.is_empty() { - return; - } - - // url changes can cause the editor to be displayed again - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor - .update(cx, |input, cx| input.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(None, cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); - let configured_card_label = if env_var_set { - format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") - } else { - let api_url = OpenAiLanguageModelProvider::api_url(cx); - if api_url == OPEN_AI_API_URL { - "API key configured".to_string() - } else { - format!("API key configured for {}", api_url) - } - }; - - let api_key_section = if self.should_render_editor(cx) { - v_flex() - .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:")) - .child( - List::new() - .child(InstructionListItem::new( - "Create one by visiting", - Some("OpenAI's console"), - Some("https://platform.openai.com/api-keys"), - )) - .child(InstructionListItem::text_only( - "Ensure your OpenAI account has credits", - )) - .child(InstructionListItem::text_only( - "Paste your API key below and hit enter to start using the assistant", - )), - ) - .child(self.api_key_editor.clone()) - .child( - Label::new(format!( - "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." - )) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .child( - Label::new( - "Note that having a subscription for another service like GitHub Copilot won't work.", - ) - .size(LabelSize::Small).color(Color::Muted), - ) - .into_any_element() - } else { - ConfiguredApiCard::new(configured_card_label) - .disabled(env_var_set) - .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) - .when(env_var_set, |this| { - this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")) - }) - .into_any_element() - }; - - let compatible_api_section = h_flex() - .mt_1p5() - .gap_0p5() - .flex_wrap() - .when(self.should_render_editor(cx), |this| { - this.pt_1p5() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - }) - .child( - h_flex() - .gap_2() - .child( - Icon::new(IconName::Info) - .size(IconSize::XSmall) - .color(Color::Muted), - ) - .child(Label::new("Zed also supports OpenAI-compatible models.")), - ) - .child( - Button::new("docs", "Learn More") - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _window, cx| { - cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible") - }), - ); - - if self.load_credentials_task.is_some() { - div().child(Label::new("Loading credentials…")).into_any() - } else { - v_flex() - .size_full() - .child(api_key_section) - .child(compatible_api_section) - .into_any() - } - } -} - #[cfg(test)] mod tests { use gpui::TestAppContext; use language_model::LanguageModelRequestMessage; + use strum::IntoEnumIterator; use super::*; @@ -891,7 +375,6 @@ mod tests { thinking_allowed: true, }; - // Validate that all models are supported by tiktoken-rs for model in Model::iter() { let count = cx .executor() diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs deleted file mode 100644 index 7b10ebf963033603ede691fa72d2fa523bcdbab9..0000000000000000000000000000000000000000 --- a/crates/language_models/src/provider/open_router.rs +++ /dev/null @@ -1,1095 +0,0 @@ -use anyhow::{Result, anyhow}; -use collections::HashMap; -use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task}; -use http_client::HttpClient; -use language_model::{ - AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, -}; -use open_router::{ - Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models, -}; -use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr as _; -use std::sync::{Arc, LazyLock}; -use ui::{List, prelude::*}; -use ui_input::InputField; -use util::ResultExt; -use zed_env_vars::{EnvVar, env_var}; - -use crate::ui::ConfiguredApiCard; -use crate::{api_key::ApiKeyState, ui::InstructionListItem}; - -const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); -const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); - -const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; -static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); - -#[derive(Default, Clone, Debug, PartialEq)] -pub struct OpenRouterSettings { - pub api_url: String, - pub available_models: Vec, -} - -pub struct OpenRouterLanguageModelProvider { - http_client: Arc, - state: Entity, -} - -pub struct State { - api_key_state: ApiKeyState, - http_client: Arc, - available_models: Vec, - fetch_models_task: Option>>, -} - -impl State { - fn is_authenticated(&self) -> bool { - self.api_key_state.has_key() - } - - fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { - let api_url = OpenRouterLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) - } - - fn authenticate(&mut self, cx: &mut Context) -> Task> { - let api_url = OpenRouterLanguageModelProvider::api_url(cx); - let task = self.api_key_state.load_if_needed( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ); - - cx.spawn(async move |this, cx| { - let result = task.await; - this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) - .ok(); - result - }) - } - - fn fetch_models( - &mut self, - cx: &mut Context, - ) -> Task> { - let http_client = self.http_client.clone(); - let api_url = OpenRouterLanguageModelProvider::api_url(cx); - let Some(api_key) = self.api_key_state.key(&api_url) else { - return Task::ready(Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - })); - }; - cx.spawn(async move |this, cx| { - let models = list_models(http_client.as_ref(), &api_url, &api_key) - .await - .map_err(|e| { - LanguageModelCompletionError::Other(anyhow::anyhow!( - "OpenRouter error: {:?}", - e - )) - })?; - - this.update(cx, |this, cx| { - this.available_models = models; - cx.notify(); - }) - .map_err(|e| LanguageModelCompletionError::Other(e))?; - - Ok(()) - }) - } - - fn restart_fetch_models_task(&mut self, cx: &mut Context) { - if self.is_authenticated() { - let task = self.fetch_models(cx); - self.fetch_models_task.replace(task); - } else { - self.available_models = Vec::new(); - } - } -} - -impl OpenRouterLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| { - cx.observe_global::({ - let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone(); - move |this: &mut State, cx| { - let current_settings = OpenRouterLanguageModelProvider::settings(cx); - let settings_changed = current_settings != &last_settings; - if settings_changed { - last_settings = current_settings.clone(); - this.authenticate(cx).detach(); - cx.notify(); - } - } - }) - .detach(); - State { - api_key_state: ApiKeyState::new(Self::api_url(cx)), - http_client: http_client.clone(), - available_models: Vec::new(), - fetch_models_task: None, - } - }); - - Self { http_client, state } - } - - fn settings(cx: &App) -> &OpenRouterSettings { - &crate::AllLanguageModelSettings::get_global(cx).open_router - } - - fn api_url(cx: &App) -> SharedString { - let api_url = &Self::settings(cx).api_url; - if api_url.is_empty() { - OPEN_ROUTER_API_URL.into() - } else { - SharedString::new(api_url.as_str()) - } - } - - fn create_language_model(&self, model: open_router::Model) -> Arc { - Arc::new(OpenRouterLanguageModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) - } -} - -impl LanguageModelProviderState for OpenRouterLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) - } -} - -impl LanguageModelProvider for OpenRouterLanguageModelProvider { - fn id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn icon(&self) -> IconName { - IconName::AiOpenRouter - } - - fn default_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(open_router::Model::default())) - } - - fn default_fast_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(open_router::Model::default_fast())) - } - - fn provided_models(&self, cx: &App) -> Vec> { - let mut models_from_api = self.state.read(cx).available_models.clone(); - let mut settings_models = Vec::new(); - - for model in &Self::settings(cx).available_models { - settings_models.push(open_router::Model { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - supports_tools: model.supports_tools, - supports_images: model.supports_images, - mode: model.mode.unwrap_or_default(), - provider: model.provider.clone(), - }); - } - - for settings_model in &settings_models { - if let Some(pos) = models_from_api - .iter() - .position(|m| m.name == settings_model.name) - { - models_from_api[pos] = settings_model.clone(); - } else { - models_from_api.push(settings_model.clone()); - } - } - - models_from_api - .into_iter() - .map(|model| self.create_language_model(model)) - .collect() - } - - fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated() - } - - fn authenticate(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.authenticate(cx)) - } - - fn configuration_view( - &self, - _target_agent: language_model::ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut App, - ) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx)) - .into() - } - - fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state - .update(cx, |state, cx| state.set_api_key(None, cx)) - } -} - -pub struct OpenRouterLanguageModel { - id: LanguageModelId, - model: open_router::Model, - state: Entity, - http_client: Arc, - request_limiter: RateLimiter, -} - -impl OpenRouterLanguageModel { - fn stream_completion( - &self, - request: open_router::Request, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - futures::stream::BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { - let api_url = OpenRouterLanguageModelProvider::api_url(cx); - (state.api_key_state.key(&api_url), api_url) - }) else { - return future::ready(Err(anyhow!("App state dropped").into())).boxed(); - }; - - async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = - open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); - request.await.map_err(Into::into) - } - .boxed() - } -} - -impl LanguageModel for OpenRouterLanguageModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name().to_string()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn supports_tools(&self) -> bool { - self.model.supports_tool_calls() - } - - fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - let model_id = self.model.id().trim().to_lowercase(); - if model_id.contains("gemini") || model_id.contains("grok") { - LanguageModelToolSchemaFormat::JsonSchemaSubset - } else { - LanguageModelToolSchemaFormat::JsonSchema - } - } - - fn telemetry_id(&self) -> String { - format!("openrouter/{}", self.model.id()) - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count() - } - - fn max_output_tokens(&self) -> Option { - self.model.max_output_tokens() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto => true, - LanguageModelToolChoice::Any => true, - LanguageModelToolChoice::None => true, - } - } - - fn supports_images(&self) -> bool { - self.model.supports_images.unwrap_or(false) - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_open_router_tokens(request, self.model.clone(), cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - futures::stream::BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let request = into_open_router(request, &self.model, self.max_output_tokens()); - let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(OpenRouterEventMapper::new().map_stream(response)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } -} - -pub fn into_open_router( - request: LanguageModelRequest, - model: &Model, - max_output_tokens: Option, -) -> open_router::Request { - let mut messages = Vec::new(); - for message in request.messages { - let reasoning_details = message.reasoning_details.clone(); - for content in message.content { - match content { - MessageContent::Text(text) => add_message_content_part( - open_router::MessagePart::Text { text }, - message.role, - &mut messages, - ), - MessageContent::Thinking { .. } => {} - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - add_message_content_part( - open_router::MessagePart::Image { - image_url: image.to_base64_url(), - }, - message.role, - &mut messages, - ); - } - MessageContent::ToolUse(tool_use) => { - let tool_call = open_router::ToolCall { - id: tool_use.id.to_string(), - content: open_router::ToolCallContent::Function { - function: open_router::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - thought_signature: tool_use.thought_signature.clone(), - }, - }, - }; - - if let Some(open_router::RequestMessage::Assistant { - tool_calls, - reasoning_details: existing_reasoning, - .. - }) = messages.last_mut() - { - tool_calls.push(tool_call); - if existing_reasoning.is_none() && reasoning_details.is_some() { - *existing_reasoning = reasoning_details.clone(); - } - } else { - messages.push(open_router::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - reasoning_details: reasoning_details.clone(), - }); - } - } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_router::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_router::MessagePart::Image { - image_url: image.to_base64_url(), - }] - } - }; - - messages.push(open_router::RequestMessage::Tool { - content: content.into(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - } - - open_router::Request { - model: model.id().into(), - messages, - stream: true, - stop: request.stop, - temperature: request.temperature.unwrap_or(0.4), - max_tokens: max_output_tokens, - parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { - Some(false) - } else { - None - }, - usage: open_router::RequestUsage { include: true }, - reasoning: if request.thinking_allowed - && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode - { - Some(open_router::Reasoning { - effort: None, - max_tokens: budget_tokens, - exclude: Some(false), - enabled: Some(true), - }) - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| open_router::ToolDefinition::Function { - function: open_router::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_router::ToolChoice::Required, - LanguageModelToolChoice::None => open_router::ToolChoice::None, - }), - provider: model.provider.clone(), - } -} - -fn add_message_content_part( - new_part: open_router::MessagePart, - role: Role, - messages: &mut Vec, -) { - match (role, messages.last_mut()) { - (Role::User, Some(open_router::RequestMessage::User { content })) - | (Role::System, Some(open_router::RequestMessage::System { content })) => { - content.push_part(new_part); - } - ( - Role::Assistant, - Some(open_router::RequestMessage::Assistant { - content: Some(content), - .. - }), - ) => { - content.push_part(new_part); - } - _ => { - messages.push(match role { - Role::User => open_router::RequestMessage::User { - content: open_router::MessageContent::from(vec![new_part]), - }, - Role::Assistant => open_router::RequestMessage::Assistant { - content: Some(open_router::MessageContent::from(vec![new_part])), - tool_calls: Vec::new(), - reasoning_details: None, - }, - Role::System => open_router::RequestMessage::System { - content: open_router::MessageContent::from(vec![new_part]), - }, - }); - } - } -} - -pub struct OpenRouterEventMapper { - tool_calls_by_index: HashMap, - reasoning_details: Option, -} - -impl OpenRouterEventMapper { - pub fn new() -> Self { - Self { - tool_calls_by_index: HashMap::default(), - reasoning_details: None, - } - } - - pub fn map_stream( - mut self, - events: Pin< - Box< - dyn Send + Stream>, - >, - >, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(error.into())], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponseStreamEvent, - ) -> Vec> { - let Some(choice) = event.choices.first() else { - return vec![Err(LanguageModelCompletionError::from(anyhow!( - "Response contained no choices" - )))]; - }; - - let mut events = Vec::new(); - - if let Some(details) = choice.delta.reasoning_details.clone() { - // Emit reasoning_details immediately - events.push(Ok(LanguageModelCompletionEvent::ReasoningDetails( - details.clone(), - ))); - self.reasoning_details = Some(details); - } - - if let Some(reasoning) = choice.delta.reasoning.clone() { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: reasoning, - signature: None, - })); - } - - if let Some(content) = choice.delta.content.clone() { - // OpenRouter send empty content string with the reasoning content - // This is a workaround for the OpenRouter API bug - if !content.is_empty() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - } - - if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - - if let Some(signature) = function.thought_signature.clone() { - entry.thought_signature = Some(signature); - } - } - } - } - - if let Some(usage) = event.usage { - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }))); - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - // Don't emit reasoning_details here - already emitted immediately when captured - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - Some("tool_calls") => { - events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { - match serde_json::Value::from_str(&tool_call.arguments) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - thought_signature: tool_call.thought_signature.clone(), - }, - )), - Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_call.id.clone().into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.clone().into(), - json_parse_error: error.to_string(), - }), - } - })); - - // Don't emit reasoning_details here - already emitted immediately when captured - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",); - // Don't emit reasoning_details here - already emitted immediately when captured - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - None => {} - } - - events - } -} - -#[derive(Default)] -struct RawToolCall { - id: String, - name: String, - arguments: String, - thought_signature: Option, -} - -pub fn count_open_router_tokens( - request: LanguageModelRequest, - _model: open_router::Model, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request - .messages - .into_iter() - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(message.string_contents()), - name: None, - function_call: None, - }) - .collect::>(); - - tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64) - }) - .boxed() -} - -struct ConfigurationView { - api_key_editor: Entity, - state: Entity, - load_credentials_task: Option>, -} - -impl ConfigurationView { - fn new(state: Entity, window: &mut Window, cx: &mut Context) -> Self { - let api_key_editor = cx.new(|cx| { - InputField::new( - window, - cx, - "sk_or_000000000000000000000000000000000000000000000000", - ) - }); - - cx.observe(&state, |_, _, cx| { - cx.notify(); - }) - .detach(); - - let load_credentials_task = Some(cx.spawn_in(window, { - let state = state.clone(); - async move |this, cx| { - if let Some(task) = state - .update(cx, |state, cx| state.authenticate(cx)) - .log_err() - { - let _ = task.await; - } - - this.update(cx, |this, cx| { - this.load_credentials_task = None; - cx.notify(); - }) - .log_err(); - } - })); - - Self { - api_key_editor, - state, - load_credentials_task, - } - } - - fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); - if api_key.is_empty() { - return; - } - - // url changes can cause the editor to be displayed again - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(None, cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); - let configured_card_label = if env_var_set { - format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") - } else { - let api_url = OpenRouterLanguageModelProvider::api_url(cx); - if api_url == OPEN_ROUTER_API_URL { - "API key configured".to_string() - } else { - format!("API key configured for {}", api_url) - } - }; - - if self.load_credentials_task.is_some() { - div() - .child(Label::new("Loading credentials...")) - .into_any_element() - } else if self.should_render_editor(cx) { - v_flex() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:")) - .child( - List::new() - .child(InstructionListItem::new( - "Create an API key by visiting", - Some("OpenRouter's console"), - Some("https://openrouter.ai/keys"), - )) - .child(InstructionListItem::text_only( - "Ensure your OpenRouter account has credits", - )) - .child(InstructionListItem::text_only( - "Paste your API key below and hit enter to start using the assistant", - )), - ) - .child(self.api_key_editor.clone()) - .child( - Label::new( - format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), - ) - .size(LabelSize::Small).color(Color::Muted), - ) - .into_any_element() - } else { - ConfiguredApiCard::new(configured_card_label) - .disabled(env_var_set) - .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) - .when(env_var_set, |this| { - this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")) - }) - .into_any_element() - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use open_router::{ChoiceDelta, FunctionChunk, ResponseMessageDelta, ToolCallChunk}; - - #[gpui::test] - async fn test_reasoning_details_preservation_with_tool_calls() { - // This test verifies that reasoning_details are properly captured and preserved - // when a model uses tool calling with reasoning/thinking tokens. - // - // The key regression this prevents: - // - OpenRouter sends multiple reasoning_details updates during streaming - // - First with actual content (encrypted reasoning data) - // - Then with empty array on completion - // - We must NOT overwrite the real data with the empty array - - let mut mapper = OpenRouterEventMapper::new(); - - // Simulate the streaming events as they come from OpenRouter/Gemini - let events = vec![ - // Event 1: Initial reasoning details with text - ResponseStreamEvent { - id: Some("response_123".into()), - created: 1234567890, - model: "google/gemini-3-pro-preview".into(), - choices: vec![ChoiceDelta { - index: 0, - delta: ResponseMessageDelta { - role: None, - content: None, - reasoning: None, - tool_calls: None, - reasoning_details: Some(serde_json::json!([ - { - "type": "reasoning.text", - "text": "Let me analyze this request...", - "format": "google-gemini-v1", - "index": 0 - } - ])), - }, - finish_reason: None, - }], - usage: None, - }, - // Event 2: More reasoning details - ResponseStreamEvent { - id: Some("response_123".into()), - created: 1234567890, - model: "google/gemini-3-pro-preview".into(), - choices: vec![ChoiceDelta { - index: 0, - delta: ResponseMessageDelta { - role: None, - content: None, - reasoning: None, - tool_calls: None, - reasoning_details: Some(serde_json::json!([ - { - "type": "reasoning.encrypted", - "data": "EtgDCtUDAdHtim9OF5jm4aeZSBAtl/randomized123", - "format": "google-gemini-v1", - "index": 0, - "id": "tool_call_abc123" - } - ])), - }, - finish_reason: None, - }], - usage: None, - }, - // Event 3: Tool call starts - ResponseStreamEvent { - id: Some("response_123".into()), - created: 1234567890, - model: "google/gemini-3-pro-preview".into(), - choices: vec![ChoiceDelta { - index: 0, - delta: ResponseMessageDelta { - role: None, - content: None, - reasoning: None, - tool_calls: Some(vec![ToolCallChunk { - index: 0, - id: Some("tool_call_abc123".into()), - function: Some(FunctionChunk { - name: Some("list_directory".into()), - arguments: Some("{\"path\":\"test\"}".into()), - thought_signature: Some("sha256:test_signature_xyz789".into()), - }), - }]), - reasoning_details: None, - }, - finish_reason: None, - }], - usage: None, - }, - // Event 4: Empty reasoning_details on tool_calls finish - // This is the critical event - we must not overwrite with this empty array! - ResponseStreamEvent { - id: Some("response_123".into()), - created: 1234567890, - model: "google/gemini-3-pro-preview".into(), - choices: vec![ChoiceDelta { - index: 0, - delta: ResponseMessageDelta { - role: None, - content: None, - reasoning: None, - tool_calls: None, - reasoning_details: Some(serde_json::json!([])), - }, - finish_reason: Some("tool_calls".into()), - }], - usage: None, - }, - ]; - - // Process all events - let mut collected_events = Vec::new(); - for event in events { - let mapped = mapper.map_event(event); - collected_events.extend(mapped); - } - - // Verify we got the expected events - let mut has_tool_use = false; - let mut reasoning_details_events = Vec::new(); - let mut thought_signature_value = None; - - for event_result in collected_events { - match event_result { - Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { - has_tool_use = true; - assert_eq!(tool_use.id.to_string(), "tool_call_abc123"); - assert_eq!(tool_use.name.as_ref(), "list_directory"); - thought_signature_value = tool_use.thought_signature.clone(); - } - Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => { - reasoning_details_events.push(details); - } - _ => {} - } - } - - // Assertions - assert!(has_tool_use, "Should have emitted ToolUse event"); - assert!( - !reasoning_details_events.is_empty(), - "Should have emitted ReasoningDetails events" - ); - - // We should have received multiple reasoning_details events (text, encrypted, empty) - // The agent layer is responsible for keeping only the first non-empty one - assert!( - reasoning_details_events.len() >= 2, - "Should have multiple reasoning_details events from streaming" - ); - - // Verify at least one contains the encrypted data - let has_encrypted = reasoning_details_events.iter().any(|details| { - if let serde_json::Value::Array(arr) = details { - arr.iter().any(|item| { - item["type"] == "reasoning.encrypted" - && item["data"] - .as_str() - .map_or(false, |s| s.contains("EtgDCtUDAdHtim9OF5jm4aeZSBAtl")) - }) - } else { - false - } - }); - assert!( - has_encrypted, - "Should have at least one reasoning_details with encrypted data" - ); - - // Verify thought_signature was captured - assert!( - thought_signature_value.is_some(), - "Tool use should have thought_signature" - ); - assert_eq!( - thought_signature_value.unwrap(), - "sha256:test_signature_xyz789" - ); - } - - #[gpui::test] - async fn test_agent_prevents_empty_reasoning_details_overwrite() { - // This test verifies that the agent layer prevents empty reasoning_details - // from overwriting non-empty ones, even though the mapper emits all events. - - // Simulate what the agent does when it receives multiple ReasoningDetails events - let mut agent_reasoning_details: Option = None; - - let events = vec![ - // First event: non-empty reasoning_details - serde_json::json!([ - { - "type": "reasoning.encrypted", - "data": "real_data_here", - "format": "google-gemini-v1" - } - ]), - // Second event: empty array (should not overwrite) - serde_json::json!([]), - ]; - - for details in events { - // This mimics the agent's logic: only store if we don't already have it - if agent_reasoning_details.is_none() { - agent_reasoning_details = Some(details); - } - } - - // Verify the agent kept the first non-empty reasoning_details - assert!(agent_reasoning_details.is_some()); - let final_details = agent_reasoning_details.unwrap(); - if let serde_json::Value::Array(arr) = &final_details { - assert!( - !arr.is_empty(), - "Agent should have kept the non-empty reasoning_details" - ); - assert_eq!(arr[0]["data"], "real_data_here"); - } else { - panic!("Expected array"); - } - } -} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 15a3c936705194891ad8fbbdc4b369e27d64b261..9c029b6aa7a8588250fb95b1f429da20164ce7cd 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -7,9 +7,17 @@ use crate::provider::{ bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, open_ai_compatible::OpenAiCompatibleSettings, - open_router::OpenRouterSettings, vercel::VercelSettings, x_ai::XAiSettings, + vercel::VercelSettings, x_ai::XAiSettings, }; +pub use settings::OpenRouterAvailableModel as AvailableModel; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct OpenRouterSettings { + pub api_url: String, + pub available_models: Vec, +} + #[derive(Debug, RegisterSetting)] pub struct AllLanguageModelSettings { pub bedrock: AmazonBedrockSettings, @@ -47,9 +55,9 @@ impl settings::Settings for AllLanguageModelSettings { bedrock: AmazonBedrockSettings { available_models: bedrock.available_models.unwrap_or_default(), region: bedrock.region, - endpoint: bedrock.endpoint_url, // todo(should be api_url) + endpoint: bedrock.endpoint_url, profile_name: bedrock.profile, - role_arn: None, // todo(was never a setting for this...) + role_arn: None, authentication_method: bedrock.authentication_method.map(Into::into), allow_global: bedrock.allow_global, },