Detailed changes
@@ -8915,7 +8915,6 @@ dependencies = [
"collections",
"component",
"convert_case 0.8.0",
- "copilot",
"credentials_provider",
"deepseek",
"editor",
@@ -8926,7 +8925,6 @@ dependencies = [
"gpui",
"gpui_tokio",
"http_client",
- "language",
"language_model",
"lmstudio",
"log",
@@ -8934,8 +8932,6 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
- "open_router",
- "partial-json-fixer",
"project",
"release_channel",
"schemars",
@@ -11347,12 +11343,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "partial-json-fixer"
-version = "0.5.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
-
[[package]]
name = "password-hash"
version = "0.4.2"
@@ -5,7 +5,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use anyhow::{Context as _, Result};
use gpui::{App, SharedString, Task};
-use language_models::provider::google::GoogleLanguageModelProvider;
+use language_models::api_key_for_gemini_cli;
use project::agent_server_store::GEMINI_NAME;
#[derive(Clone)]
@@ -42,11 +42,7 @@ impl AgentServer for Gemini {
cx.spawn(async move |cx| {
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
- if let Some(api_key) = cx
- .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
- .await
- .ok()
- {
+ if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
let (command, root_dir, login) = store
@@ -4,6 +4,8 @@ mod copilot_migration;
pub mod extension_settings;
mod google_ai_migration;
pub mod headless_host;
+mod open_router_migration;
+mod openai_migration;
pub mod wasm_host;
#[cfg(test)]
@@ -893,6 +895,11 @@ impl ExtensionStore {
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
anthropic_migration::migrate_anthropic_credentials_if_needed(&extension_id, cx);
google_ai_migration::migrate_google_ai_credentials_if_needed(&extension_id, cx);
+ openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
+ open_router_migration::migrate_open_router_credentials_if_needed(
+ &extension_id,
+ cx,
+ );
})
.ok();
}
@@ -0,0 +1,157 @@
+use credentials_provider::CredentialsProvider;
+use gpui::App;
+
+const OPEN_ROUTER_EXTENSION_ID: &str = "open-router";
+const OPEN_ROUTER_PROVIDER_ID: &str = "open-router";
+const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
+
+pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
+ if extension_id != OPEN_ROUTER_EXTENSION_ID {
+ return;
+ }
+
+ let extension_credential_key = format!(
+ "extension-llm-{}:{}",
+ OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
+ );
+
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ let existing_credential = credentials_provider
+ .read_credentials(&extension_credential_key, &cx)
+ .await
+ .ok()
+ .flatten();
+
+ if existing_credential.is_some() {
+ log::debug!("OpenRouter extension already has credentials, skipping migration");
+ return;
+ }
+
+ let old_credential = credentials_provider
+ .read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
+ .await
+ .ok()
+ .flatten();
+
+ let api_key = match old_credential {
+ Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
+ Ok(key) => key,
+ Err(_) => {
+ log::error!("Failed to decode OpenRouter API key as UTF-8");
+ return;
+ }
+ },
+ None => {
+ log::debug!("No existing OpenRouter API key found to migrate");
+ return;
+ }
+ };
+
+ log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
+
+ match credentials_provider
+ .write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
+ .await
+ {
+ Ok(()) => {
+ log::info!("Successfully migrated OpenRouter API key to extension");
+ }
+ Err(err) => {
+ log::error!("Failed to migrate OpenRouter API key: {}", err);
+ }
+ }
+ })
+ .detach();
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::TestAppContext;
+
+ #[gpui::test]
+ async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
+ let api_key = "sk-or-test-key-12345";
+
+ cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
+
+ cx.update(|cx| {
+ migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let migrated = cx.read_credentials("extension-llm-open-router:open-router");
+ assert!(migrated.is_some(), "Credentials should have been migrated");
+ let (username, password) = migrated.unwrap();
+ assert_eq!(username, "Bearer");
+ assert_eq!(String::from_utf8(password).unwrap(), api_key);
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_if_extension_already_has_credentials(cx: &mut TestAppContext) {
+ let old_api_key = "sk-or-old-key";
+ let existing_key = "sk-or-existing-key";
+
+ cx.write_credentials(
+ OPEN_ROUTER_DEFAULT_API_URL,
+ "Bearer",
+ old_api_key.as_bytes(),
+ );
+ cx.write_credentials(
+ "extension-llm-open-router:open-router",
+ "Bearer",
+ existing_key.as_bytes(),
+ );
+
+ cx.update(|cx| {
+ migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-open-router:open-router");
+ let (_, password) = credentials.unwrap();
+ assert_eq!(
+ String::from_utf8(password).unwrap(),
+ existing_key,
+ "Should not overwrite existing credentials"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_if_no_old_credentials(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-open-router:open-router");
+ assert!(
+ credentials.is_none(),
+ "Should not create credentials if none existed"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
+ let api_key = "sk-or-test-key";
+
+ cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
+
+ cx.update(|cx| {
+ migrate_open_router_credentials_if_needed("some-other-extension", cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-open-router:open-router");
+ assert!(
+ credentials.is_none(),
+ "Should not migrate for other extensions"
+ );
+ }
+}
@@ -0,0 +1,153 @@
+use credentials_provider::CredentialsProvider;
+use gpui::App;
+
+const OPENAI_EXTENSION_ID: &str = "openai";
+const OPENAI_PROVIDER_ID: &str = "openai";
+const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
+
+pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
+ if extension_id != OPENAI_EXTENSION_ID {
+ return;
+ }
+
+ let extension_credential_key = format!(
+ "extension-llm-{}:{}",
+ OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
+ );
+
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ let existing_credential = credentials_provider
+ .read_credentials(&extension_credential_key, &cx)
+ .await
+ .ok()
+ .flatten();
+
+ if existing_credential.is_some() {
+ log::debug!("OpenAI extension already has credentials, skipping migration");
+ return;
+ }
+
+ let old_credential = credentials_provider
+ .read_credentials(OPENAI_DEFAULT_API_URL, &cx)
+ .await
+ .ok()
+ .flatten();
+
+ let api_key = match old_credential {
+ Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
+ Ok(key) => key,
+ Err(_) => {
+ log::error!("Failed to decode OpenAI API key as UTF-8");
+ return;
+ }
+ },
+ None => {
+ log::debug!("No existing OpenAI API key found to migrate");
+ return;
+ }
+ };
+
+ log::info!("Migrating existing OpenAI API key to OpenAI extension");
+
+ match credentials_provider
+ .write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
+ .await
+ {
+ Ok(()) => {
+ log::info!("Successfully migrated OpenAI API key to extension");
+ }
+ Err(err) => {
+ log::error!("Failed to migrate OpenAI API key: {}", err);
+ }
+ }
+ })
+ .detach();
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::TestAppContext;
+
+ #[gpui::test]
+ async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
+ let api_key = "sk-test-key-12345";
+
+ cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
+
+ cx.update(|cx| {
+ migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let migrated = cx.read_credentials("extension-llm-openai:openai");
+ assert!(migrated.is_some(), "Credentials should have been migrated");
+ let (username, password) = migrated.unwrap();
+ assert_eq!(username, "Bearer");
+ assert_eq!(String::from_utf8(password).unwrap(), api_key);
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_if_extension_already_has_credentials(cx: &mut TestAppContext) {
+ let old_api_key = "sk-old-key";
+ let existing_key = "sk-existing-key";
+
+ cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", old_api_key.as_bytes());
+ cx.write_credentials(
+ "extension-llm-openai:openai",
+ "Bearer",
+ existing_key.as_bytes(),
+ );
+
+ cx.update(|cx| {
+ migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-openai:openai");
+ let (_, password) = credentials.unwrap();
+ assert_eq!(
+ String::from_utf8(password).unwrap(),
+ existing_key,
+ "Should not overwrite existing credentials"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_if_no_old_credentials(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-openai:openai");
+ assert!(
+ credentials.is_none(),
+ "Should not create credentials if none existed"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
+ let api_key = "sk-test-key";
+
+ cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
+
+ cx.update(|cx| {
+ migrate_openai_credentials_if_needed("some-other-extension", cx);
+ });
+
+ cx.run_until_parked();
+
+ let credentials = cx.read_credentials("extension-llm-openai:openai");
+ assert!(
+ credentials.is_none(),
+ "Should not migrate for other extensions"
+ );
+ }
+}
@@ -25,7 +25,6 @@ cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
convert_case.workspace = true
-copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
extension.workspace = true
@@ -35,7 +34,6 @@ google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
-language.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
@@ -43,8 +41,6 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
-open_router = { workspace = true, features = ["schemars"] }
-partial-json-fixer.workspace = true
release_channel.workspace = true
schemars.workspace = true
semver.workspace = true
@@ -223,27 +223,13 @@ impl ApiKeyState {
}
impl ApiKey {
- pub fn key(&self) -> &str {
- &self.key
- }
-
- pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
+ fn from_env(env_var_name: SharedString, key: &str) -> Self {
Self {
source: ApiKeySource::EnvVar(env_var_name),
key: key.into(),
}
}
- pub async fn load_from_system_keychain(
- url: &str,
- credentials_provider: &dyn CredentialsProvider,
- cx: &AsyncApp,
- ) -> Result<Self, AuthenticateError> {
- Self::load_from_system_keychain_impl(url, credentials_provider, cx)
- .await
- .into_authenticate_result()
- }
-
async fn load_from_system_keychain_impl(
url: &str,
credentials_provider: &dyn CredentialsProvider,
@@ -0,0 +1,43 @@
+use anyhow::Result;
+use credentials_provider::CredentialsProvider;
+use gpui::{App, Task};
+
+const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
+const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
+const GOOGLE_AI_EXTENSION_CREDENTIAL_KEY: &str = "extension-llm-google-ai:google-ai";
+
+/// Returns the Google AI API key for use by the Gemini CLI.
+///
+/// This function checks the following sources in order:
+/// 1. `GEMINI_API_KEY` environment variable
+/// 2. `GOOGLE_AI_API_KEY` environment variable
+/// 3. Extension credential store (`extension-llm-google-ai:google-ai`)
+pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
+ if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR_NAME) {
+ if !key.is_empty() {
+ return Task::ready(Ok(key));
+ }
+ }
+
+ if let Ok(key) = std::env::var(GOOGLE_AI_API_KEY_VAR_NAME) {
+ if !key.is_empty() {
+ return Task::ready(Ok(key));
+ }
+ }
+
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ let credential = credentials_provider
+ .read_credentials(GOOGLE_AI_EXTENSION_CREDENTIAL_KEY, &cx)
+ .await?;
+
+ match credential {
+ Some((_, key_bytes)) => {
+ let key = String::from_utf8(key_bytes)?;
+ Ok(key)
+ }
+ None => Err(anyhow::anyhow!("No Google AI API key found")),
+ }
+ })
+}
@@ -10,20 +10,19 @@ use provider::deepseek::DeepSeekLanguageModelProvider;
mod api_key;
mod extension;
+mod google_ai_api_key;
pub mod provider;
mod settings;
pub mod ui;
+pub use google_ai_api_key::api_key_for_gemini_cli;
+
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
-use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
-use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
pub use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
-use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
-use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
@@ -118,10 +117,6 @@ fn register_language_model_providers(
)),
cx,
);
- registry.register_provider(
- Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
- cx,
- );
registry.register_provider(
Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
cx,
@@ -134,10 +129,6 @@ fn register_language_model_providers(
Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
- registry.register_provider(
- Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
- cx,
- );
registry.register_provider(
MistralLanguageModelProvider::global(client.http_client(), cx),
cx,
@@ -146,13 +137,6 @@ fn register_language_model_providers(
Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
- registry.register_provider(
- Arc::new(OpenRouterLanguageModelProvider::new(
- client.http_client(),
- cx,
- )),
- cx,
- );
registry.register_provider(
Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
cx,
@@ -161,5 +145,4 @@ fn register_language_model_providers(
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
- registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
}
@@ -1,6 +1,5 @@
pub mod bedrock;
pub mod cloud;
-pub mod copilot_chat;
pub mod deepseek;
pub mod google;
pub mod lmstudio;
@@ -8,6 +7,5 @@ pub mod mistral;
pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
-pub mod open_router;
pub mod vercel;
pub mod x_ai;
@@ -1,1565 +0,0 @@
-use std::pin::Pin;
-use std::str::FromStr as _;
-use std::sync::Arc;
-
-use anyhow::{Result, anyhow};
-use cloud_llm_client::CompletionIntent;
-use collections::HashMap;
-use copilot::copilot_chat::{
- ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
- Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
- ToolCall,
-};
-use copilot::{Copilot, Status};
-use futures::future::BoxFuture;
-use futures::stream::BoxStream;
-use futures::{FutureExt, Stream, StreamExt};
-use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg};
-use http_client::StatusCode;
-use language::language_settings::all_language_settings;
-use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
- LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
- StopReason, TokenUsage,
-};
-use settings::SettingsStore;
-use ui::{CommonAnimationExt, prelude::*};
-use util::debug_panic;
-
-use crate::ui::ConfiguredApiCard;
-
-const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
-const PROVIDER_NAME: LanguageModelProviderName =
- LanguageModelProviderName::new("GitHub Copilot Chat");
-
-pub struct CopilotChatLanguageModelProvider {
- state: Entity<State>,
-}
-
-pub struct State {
- _copilot_chat_subscription: Option<Subscription>,
- _settings_subscription: Subscription,
-}
-
-impl State {
- fn is_authenticated(&self, cx: &App) -> bool {
- CopilotChat::global(cx)
- .map(|m| m.read(cx).is_authenticated())
- .unwrap_or(false)
- }
-}
-
-impl CopilotChatLanguageModelProvider {
- pub fn new(cx: &mut App) -> Self {
- let state = cx.new(|cx| {
- let copilot_chat_subscription = CopilotChat::global(cx)
- .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
- State {
- _copilot_chat_subscription: copilot_chat_subscription,
- _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
- if let Some(copilot_chat) = CopilotChat::global(cx) {
- let language_settings = all_language_settings(None, cx);
- let configuration = copilot::copilot_chat::CopilotChatConfiguration {
- enterprise_uri: language_settings
- .edit_predictions
- .copilot
- .enterprise_uri
- .clone(),
- };
- copilot_chat.update(cx, |chat, cx| {
- chat.set_configuration(configuration, cx);
- });
- }
- cx.notify();
- }),
- }
- });
-
- Self { state }
- }
-
- fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
- Arc::new(CopilotChatLanguageModel {
- model,
- request_limiter: RateLimiter::new(4),
- })
- }
-}
-
-impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
- type ObservableEntity = State;
-
- fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
- Some(self.state.clone())
- }
-}
-
-impl LanguageModelProvider for CopilotChatLanguageModelProvider {
- fn id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn icon(&self) -> IconName {
- IconName::Copilot
- }
-
- fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
- models
- .first()
- .map(|model| self.create_language_model(model.clone()))
- }
-
- fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
- // model (e.g. 4o) and a sensible choice when considering premium requests
- self.default_model(cx)
- }
-
- fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
- return Vec::new();
- };
- models
- .iter()
- .map(|model| self.create_language_model(model.clone()))
- .collect()
- }
-
- fn is_authenticated(&self, cx: &App) -> bool {
- self.state.read(cx).is_authenticated(cx)
- }
-
- fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated(cx) {
- return Task::ready(Ok(()));
- };
-
- let Some(copilot) = Copilot::global(cx) else {
- return Task::ready(Err(anyhow!(concat!(
- "Copilot must be enabled for Copilot Chat to work. ",
- "Please enable Copilot and try again."
- ))
- .into()));
- };
-
- let err = match copilot.read(cx).status() {
- Status::Authorized => return Task::ready(Ok(())),
- Status::Disabled => anyhow!(
- "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
- ),
- Status::Error(err) => anyhow!(format!(
- "Received the following error while signing into Copilot: {err}"
- )),
- Status::Starting { task: _ } => anyhow!(
- "Copilot is still starting, please wait for Copilot to start then try again"
- ),
- Status::Unauthorized => anyhow!(
- "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
- ),
- Status::SignedOut { .. } => {
- anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
- }
- Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
- };
-
- Task::ready(Err(err.into()))
- }
-
- fn configuration_view(
- &self,
- _target_agent: language_model::ConfigurationViewTargetAgent,
- _: &mut Window,
- cx: &mut App,
- ) -> AnyView {
- let state = self.state.clone();
- cx.new(|cx| ConfigurationView::new(state, cx)).into()
- }
-
- fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
- Task::ready(Err(anyhow!(
- "Signing out of GitHub Copilot Chat is currently not supported."
- )))
- }
-}
-
-fn collect_tiktoken_messages(
- request: LanguageModelRequest,
-) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
- request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::<Vec<_>>()
-}
-
-pub struct CopilotChatLanguageModel {
- model: CopilotChatModel,
- request_limiter: RateLimiter,
-}
-
-impl LanguageModel for CopilotChatLanguageModel {
- fn id(&self) -> LanguageModelId {
- LanguageModelId::from(self.model.id().to_string())
- }
-
- fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name().to_string())
- }
-
- fn provider_id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn provider_name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn supports_tools(&self) -> bool {
- self.model.supports_tools()
- }
-
- fn supports_images(&self) -> bool {
- self.model.supports_vision()
- }
-
- fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
- match self.model.vendor() {
- ModelVendor::OpenAI | ModelVendor::Anthropic => {
- LanguageModelToolSchemaFormat::JsonSchema
- }
- ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => {
- LanguageModelToolSchemaFormat::JsonSchemaSubset
- }
- }
- }
-
- fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
- match choice {
- LanguageModelToolChoice::Auto
- | LanguageModelToolChoice::Any
- | LanguageModelToolChoice::None => self.supports_tools(),
- }
- }
-
- fn telemetry_id(&self) -> String {
- format!("copilot_chat/{}", self.model.id())
- }
-
- fn max_token_count(&self) -> u64 {
- self.model.max_token_count()
- }
-
- fn count_tokens(
- &self,
- request: LanguageModelRequest,
- cx: &App,
- ) -> BoxFuture<'static, Result<u64>> {
- let model = self.model.clone();
- cx.background_spawn(async move {
- let messages = collect_tiktoken_messages(request);
- // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor).
- let tokenizer_model = match model.tokenizer() {
- Some("o200k_base") => "gpt-4o",
- Some("cl100k_base") => "gpt-4",
- _ => "gpt-4o",
- };
-
- tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages)
- .map(|tokens| tokens as u64)
- })
- .boxed()
- }
-
- fn stream_completion(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
- LanguageModelCompletionError,
- >,
- > {
- let is_user_initiated = request.intent.is_none_or(|intent| match intent {
- CompletionIntent::UserPrompt
- | CompletionIntent::ThreadContextSummarization
- | CompletionIntent::InlineAssist
- | CompletionIntent::TerminalInlineAssist
- | CompletionIntent::GenerateGitCommitMessage => true,
-
- CompletionIntent::ToolResults
- | CompletionIntent::ThreadSummarization
- | CompletionIntent::CreateFile
- | CompletionIntent::EditFile => false,
- });
-
- if self.model.supports_response() {
- let responses_request = into_copilot_responses(&self.model, request);
- let request_limiter = self.request_limiter.clone();
- let future = cx.spawn(async move |cx| {
- let request =
- CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone());
- request_limiter
- .stream(async move {
- let stream = request.await?;
- let mapper = CopilotResponsesEventMapper::new();
- Ok(mapper.map_stream(stream).boxed())
- })
- .await
- });
- return async move { Ok(future.await?.boxed()) }.boxed();
- }
-
- let copilot_request = match into_copilot_chat(&self.model, request) {
- Ok(request) => request,
- Err(err) => return futures::future::ready(Err(err.into())).boxed(),
- };
- let is_streaming = copilot_request.stream;
-
- let request_limiter = self.request_limiter.clone();
- let future = cx.spawn(async move |cx| {
- let request =
- CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
- request_limiter
- .stream(async move {
- let response = request.await?;
- Ok(map_to_language_model_completion_events(
- response,
- is_streaming,
- ))
- })
- .await
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
-}
-
-pub fn map_to_language_model_completion_events(
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
- is_streaming: bool,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- #[derive(Default)]
- struct RawToolCall {
- id: String,
- name: String,
- arguments: String,
- thought_signature: Option<String>,
- }
-
- struct State {
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
- tool_calls_by_index: HashMap<usize, RawToolCall>,
- reasoning_opaque: Option<String>,
- reasoning_text: Option<String>,
- }
-
- futures::stream::unfold(
- State {
- events,
- tool_calls_by_index: HashMap::default(),
- reasoning_opaque: None,
- reasoning_text: None,
- },
- move |mut state| async move {
- if let Some(event) = state.events.next().await {
- match event {
- Ok(event) => {
- let Some(choice) = event.choices.first() else {
- return Some((
- vec![Err(anyhow!("Response contained no choices").into())],
- state,
- ));
- };
-
- let delta = if is_streaming {
- choice.delta.as_ref()
- } else {
- choice.message.as_ref()
- };
-
- let Some(delta) = delta else {
- return Some((
- vec![Err(anyhow!("Response contained no delta").into())],
- state,
- ));
- };
-
- let mut events = Vec::new();
- if let Some(content) = delta.content.clone() {
- events.push(Ok(LanguageModelCompletionEvent::Text(content)));
- }
-
- // Capture reasoning data from the delta (e.g. for Gemini 3)
- if let Some(opaque) = delta.reasoning_opaque.clone() {
- state.reasoning_opaque = Some(opaque);
- }
- if let Some(text) = delta.reasoning_text.clone() {
- state.reasoning_text = Some(text);
- }
-
- for (index, tool_call) in delta.tool_calls.iter().enumerate() {
- let tool_index = tool_call.index.unwrap_or(index);
- let entry = state.tool_calls_by_index.entry(tool_index).or_default();
-
- if let Some(tool_id) = tool_call.id.clone() {
- entry.id = tool_id;
- }
-
- if let Some(function) = tool_call.function.as_ref() {
- if let Some(name) = function.name.clone() {
- entry.name = name;
- }
-
- if let Some(arguments) = function.arguments.clone() {
- entry.arguments.push_str(&arguments);
- }
-
- if let Some(thought_signature) = function.thought_signature.clone()
- {
- entry.thought_signature = Some(thought_signature);
- }
- }
- }
-
- if let Some(usage) = event.usage {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
- TokenUsage {
- input_tokens: usage.prompt_tokens,
- output_tokens: usage.completion_tokens,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- },
- )));
- }
-
- match choice.finish_reason.as_deref() {
- Some("stop") => {
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::EndTurn,
- )));
- }
- Some("tool_calls") => {
- // Gemini 3 models send reasoning_opaque/reasoning_text that must
- // be preserved and sent back in subsequent requests. Emit as
- // ReasoningDetails so the agent stores it in the message.
- if state.reasoning_opaque.is_some()
- || state.reasoning_text.is_some()
- {
- let mut details = serde_json::Map::new();
- if let Some(opaque) = state.reasoning_opaque.take() {
- details.insert(
- "reasoning_opaque".to_string(),
- serde_json::Value::String(opaque),
- );
- }
- if let Some(text) = state.reasoning_text.take() {
- details.insert(
- "reasoning_text".to_string(),
- serde_json::Value::String(text),
- );
- }
- events.push(Ok(
- LanguageModelCompletionEvent::ReasoningDetails(
- serde_json::Value::Object(details),
- ),
- ));
- }
-
- events.extend(state.tool_calls_by_index.drain().map(
- |(_, tool_call)| {
- // The model can output an empty string
- // to indicate the absence of arguments.
- // When that happens, create an empty
- // object instead.
- let arguments = if tool_call.arguments.is_empty() {
- Ok(serde_json::Value::Object(Default::default()))
- } else {
- serde_json::Value::from_str(&tool_call.arguments)
- };
- match arguments {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: tool_call.arguments,
- thought_signature: tool_call.thought_signature,
- },
- )),
- Err(error) => Ok(
- LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: tool_call.id.into(),
- tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.into(),
- json_parse_error: error.to_string(),
- },
- ),
- }
- },
- ));
-
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::ToolUse,
- )));
- }
- Some(stop_reason) => {
- log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::EndTurn,
- )));
- }
- None => {}
- }
-
- return Some((events, state));
- }
- Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
- }
- }
-
- None
- },
- )
- .flat_map(futures::stream::iter)
-}
-
-pub struct CopilotResponsesEventMapper {
- pending_stop_reason: Option<StopReason>,
-}
-
-impl CopilotResponsesEventMapper {
- pub fn new() -> Self {
- Self {
- pending_stop_reason: None,
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<copilot::copilot_responses::StreamEvent>>>>,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events.flat_map(move |event| {
- futures::stream::iter(match event {
- Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
- })
- })
- }
-
- fn map_event(
- &mut self,
- event: copilot::copilot_responses::StreamEvent,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- match event {
- copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item {
- copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => {
- vec![Ok(LanguageModelCompletionEvent::StartMessage {
- message_id: id,
- })]
- }
- _ => Vec::new(),
- },
-
- copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => {
- if delta.is_empty() {
- Vec::new()
- } else {
- vec![Ok(LanguageModelCompletionEvent::Text(delta))]
- }
- }
-
- copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item {
- copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(),
- copilot::copilot_responses::ResponseOutputItem::FunctionCall {
- call_id,
- name,
- arguments,
- thought_signature,
- ..
- } => {
- let mut events = Vec::new();
- match serde_json::from_str::<serde_json::Value>(&arguments) {
- Ok(input) => events.push(Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: call_id.into(),
- name: name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: arguments.clone(),
- thought_signature,
- },
- ))),
- Err(error) => {
- events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: call_id.into(),
- tool_name: name.as_str().into(),
- raw_input: arguments.clone().into(),
- json_parse_error: error.to_string(),
- }))
- }
- }
- // Record that we already emitted a tool-use stop so we can avoid duplicating
- // a Stop event on Completed.
- self.pending_stop_reason = Some(StopReason::ToolUse);
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
- events
- }
- copilot::copilot_responses::ResponseOutputItem::Reasoning {
- summary,
- encrypted_content,
- ..
- } => {
- let mut events = Vec::new();
-
- if let Some(blocks) = summary {
- let mut text = String::new();
- for block in blocks {
- text.push_str(&block.text);
- }
- if !text.is_empty() {
- events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text,
- signature: None,
- }));
- }
- }
-
- if let Some(data) = encrypted_content {
- events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data }));
- }
-
- events
- }
- },
-
- copilot::copilot_responses::StreamEvent::Completed { response } => {
- let mut events = Vec::new();
- if let Some(usage) = response.usage {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: usage.input_tokens.unwrap_or(0),
- output_tokens: usage.output_tokens.unwrap_or(0),
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- })));
- }
- if self.pending_stop_reason.take() != Some(StopReason::ToolUse) {
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
- }
- events
- }
-
- copilot::copilot_responses::StreamEvent::Incomplete { response } => {
- let reason = response
- .incomplete_details
- .as_ref()
- .and_then(|details| details.reason.as_ref());
- let stop_reason = match reason {
- Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => {
- StopReason::MaxTokens
- }
- Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => {
- StopReason::Refusal
- }
- _ => self
- .pending_stop_reason
- .take()
- .unwrap_or(StopReason::EndTurn),
- };
-
- let mut events = Vec::new();
- if let Some(usage) = response.usage {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: usage.input_tokens.unwrap_or(0),
- output_tokens: usage.output_tokens.unwrap_or(0),
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- })));
- }
- events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
- events
- }
-
- copilot::copilot_responses::StreamEvent::Failed { response } => {
- let provider = PROVIDER_NAME;
- let (status_code, message) = match response.error {
- Some(error) => {
- let status_code = StatusCode::from_str(&error.code)
- .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
- (status_code, error.message)
- }
- None => (
- StatusCode::INTERNAL_SERVER_ERROR,
- "response.failed".to_string(),
- ),
- };
- vec![Err(LanguageModelCompletionError::HttpResponseError {
- provider,
- status_code,
- message,
- })]
- }
-
- copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err(
- LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))),
- )],
-
- copilot::copilot_responses::StreamEvent::Created { .. }
- | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(),
- }
- }
-}
-
-fn into_copilot_chat(
- model: &copilot::copilot_chat::Model,
- request: LanguageModelRequest,
-) -> Result<CopilotChatRequest> {
- let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
- for message in request.messages {
- if let Some(last_message) = request_messages.last_mut() {
- if last_message.role == message.role {
- last_message.content.extend(message.content);
- } else {
- request_messages.push(message);
- }
- } else {
- request_messages.push(message);
- }
- }
-
- let mut messages: Vec<ChatMessage> = Vec::new();
- for message in request_messages {
- match message.role {
- Role::User => {
- for content in &message.content {
- if let MessageContent::ToolResult(tool_result) = content {
- let content = match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => text.to_string().into(),
- LanguageModelToolResultContent::Image(image) => {
- if model.supports_vision() {
- ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
- image_url: ImageUrl {
- url: image.to_base64_url(),
- },
- }])
- } else {
- debug_panic!(
- "This should be caught at {} level",
- tool_result.tool_name
- );
- "[Tool responded with an image, but this model does not support vision]".to_string().into()
- }
- }
- };
-
- messages.push(ChatMessage::Tool {
- tool_call_id: tool_result.tool_use_id.to_string(),
- content,
- });
- }
- }
-
- let mut content_parts = Vec::new();
- for content in &message.content {
- match content {
- MessageContent::Text(text) | MessageContent::Thinking { text, .. }
- if !text.is_empty() =>
- {
- if let Some(ChatMessagePart::Text { text: text_content }) =
- content_parts.last_mut()
- {
- text_content.push_str(text);
- } else {
- content_parts.push(ChatMessagePart::Text {
- text: text.to_string(),
- });
- }
- }
- MessageContent::Image(image) if model.supports_vision() => {
- content_parts.push(ChatMessagePart::Image {
- image_url: ImageUrl {
- url: image.to_base64_url(),
- },
- });
- }
- _ => {}
- }
- }
-
- if !content_parts.is_empty() {
- messages.push(ChatMessage::User {
- content: content_parts.into(),
- });
- }
- }
- Role::Assistant => {
- let mut tool_calls = Vec::new();
- for content in &message.content {
- if let MessageContent::ToolUse(tool_use) = content {
- tool_calls.push(ToolCall {
- id: tool_use.id.to_string(),
- content: copilot::copilot_chat::ToolCallContent::Function {
- function: copilot::copilot_chat::FunctionContent {
- name: tool_use.name.to_string(),
- arguments: serde_json::to_string(&tool_use.input)?,
- thought_signature: tool_use.thought_signature.clone(),
- },
- },
- });
- }
- }
-
- let text_content = {
- let mut buffer = String::new();
- for string in message.content.iter().filter_map(|content| match content {
- MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
- Some(text.as_str())
- }
- MessageContent::ToolUse(_)
- | MessageContent::RedactedThinking(_)
- | MessageContent::ToolResult(_)
- | MessageContent::Image(_) => None,
- }) {
- buffer.push_str(string);
- }
-
- buffer
- };
-
- // Extract reasoning_opaque and reasoning_text from reasoning_details
- let (reasoning_opaque, reasoning_text) =
- if let Some(details) = &message.reasoning_details {
- let opaque = details
- .get("reasoning_opaque")
- .and_then(|v| v.as_str())
- .map(|s| s.to_string());
- let text = details
- .get("reasoning_text")
- .and_then(|v| v.as_str())
- .map(|s| s.to_string());
- (opaque, text)
- } else {
- (None, None)
- };
-
- messages.push(ChatMessage::Assistant {
- content: if text_content.is_empty() {
- ChatMessageContent::empty()
- } else {
- text_content.into()
- },
- tool_calls,
- reasoning_opaque,
- reasoning_text,
- });
- }
- Role::System => messages.push(ChatMessage::System {
- content: message.string_contents(),
- }),
- }
- }
-
- let tools = request
- .tools
- .iter()
- .map(|tool| Tool::Function {
- function: copilot::copilot_chat::Function {
- name: tool.name.clone(),
- description: tool.description.clone(),
- parameters: tool.input_schema.clone(),
- },
- })
- .collect::<Vec<_>>();
-
- Ok(CopilotChatRequest {
- intent: true,
- n: 1,
- stream: model.uses_streaming(),
- temperature: 0.1,
- model: model.id().to_string(),
- messages,
- tools,
- tool_choice: request.tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
- LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
- LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
- }),
- })
-}
-
-fn into_copilot_responses(
- model: &copilot::copilot_chat::Model,
- request: LanguageModelRequest,
-) -> copilot::copilot_responses::Request {
- use copilot::copilot_responses as responses;
-
- let LanguageModelRequest {
- thread_id: _,
- prompt_id: _,
- intent: _,
- mode: _,
- messages,
- tools,
- tool_choice,
- stop: _,
- temperature,
- thinking_allowed: _,
- } = request;
-
- let mut input_items: Vec<responses::ResponseInputItem> = Vec::new();
-
- for message in messages {
- match message.role {
- Role::User => {
- for content in &message.content {
- if let MessageContent::ToolResult(tool_result) = content {
- let output = if let Some(out) = &tool_result.output {
- match out {
- serde_json::Value::String(s) => {
- responses::ResponseFunctionOutput::Text(s.clone())
- }
- serde_json::Value::Null => {
- responses::ResponseFunctionOutput::Text(String::new())
- }
- other => responses::ResponseFunctionOutput::Text(other.to_string()),
- }
- } else {
- match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- responses::ResponseFunctionOutput::Text(text.to_string())
- }
- LanguageModelToolResultContent::Image(image) => {
- if model.supports_vision() {
- responses::ResponseFunctionOutput::Content(vec![
- responses::ResponseInputContent::InputImage {
- image_url: Some(image.to_base64_url()),
- detail: Default::default(),
- },
- ])
- } else {
- debug_panic!(
- "This should be caught at {} level",
- tool_result.tool_name
- );
- responses::ResponseFunctionOutput::Text(
- "[Tool responded with an image, but this model does not support vision]".into(),
- )
- }
- }
- }
- };
-
- input_items.push(responses::ResponseInputItem::FunctionCallOutput {
- call_id: tool_result.tool_use_id.to_string(),
- output,
- status: None,
- });
- }
- }
-
- let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
- for content in &message.content {
- match content {
- MessageContent::Text(text) => {
- parts.push(responses::ResponseInputContent::InputText {
- text: text.clone(),
- });
- }
-
- MessageContent::Image(image) => {
- if model.supports_vision() {
- parts.push(responses::ResponseInputContent::InputImage {
- image_url: Some(image.to_base64_url()),
- detail: Default::default(),
- });
- }
- }
- _ => {}
- }
- }
-
- if !parts.is_empty() {
- input_items.push(responses::ResponseInputItem::Message {
- role: "user".into(),
- content: Some(parts),
- status: None,
- });
- }
- }
-
- Role::Assistant => {
- for content in &message.content {
- if let MessageContent::ToolUse(tool_use) = content {
- input_items.push(responses::ResponseInputItem::FunctionCall {
- call_id: tool_use.id.to_string(),
- name: tool_use.name.to_string(),
- arguments: tool_use.raw_input.clone(),
- status: None,
- thought_signature: tool_use.thought_signature.clone(),
- });
- }
- }
-
- for content in &message.content {
- if let MessageContent::RedactedThinking(data) = content {
- input_items.push(responses::ResponseInputItem::Reasoning {
- id: None,
- summary: Vec::new(),
- encrypted_content: data.clone(),
- });
- }
- }
-
- let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
- for content in &message.content {
- match content {
- MessageContent::Text(text) => {
- parts.push(responses::ResponseInputContent::OutputText {
- text: text.clone(),
- });
- }
- MessageContent::Image(_) => {
- parts.push(responses::ResponseInputContent::OutputText {
- text: "[image omitted]".to_string(),
- });
- }
- _ => {}
- }
- }
-
- if !parts.is_empty() {
- input_items.push(responses::ResponseInputItem::Message {
- role: "assistant".into(),
- content: Some(parts),
- status: Some("completed".into()),
- });
- }
- }
-
- Role::System => {
- let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
- for content in &message.content {
- if let MessageContent::Text(text) = content {
- parts.push(responses::ResponseInputContent::InputText {
- text: text.clone(),
- });
- }
- }
-
- if !parts.is_empty() {
- input_items.push(responses::ResponseInputItem::Message {
- role: "system".into(),
- content: Some(parts),
- status: None,
- });
- }
- }
- }
- }
-
- let converted_tools: Vec<responses::ToolDefinition> = tools
- .into_iter()
- .map(|tool| responses::ToolDefinition::Function {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- strict: None,
- })
- .collect();
-
- let mapped_tool_choice = tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => responses::ToolChoice::Auto,
- LanguageModelToolChoice::Any => responses::ToolChoice::Any,
- LanguageModelToolChoice::None => responses::ToolChoice::None,
- });
-
- responses::Request {
- model: model.id().to_string(),
- input: input_items,
- stream: model.uses_streaming(),
- temperature,
- tools: converted_tools,
- tool_choice: mapped_tool_choice,
- reasoning: None, // We would need to add support for setting from user settings.
- include: Some(vec![
- copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent,
- ]),
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use copilot::copilot_responses as responses;
- use futures::StreamExt;
-
- fn map_events(events: Vec<responses::StreamEvent>) -> Vec<LanguageModelCompletionEvent> {
- futures::executor::block_on(async {
- CopilotResponsesEventMapper::new()
- .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
- .collect::<Vec<_>>()
- .await
- .into_iter()
- .map(Result::unwrap)
- .collect()
- })
- }
-
- #[test]
- fn responses_stream_maps_text_and_usage() {
- let events = vec![
- responses::StreamEvent::OutputItemAdded {
- output_index: 0,
- sequence_number: None,
- item: responses::ResponseOutputItem::Message {
- id: "msg_1".into(),
- role: "assistant".into(),
- content: Some(Vec::new()),
- },
- },
- responses::StreamEvent::OutputTextDelta {
- item_id: "msg_1".into(),
- output_index: 0,
- delta: "Hello".into(),
- },
- responses::StreamEvent::Completed {
- response: responses::Response {
- usage: Some(responses::ResponseUsage {
- input_tokens: Some(5),
- output_tokens: Some(3),
- total_tokens: Some(8),
- }),
- ..Default::default()
- },
- },
- ];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_1"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Text(ref text) if text == "Hello"
- ));
- assert!(matches!(
- mapped[2],
- LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: 5,
- output_tokens: 3,
- ..
- })
- ));
- assert!(matches!(
- mapped[3],
- LanguageModelCompletionEvent::Stop(StopReason::EndTurn)
- ));
- }
-
- #[test]
- fn responses_stream_maps_tool_calls() {
- let events = vec![responses::StreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: responses::ResponseOutputItem::FunctionCall {
- id: Some("fn_1".into()),
- call_id: "call_1".into(),
- name: "do_it".into(),
- arguments: "{\"x\":1}".into(),
- status: None,
- thought_signature: None,
- },
- }];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUse(ref use_) if use_.id.to_string() == "call_1" && use_.name.as_ref() == "do_it"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_handles_json_parse_error() {
- let events = vec![responses::StreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: responses::ResponseOutputItem::FunctionCall {
- id: Some("fn_1".into()),
- call_id: "call_1".into(),
- name: "do_it".into(),
- arguments: "{not json}".into(),
- status: None,
- thought_signature: None,
- },
- }];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::ToolUseJsonParseError { ref id, ref tool_name, .. }
- if id.to_string() == "call_1" && tool_name.as_ref() == "do_it"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
- ));
- }
-
- #[test]
- fn responses_stream_maps_reasoning_summary_and_encrypted_content() {
- let events = vec![responses::StreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: responses::ResponseOutputItem::Reasoning {
- id: "r1".into(),
- summary: Some(vec![responses::ResponseReasoningItem {
- kind: "summary_text".into(),
- text: "Chain".into(),
- }]),
- encrypted_content: Some("ENC".into()),
- },
- }];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain"
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC"
- ));
- }
-
- #[test]
- fn responses_stream_handles_incomplete_max_tokens() {
- let events = vec![responses::StreamEvent::Incomplete {
- response: responses::Response {
- usage: Some(responses::ResponseUsage {
- input_tokens: Some(10),
- output_tokens: Some(0),
- total_tokens: Some(10),
- }),
- incomplete_details: Some(responses::IncompleteDetails {
- reason: Some(responses::IncompleteReason::MaxOutputTokens),
- }),
- ..Default::default()
- },
- }];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped[0],
- LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: 10,
- output_tokens: 0,
- ..
- })
- ));
- assert!(matches!(
- mapped[1],
- LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
- ));
- }
-
- #[test]
- fn responses_stream_handles_incomplete_content_filter() {
- let events = vec![responses::StreamEvent::Incomplete {
- response: responses::Response {
- usage: None,
- incomplete_details: Some(responses::IncompleteDetails {
- reason: Some(responses::IncompleteReason::ContentFilter),
- }),
- ..Default::default()
- },
- }];
-
- let mapped = map_events(events);
- assert!(matches!(
- mapped.last().unwrap(),
- LanguageModelCompletionEvent::Stop(StopReason::Refusal)
- ));
- }
-
- #[test]
- fn responses_stream_completed_no_duplicate_after_tool_use() {
- let events = vec![
- responses::StreamEvent::OutputItemDone {
- output_index: 0,
- sequence_number: None,
- item: responses::ResponseOutputItem::FunctionCall {
- id: Some("fn_1".into()),
- call_id: "call_1".into(),
- name: "do_it".into(),
- arguments: "{}".into(),
- status: None,
- thought_signature: None,
- },
- },
- responses::StreamEvent::Completed {
- response: responses::Response::default(),
- },
- ];
-
- let mapped = map_events(events);
-
- let mut stop_count = 0usize;
- let mut saw_tool_use_stop = false;
- for event in mapped {
- if let LanguageModelCompletionEvent::Stop(reason) = event {
- stop_count += 1;
- if matches!(reason, StopReason::ToolUse) {
- saw_tool_use_stop = true;
- }
- }
- }
- assert_eq!(stop_count, 1, "should emit exactly one Stop event");
- assert!(saw_tool_use_stop, "Stop reason should be ToolUse");
- }
-
- #[test]
- fn responses_stream_failed_maps_http_response_error() {
- let events = vec![responses::StreamEvent::Failed {
- response: responses::Response {
- error: Some(responses::ResponseError {
- code: "429".into(),
- message: "too many requests".into(),
- }),
- ..Default::default()
- },
- }];
-
- let mapped_results = futures::executor::block_on(async {
- CopilotResponsesEventMapper::new()
- .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
- .collect::<Vec<_>>()
- .await
- });
-
- assert_eq!(mapped_results.len(), 1);
- match &mapped_results[0] {
- Err(LanguageModelCompletionError::HttpResponseError {
- status_code,
- message,
- ..
- }) => {
- assert_eq!(*status_code, http_client::StatusCode::TOO_MANY_REQUESTS);
- assert_eq!(message, "too many requests");
- }
- other => panic!("expected HttpResponseError, got {:?}", other),
- }
- }
-
- #[test]
- fn chat_completions_stream_maps_reasoning_data() {
- use copilot::copilot_chat::ResponseEvent;
-
- let events = vec![
- ResponseEvent {
- choices: vec![copilot::copilot_chat::ResponseChoice {
- index: Some(0),
- finish_reason: None,
- delta: Some(copilot::copilot_chat::ResponseDelta {
- content: None,
- role: Some(copilot::copilot_chat::Role::Assistant),
- tool_calls: vec![copilot::copilot_chat::ToolCallChunk {
- index: Some(0),
- id: Some("call_abc123".to_string()),
- function: Some(copilot::copilot_chat::FunctionChunk {
- name: Some("list_directory".to_string()),
- arguments: Some("{\"path\":\"test\"}".to_string()),
- thought_signature: None,
- }),
- }],
- reasoning_opaque: Some("encrypted_reasoning_token_xyz".to_string()),
- reasoning_text: Some("Let me check the directory".to_string()),
- }),
- message: None,
- }],
- id: "chatcmpl-123".to_string(),
- usage: None,
- },
- ResponseEvent {
- choices: vec![copilot::copilot_chat::ResponseChoice {
- index: Some(0),
- finish_reason: Some("tool_calls".to_string()),
- delta: Some(copilot::copilot_chat::ResponseDelta {
- content: None,
- role: None,
- tool_calls: vec![],
- reasoning_opaque: None,
- reasoning_text: None,
- }),
- message: None,
- }],
- id: "chatcmpl-123".to_string(),
- usage: None,
- },
- ];
-
- let mapped = futures::executor::block_on(async {
- map_to_language_model_completion_events(
- Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
- true,
- )
- .collect::<Vec<_>>()
- .await
- });
-
- let mut has_reasoning_details = false;
- let mut has_tool_use = false;
- let mut reasoning_opaque_value: Option<String> = None;
- let mut reasoning_text_value: Option<String> = None;
-
- for event_result in mapped {
- match event_result {
- Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => {
- has_reasoning_details = true;
- reasoning_opaque_value = details
- .get("reasoning_opaque")
- .and_then(|v| v.as_str())
- .map(|s| s.to_string());
- reasoning_text_value = details
- .get("reasoning_text")
- .and_then(|v| v.as_str())
- .map(|s| s.to_string());
- }
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
- has_tool_use = true;
- assert_eq!(tool_use.id.to_string(), "call_abc123");
- assert_eq!(tool_use.name.as_ref(), "list_directory");
- }
- _ => {}
- }
- }
-
- assert!(
- has_reasoning_details,
- "Should emit ReasoningDetails event for Gemini 3 reasoning"
- );
- assert!(has_tool_use, "Should emit ToolUse event");
- assert_eq!(
- reasoning_opaque_value,
- Some("encrypted_reasoning_token_xyz".to_string()),
- "Should capture reasoning_opaque"
- );
- assert_eq!(
- reasoning_text_value,
- Some("Let me check the directory".to_string()),
- "Should capture reasoning_text"
- );
- }
-}
-struct ConfigurationView {
- copilot_status: Option<copilot::Status>,
- state: Entity<State>,
- _subscription: Option<Subscription>,
-}
-
-impl ConfigurationView {
- pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
- let copilot = Copilot::global(cx);
-
- Self {
- copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
- state,
- _subscription: copilot.as_ref().map(|copilot| {
- cx.observe(copilot, |this, model, cx| {
- this.copilot_status = Some(model.read(cx).status());
- cx.notify();
- })
- }),
- }
- }
-}
-
-impl Render for ConfigurationView {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- if self.state.read(cx).is_authenticated(cx) {
- ConfiguredApiCard::new("Authorized")
- .button_label("Sign Out")
- .on_click(|_, window, cx| {
- window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
- })
- .into_any_element()
- } else {
- let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4);
-
- const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
-
- match &self.copilot_status {
- Some(status) => match status {
- Status::Starting { task: _ } => h_flex()
- .gap_2()
- .child(loading_icon)
- .child(Label::new("Starting Copilotβ¦"))
- .into_any_element(),
- Status::SigningIn { prompt: _ }
- | Status::SignedOut {
- awaiting_signing_in: true,
- } => h_flex()
- .gap_2()
- .child(loading_icon)
- .child(Label::new("Signing into Copilotβ¦"))
- .into_any_element(),
- Status::Error(_) => {
- const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
- v_flex()
- .gap_6()
- .child(Label::new(LABEL))
- .child(svg().size_8().path(IconName::CopilotError.path()))
- .into_any_element()
- }
- _ => {
- const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
-
- v_flex()
- .gap_2()
- .child(Label::new(LABEL))
- .child(
- Button::new("sign_in", "Sign in to use GitHub Copilot")
- .full_width()
- .style(ButtonStyle::Outlined)
- .icon_color(Color::Muted)
- .icon(IconName::Github)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .on_click(|_, window, cx| {
- copilot::initiate_sign_in(window, cx)
- }),
- )
- .into_any_element()
- }
- },
- None => v_flex()
- .gap_6()
- .child(Label::new(ERROR_LABEL))
- .into_any_element(),
- }
- }
- }
-}
@@ -1,44 +1,22 @@
-use anyhow::{Context as _, Result, anyhow};
-use collections::BTreeMap;
-use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
+use anyhow::Result;
+use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
-use http_client::HttpClient;
+use gpui::{App, AppContext as _};
use language_model::{
- AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
- LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
-};
-use language_model::{
- LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, RateLimiter, Role,
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
+ StopReason,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub use settings::GoogleAvailableModel as AvailableModel;
-use settings::{Settings, SettingsStore};
-use std::pin::Pin;
-use std::sync::{
- Arc, LazyLock,
- atomic::{self, AtomicU64},
+use std::{
+ pin::Pin,
+ sync::atomic::{self, AtomicU64},
};
-use strum::IntoEnumIterator;
-use ui::{List, prelude::*};
-use ui_input::InputField;
-use util::ResultExt;
-use zed_env_vars::EnvVar;
-
-use crate::api_key::ApiKey;
-use crate::api_key::ApiKeyState;
-use crate::ui::{ConfiguredApiCard, InstructionListItem};
-
-const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
-const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@@ -57,346 +35,6 @@ pub enum ModelMode {
},
}
-pub struct GoogleLanguageModelProvider {
- http_client: Arc<dyn HttpClient>,
- state: Entity<State>,
-}
-
-pub struct State {
- api_key_state: ApiKeyState,
-}
-
-const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
-const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
-
-static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
- // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
- EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
-});
-
-impl State {
- fn is_authenticated(&self) -> bool {
- self.api_key_state.has_key()
- }
-
- fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
- let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
- }
-
- fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- let api_url = GoogleLanguageModelProvider::api_url(cx);
- self.api_key_state.load_if_needed(
- api_url,
- &API_KEY_ENV_VAR,
- |this| &mut this.api_key_state,
- cx,
- )
- }
-}
-
-impl GoogleLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| {
- cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- let api_url = Self::api_url(cx);
- this.api_key_state.handle_url_change(
- api_url,
- &API_KEY_ENV_VAR,
- |this| &mut this.api_key_state,
- cx,
- );
- cx.notify();
- })
- .detach();
- State {
- api_key_state: ApiKeyState::new(Self::api_url(cx)),
- }
- });
-
- Self { http_client, state }
- }
-
- fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
- Arc::new(GoogleLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- })
- }
-
- pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
- if let Some(key) = API_KEY_ENV_VAR.value.clone() {
- return Task::ready(Ok(key));
- }
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = Self::api_url(cx).to_string();
- cx.spawn(async move |cx| {
- Ok(
- ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
- .await?
- .key()
- .to_string(),
- )
- })
- }
-
- fn settings(cx: &App) -> &GoogleSettings {
- &crate::AllLanguageModelSettings::get_global(cx).google
- }
-
- fn api_url(cx: &App) -> SharedString {
- let api_url = &Self::settings(cx).api_url;
- if api_url.is_empty() {
- google_ai::API_URL.into()
- } else {
- SharedString::new(api_url.as_str())
- }
- }
-}
-
-impl LanguageModelProviderState for GoogleLanguageModelProvider {
- type ObservableEntity = State;
-
- fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
- Some(self.state.clone())
- }
-}
-
-impl LanguageModelProvider for GoogleLanguageModelProvider {
- fn id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn icon(&self) -> IconName {
- IconName::AiGoogle
- }
-
- fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(google_ai::Model::default()))
- }
-
- fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(google_ai::Model::default_fast()))
- }
-
- fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = BTreeMap::default();
-
- // Add base models from google_ai::Model::iter()
- for model in google_ai::Model::iter() {
- if !matches!(model, google_ai::Model::Custom { .. }) {
- models.insert(model.id().to_string(), model);
- }
- }
-
- // Override with available models from settings
- for model in &GoogleLanguageModelProvider::settings(cx).available_models {
- models.insert(
- model.name.clone(),
- google_ai::Model::Custom {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- mode: model.mode.unwrap_or_default(),
- },
- );
- }
-
- models
- .into_values()
- .map(|model| {
- Arc::new(GoogleLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
- })
- .collect()
- }
-
- fn is_authenticated(&self, cx: &App) -> bool {
- self.state.read(cx).is_authenticated()
- }
-
- fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
- self.state.update(cx, |state, cx| state.authenticate(cx))
- }
-
- fn configuration_view(
- &self,
- target_agent: language_model::ConfigurationViewTargetAgent,
- window: &mut Window,
- cx: &mut App,
- ) -> AnyView {
- cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
- .into()
- }
-
- fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state
- .update(cx, |state, cx| state.set_api_key(None, cx))
- }
-}
-
-pub struct GoogleLanguageModel {
- id: LanguageModelId,
- model: google_ai::Model,
- state: Entity<State>,
- http_client: Arc<dyn HttpClient>,
- request_limiter: RateLimiter,
-}
-
-impl GoogleLanguageModel {
- fn stream_completion(
- &self,
- request: google_ai::GenerateContentRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
- > {
- let http_client = self.http_client.clone();
-
- let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
- let api_url = GoogleLanguageModelProvider::api_url(cx);
- (state.api_key_state.key(&api_url), api_url)
- }) else {
- return future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- async move {
- let api_key = api_key.context("Missing Google API key")?;
- let request = google_ai::stream_generate_content(
- http_client.as_ref(),
- &api_url,
- &api_key,
- request,
- );
- request.await.context("failed to stream completion")
- }
- .boxed()
- }
-}
-
-impl LanguageModel for GoogleLanguageModel {
- fn id(&self) -> LanguageModelId {
- self.id.clone()
- }
-
- fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name().to_string())
- }
-
- fn provider_id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn provider_name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn supports_tools(&self) -> bool {
- self.model.supports_tools()
- }
-
- fn supports_images(&self) -> bool {
- self.model.supports_images()
- }
-
- fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
- match choice {
- LanguageModelToolChoice::Auto
- | LanguageModelToolChoice::Any
- | LanguageModelToolChoice::None => true,
- }
- }
-
- fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
- LanguageModelToolSchemaFormat::JsonSchemaSubset
- }
-
- fn telemetry_id(&self) -> String {
- format!("google/{}", self.model.request_id())
- }
-
- fn max_token_count(&self) -> u64 {
- self.model.max_token_count()
- }
-
- fn max_output_tokens(&self) -> Option<u64> {
- self.model.max_output_tokens()
- }
-
- fn count_tokens(
- &self,
- request: LanguageModelRequest,
- cx: &App,
- ) -> BoxFuture<'static, Result<u64>> {
- let model_id = self.model.request_id().to_string();
- let request = into_google(request, model_id, self.model.mode());
- let http_client = self.http_client.clone();
- let api_url = GoogleLanguageModelProvider::api_url(cx);
- let api_key = self.state.read(cx).api_key_state.key(&api_url);
-
- async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- }
- .into());
- };
- let response = google_ai::count_tokens(
- http_client.as_ref(),
- &api_url,
- &api_key,
- google_ai::CountTokensRequest {
- generate_content_request: request,
- },
- )
- .await?;
- Ok(response.total_tokens)
- }
- .boxed()
- }
-
- fn stream_completion(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- futures::stream::BoxStream<
- 'static,
- Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
- >,
- LanguageModelCompletionError,
- >,
- > {
- let request = into_google(
- request,
- self.model.request_id().to_string(),
- self.model.mode(),
- );
- let request = self.stream_completion(request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await.map_err(LanguageModelCompletionError::from)?;
- Ok(GoogleEventMapper::new().map_stream(response))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
-}
-
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
@@ -439,7 +77,6 @@ pub fn into_google(
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
- // Normalize empty string signatures to None
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
@@ -457,7 +94,6 @@ pub fn into_google(
google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
- // The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
@@ -470,7 +106,6 @@ pub fn into_google(
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
- // The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
@@ -519,7 +154,7 @@ pub fn into_google(
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
- Role::System => google_ai::Role::User, // Google AI doesn't have a system role
+ Role::System => google_ai::Role::User,
},
})
}
@@ -653,13 +288,13 @@ impl GoogleEventMapper {
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
- let name: Arc<str> = function_call_part.function_call.name.into();
+ let name: std::sync::Arc<str> =
+ function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
- // Normalize empty string signatures to None
let thought_signature = function_call_part
.thought_signature
.filter(|s| !s.is_empty());
@@ -678,7 +313,7 @@ impl GoogleEventMapper {
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
+ text: "(Encrypted thought)".to_string(),
signature: Some(part.thought_signature),
}));
}
@@ -686,8 +321,6 @@ impl GoogleEventMapper {
}
}
- // Even when Gemini wants to use a Tool, the API
- // responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
@@ -700,8 +333,6 @@ pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
- // So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
let messages = request
.messages
@@ -718,8 +349,6 @@ pub fn count_google_tokens(
})
.collect::<Vec<_>>();
- // Tiktoken doesn't yet support these models, so we manually use the
- // same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
@@ -760,148 +389,6 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
}
}
-struct ConfigurationView {
- api_key_editor: Entity<InputField>,
- state: Entity<State>,
- target_agent: language_model::ConfigurationViewTargetAgent,
- load_credentials_task: Option<Task<()>>,
-}
-
-impl ConfigurationView {
- fn new(
- state: Entity<State>,
- target_agent: language_model::ConfigurationViewTargetAgent,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Self {
- cx.observe(&state, |_, _, cx| {
- cx.notify();
- })
- .detach();
-
- let load_credentials_task = Some(cx.spawn_in(window, {
- let state = state.clone();
- async move |this, cx| {
- if let Some(task) = state
- .update(cx, |state, cx| state.authenticate(cx))
- .log_err()
- {
- // We don't log an error, because "not signed in" is also an error.
- let _ = task.await;
- }
- this.update(cx, |this, cx| {
- this.load_credentials_task = None;
- cx.notify();
- })
- .log_err();
- }
- }));
-
- Self {
- api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
- target_agent,
- state,
- load_credentials_task,
- }
- }
-
- fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
- if api_key.is_empty() {
- return;
- }
-
- // url changes can cause the editor to be displayed again
- self.api_key_editor
- .update(cx, |editor, cx| editor.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- self.api_key_editor
- .update(cx, |editor, cx| editor.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(None, cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
- !self.state.read(cx).is_authenticated()
- }
-}
-
-impl Render for ConfigurationView {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
- let configured_card_label = if env_var_set {
- format!(
- "API key set in {} environment variable",
- API_KEY_ENV_VAR.name
- )
- } else {
- let api_url = GoogleLanguageModelProvider::api_url(cx);
- if api_url == google_ai::API_URL {
- "API key configured".to_string()
- } else {
- format!("API key configured for {}", api_url)
- }
- };
-
- if self.load_credentials_task.is_some() {
- div()
- .child(Label::new("Loading credentials..."))
- .into_any_element()
- } else if self.should_render_editor(cx) {
- v_flex()
- .size_full()
- .on_action(cx.listener(Self::save_api_key))
- .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
- ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
- ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
- })))
- .child(
- List::new()
- .child(InstructionListItem::new(
- "Create one by visiting",
- Some("Google AI's console"),
- Some("https://aistudio.google.com/app/apikey"),
- ))
- .child(InstructionListItem::text_only(
- "Paste your API key below and hit enter to start using the assistant",
- )),
- )
- .child(self.api_key_editor.clone())
- .child(
- Label::new(
- format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
- )
- .size(LabelSize::Small).color(Color::Muted),
- )
- .into_any_element()
- } else {
- ConfiguredApiCard::new(configured_card_label)
- .disabled(env_var_set)
- .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
- .when(env_var_set, |this| {
- this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
- })
- .into_any_element()
- }
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -940,7 +427,7 @@ mod tests {
let events = mapper.map_event(response);
- assert_eq!(events.len(), 2); // ToolUse event + Stop event
+ assert_eq!(events.len(), 2);
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "test_function");
@@ -1034,18 +521,25 @@ mod tests {
parts: vec![
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "function_1".to_string(),
- args: json!({"arg": "value1"}),
+ name: "function_a".to_string(),
+ args: json!({}),
},
- thought_signature: Some("signature_1".to_string()),
+ thought_signature: Some("sig_a".to_string()),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "function_2".to_string(),
- args: json!({"arg": "value2"}),
+ name: "function_b".to_string(),
+ args: json!({}),
},
thought_signature: None,
}),
+ Part::FunctionCallPart(FunctionCallPart {
+ function_call: FunctionCall {
+ name: "function_c".to_string(),
+ args: json!({}),
+ },
+ thought_signature: Some("sig_c".to_string()),
+ }),
],
role: GoogleRole::Model,
},
@@ -1060,35 +554,35 @@ mod tests {
let events = mapper.map_event(response);
- assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
-
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
- assert_eq!(tool_use.name.as_ref(), "function_1");
- assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
- } else {
- panic!("Expected ToolUse event for function_1");
- }
+ let tool_uses: Vec<_> = events
+ .iter()
+ .filter_map(|e| {
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = e {
+ Some(tool_use)
+ } else {
+ None
+ }
+ })
+ .collect();
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
- assert_eq!(tool_use.name.as_ref(), "function_2");
- assert_eq!(tool_use.thought_signature, None);
- } else {
- panic!("Expected ToolUse event for function_2");
- }
+ assert_eq!(tool_uses.len(), 3);
+ assert_eq!(tool_uses[0].thought_signature.as_deref(), Some("sig_a"));
+ assert_eq!(tool_uses[1].thought_signature, None);
+ assert_eq!(tool_uses[2].thought_signature.as_deref(), Some("sig_c"));
}
#[test]
fn test_tool_use_with_signature_converts_to_function_call_part() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
+ id: LanguageModelToolUseId::from("test-id"),
+ name: "test_tool".into(),
+ input: json!({"key": "value"}),
+ raw_input: r#"{"key": "value"}"#.to_string(),
is_input_complete: true,
- thought_signature: Some("test_signature_456".to_string()),
+ thought_signature: Some("test_sig".to_string()),
};
- let request = super::into_google(
+ let request = into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -1102,13 +596,11 @@ mod tests {
GoogleModelMode::Default,
);
- assert_eq!(request.contents[0].parts.len(), 1);
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.function_call.name, "test_function");
- assert_eq!(
- fc_part.thought_signature.as_deref(),
- Some("test_signature_456")
- );
+ let parts = &request.contents[0].parts;
+ assert_eq!(parts.len(), 1);
+
+ if let Part::FunctionCallPart(fcp) = &parts[0] {
+ assert_eq!(fcp.thought_signature.as_deref(), Some("test_sig"));
} else {
panic!("Expected FunctionCallPart");
}
@@ -1117,15 +609,15 @@ mod tests {
#[test]
fn test_tool_use_without_signature_omits_field() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
+ id: LanguageModelToolUseId::from("test-id"),
+ name: "test_tool".into(),
+ input: json!({"key": "value"}),
+ raw_input: r#"{"key": "value"}"#.to_string(),
is_input_complete: true,
thought_signature: None,
};
- let request = super::into_google(
+ let request = into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -1139,9 +631,10 @@ mod tests {
GoogleModelMode::Default,
);
- assert_eq!(request.contents[0].parts.len(), 1);
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature, None);
+ let parts = &request.contents[0].parts;
+
+ if let Part::FunctionCallPart(fcp) = &parts[0] {
+ assert_eq!(fcp.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
@@ -1150,15 +643,15 @@ mod tests {
#[test]
fn test_empty_signature_in_tool_use_normalized_to_none() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test_id"),
- name: "test_function".into(),
- raw_input: json!({"arg": "value"}).to_string(),
- input: json!({"arg": "value"}),
+ id: LanguageModelToolUseId::from("test-id"),
+ name: "test_tool".into(),
+ input: json!({}),
+ raw_input: "{}".to_string(),
is_input_complete: true,
thought_signature: Some("".to_string()),
};
- let request = super::into_google(
+ let request = into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -1172,8 +665,10 @@ mod tests {
GoogleModelMode::Default,
);
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature, None);
+ let parts = &request.contents[0].parts;
+
+ if let Part::FunctionCallPart(fcp) = &parts[0] {
+ assert_eq!(fcp.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
@@ -1181,9 +676,8 @@ mod tests {
#[test]
fn test_round_trip_preserves_signature() {
- let mut mapper = GoogleEventMapper::new();
+ let original_signature = "original_thought_signature_abc123";
- // Simulate receiving a response from Google with a signature
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
@@ -1193,7 +687,7 @@ mod tests {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
- thought_signature: Some("round_trip_sig".to_string()),
+ thought_signature: Some(original_signature.to_string()),
})],
role: GoogleRole::Model,
},
@@ -1206,6 +700,7 @@ mod tests {
usage_metadata: None,
};
+ let mut mapper = GoogleEventMapper::new();
let events = mapper.map_event(response);
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
@@ -1214,8 +709,7 @@ mod tests {
panic!("Expected ToolUse event");
};
- // Convert back to Google format
- let request = super::into_google(
+ let request = into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -1229,9 +723,9 @@ mod tests {
GoogleModelMode::Default,
);
- // Verify signature is preserved
- if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
- assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
+ let parts = &request.contents[0].parts;
+ if let Part::FunctionCallPart(fcp) = &parts[0] {
+ assert_eq!(fcp.thought_signature.as_deref(), Some(original_signature));
} else {
panic!("Expected FunctionCallPart");
}
@@ -1247,14 +741,14 @@ mod tests {
content: Content {
parts: vec![
Part::TextPart(TextPart {
- text: "I'll help with that.".to_string(),
+ text: "Let me help you with that.".to_string(),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "helper_function".to_string(),
- args: json!({"query": "help"}),
+ name: "search".to_string(),
+ args: json!({"query": "test"}),
},
- thought_signature: Some("mixed_sig".to_string()),
+ thought_signature: Some("thinking_sig".to_string()),
}),
],
role: GoogleRole::Model,
@@ -1270,27 +764,35 @@ mod tests {
let events = mapper.map_event(response);
- assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
+ let mut found_text = false;
+ let mut found_tool_with_sig = false;
- if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
- assert_eq!(text, "I'll help with that.");
- } else {
- panic!("Expected Text event");
+ for event in events {
+ match event {
+ Ok(LanguageModelCompletionEvent::Text(text)) => {
+ assert_eq!(text, "Let me help you with that.");
+ found_text = true;
+ }
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ assert_eq!(tool_use.thought_signature.as_deref(), Some("thinking_sig"));
+ found_tool_with_sig = true;
+ }
+ _ => {}
+ }
}
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
- assert_eq!(tool_use.name.as_ref(), "helper_function");
- assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
- } else {
- panic!("Expected ToolUse event");
- }
+ assert!(found_text, "Should have found text event");
+ assert!(
+ found_tool_with_sig,
+ "Should have found tool use with signature"
+ );
}
#[test]
fn test_special_characters_in_signature_preserved() {
- let mut mapper = GoogleEventMapper::new();
+ let special_signature = "sig/with+special=chars&more%stuff";
- let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
+ let mut mapper = GoogleEventMapper::new();
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
@@ -1298,10 +800,10 @@ mod tests {
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "test_function".to_string(),
- args: json!({"arg": "value"}),
+ name: "test".to_string(),
+ args: json!({}),
},
- thought_signature: Some(signature_with_special_chars.clone()),
+ thought_signature: Some(special_signature.to_string()),
})],
role: GoogleRole::Model,
},
@@ -1319,7 +821,7 @@ mod tests {
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(
tool_use.thought_signature.as_deref(),
- Some(signature_with_special_chars.as_str())
+ Some(special_signature)
);
} else {
panic!("Expected ToolUse event");
@@ -1,38 +1,17 @@
use anyhow::{Result, anyhow};
-use collections::{BTreeMap, HashMap};
-use futures::Stream;
-use futures::{FutureExt, StreamExt, future, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
-use http_client::HttpClient;
+use collections::HashMap;
+use futures::{FutureExt, Stream, future::BoxFuture};
+use gpui::{App, AppContext as _};
use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
- RateLimiter, Role, StopReason, TokenUsage,
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
};
-use menu;
-use open_ai::{
- ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
-};
-use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
+use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent};
+pub use settings::OpenAiAvailableModel as AvailableModel;
use std::pin::Pin;
-use std::str::FromStr as _;
-use std::sync::{Arc, LazyLock};
-use strum::IntoEnumIterator;
-use ui::{List, prelude::*};
-use ui_input::InputField;
-use util::ResultExt;
-use zed_env_vars::{EnvVar, env_var};
-
-use crate::ui::ConfiguredApiCard;
-use crate::{api_key::ApiKeyState, ui::InstructionListItem};
-
-const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
-const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
+use std::str::FromStr;
-const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
-static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+use language_model::LanguageModelToolResultContent;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@@ -40,314 +19,6 @@ pub struct OpenAiSettings {
pub available_models: Vec<AvailableModel>,
}
-pub struct OpenAiLanguageModelProvider {
- http_client: Arc<dyn HttpClient>,
- state: Entity<State>,
-}
-
-pub struct State {
- api_key_state: ApiKeyState,
-}
-
-impl State {
- fn is_authenticated(&self) -> bool {
- self.api_key_state.has_key()
- }
-
- fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
- let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
- }
-
- fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- let api_url = OpenAiLanguageModelProvider::api_url(cx);
- self.api_key_state.load_if_needed(
- api_url,
- &API_KEY_ENV_VAR,
- |this| &mut this.api_key_state,
- cx,
- )
- }
-}
-
-impl OpenAiLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| {
- cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- let api_url = Self::api_url(cx);
- this.api_key_state.handle_url_change(
- api_url,
- &API_KEY_ENV_VAR,
- |this| &mut this.api_key_state,
- cx,
- );
- cx.notify();
- })
- .detach();
- State {
- api_key_state: ApiKeyState::new(Self::api_url(cx)),
- }
- });
-
- Self { http_client, state }
- }
-
- fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
- Arc::new(OpenAiLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- })
- }
-
- fn settings(cx: &App) -> &OpenAiSettings {
- &crate::AllLanguageModelSettings::get_global(cx).openai
- }
-
- fn api_url(cx: &App) -> SharedString {
- let api_url = &Self::settings(cx).api_url;
- if api_url.is_empty() {
- open_ai::OPEN_AI_API_URL.into()
- } else {
- SharedString::new(api_url.as_str())
- }
- }
-}
-
-impl LanguageModelProviderState for OpenAiLanguageModelProvider {
- type ObservableEntity = State;
-
- fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
- Some(self.state.clone())
- }
-}
-
-impl LanguageModelProvider for OpenAiLanguageModelProvider {
- fn id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn icon(&self) -> IconName {
- IconName::AiOpenAi
- }
-
- fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(open_ai::Model::default()))
- }
-
- fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(open_ai::Model::default_fast()))
- }
-
- fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = BTreeMap::default();
-
- // Add base models from open_ai::Model::iter()
- for model in open_ai::Model::iter() {
- if !matches!(model, open_ai::Model::Custom { .. }) {
- models.insert(model.id().to_string(), model);
- }
- }
-
- // Override with available models from settings
- for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
- models.insert(
- model.name.clone(),
- open_ai::Model::Custom {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- max_output_tokens: model.max_output_tokens,
- max_completion_tokens: model.max_completion_tokens,
- reasoning_effort: model.reasoning_effort.clone(),
- },
- );
- }
-
- models
- .into_values()
- .map(|model| self.create_language_model(model))
- .collect()
- }
-
- fn is_authenticated(&self, cx: &App) -> bool {
- self.state.read(cx).is_authenticated()
- }
-
- fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
- self.state.update(cx, |state, cx| state.authenticate(cx))
- }
-
- fn configuration_view(
- &self,
- _target_agent: language_model::ConfigurationViewTargetAgent,
- window: &mut Window,
- cx: &mut App,
- ) -> AnyView {
- cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
- .into()
- }
-
- fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state
- .update(cx, |state, cx| state.set_api_key(None, cx))
- }
-}
-
-pub struct OpenAiLanguageModel {
- id: LanguageModelId,
- model: open_ai::Model,
- state: Entity<State>,
- http_client: Arc<dyn HttpClient>,
- request_limiter: RateLimiter,
-}
-
-impl OpenAiLanguageModel {
- fn stream_completion(
- &self,
- request: open_ai::Request,
- cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
- {
- let http_client = self.http_client.clone();
-
- let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
- let api_url = OpenAiLanguageModelProvider::api_url(cx);
- (state.api_key_state.key(&api_url), api_url)
- }) else {
- return future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- let future = self.request_limiter.stream(async move {
- let provider = PROVIDER_NAME;
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey { provider });
- };
- let request = stream_completion(
- http_client.as_ref(),
- provider.0.as_str(),
- &api_url,
- &api_key,
- request,
- );
- let response = request.await?;
- Ok(response)
- });
-
- async move { Ok(future.await?.boxed()) }.boxed()
- }
-}
-
-impl LanguageModel for OpenAiLanguageModel {
- fn id(&self) -> LanguageModelId {
- self.id.clone()
- }
-
- fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name().to_string())
- }
-
- fn provider_id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn provider_name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn supports_tools(&self) -> bool {
- true
- }
-
- fn supports_images(&self) -> bool {
- use open_ai::Model;
- match &self.model {
- Model::FourOmni
- | Model::FourOmniMini
- | Model::FourPointOne
- | Model::FourPointOneMini
- | Model::FourPointOneNano
- | Model::Five
- | Model::FiveMini
- | Model::FiveNano
- | Model::FivePointOne
- | Model::O1
- | Model::O3
- | Model::O4Mini => true,
- Model::ThreePointFiveTurbo
- | Model::Four
- | Model::FourTurbo
- | Model::O3Mini
- | Model::Custom { .. } => false,
- }
- }
-
- fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
- match choice {
- LanguageModelToolChoice::Auto => true,
- LanguageModelToolChoice::Any => true,
- LanguageModelToolChoice::None => true,
- }
- }
-
- fn telemetry_id(&self) -> String {
- format!("openai/{}", self.model.id())
- }
-
- fn max_token_count(&self) -> u64 {
- self.model.max_token_count()
- }
-
- fn max_output_tokens(&self) -> Option<u64> {
- self.model.max_output_tokens()
- }
-
- fn count_tokens(
- &self,
- request: LanguageModelRequest,
- cx: &App,
- ) -> BoxFuture<'static, Result<u64>> {
- count_open_ai_tokens(request, self.model.clone(), cx)
- }
-
- fn stream_completion(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- futures::stream::BoxStream<
- 'static,
- Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
- >,
- LanguageModelCompletionError,
- >,
- > {
- let request = into_open_ai(
- request,
- self.model.id(),
- self.model.supports_parallel_tool_calls(),
- self.model.supports_prompt_cache_key(),
- self.max_output_tokens(),
- self.model.reasoning_effort(),
- );
- let completions = self.stream_completion(request, cx);
- async move {
- let mapper = OpenAiEventMapper::new();
- Ok(mapper.map_stream(completions.await?).boxed())
- }
- .boxed()
- }
-}
-
pub fn into_open_ai(
request: LanguageModelRequest,
model_id: &str,
@@ -441,7 +112,6 @@ pub fn into_open_ai(
temperature: request.temperature.unwrap_or(1.0),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
- // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
Some(false)
} else {
None
@@ -521,6 +191,7 @@ impl OpenAiEventMapper {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
+ use futures::StreamExt;
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
@@ -648,19 +319,12 @@ pub fn count_open_ai_tokens(
match model {
Model::Custom { max_tokens, .. } => {
let model = if max_tokens >= 100_000 {
- // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
"gpt-4o"
} else {
- // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
- // supported with this tiktoken method
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model, &messages)
}
- // Currently supported by tiktoken_rs
- // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
- // arm with an override. We enumerate all supported models here so that we can check if new
- // models are supported yet or not.
Model::ThreePointFiveTurbo
| Model::Four
| Model::FourTurbo
@@ -675,7 +339,7 @@ pub fn count_open_ai_tokens(
| Model::O4Mini
| Model::Five
| Model::FiveMini
- | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), // GPT-5.1 doesn't have tiktoken support yet; fall back on gpt-4o tokenizer
+ | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
}
.map(|tokens| tokens as u64)
@@ -683,191 +347,11 @@ pub fn count_open_ai_tokens(
.boxed()
}
-struct ConfigurationView {
- api_key_editor: Entity<InputField>,
- state: Entity<State>,
- load_credentials_task: Option<Task<()>>,
-}
-
-impl ConfigurationView {
- fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
- let api_key_editor = cx.new(|cx| {
- InputField::new(
- window,
- cx,
- "sk-000000000000000000000000000000000000000000000000",
- )
- });
-
- cx.observe(&state, |_, _, cx| {
- cx.notify();
- })
- .detach();
-
- let load_credentials_task = Some(cx.spawn_in(window, {
- let state = state.clone();
- async move |this, cx| {
- if let Some(task) = state
- .update(cx, |state, cx| state.authenticate(cx))
- .log_err()
- {
- // We don't log an error, because "not signed in" is also an error.
- let _ = task.await;
- }
- this.update(cx, |this, cx| {
- this.load_credentials_task = None;
- cx.notify();
- })
- .log_err();
- }
- }));
-
- Self {
- api_key_editor,
- state,
- load_credentials_task,
- }
- }
-
- fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
- if api_key.is_empty() {
- return;
- }
-
- // url changes can cause the editor to be displayed again
- self.api_key_editor
- .update(cx, |editor, cx| editor.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- self.api_key_editor
- .update(cx, |input, cx| input.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(None, cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
- !self.state.read(cx).is_authenticated()
- }
-}
-
-impl Render for ConfigurationView {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
- let configured_card_label = if env_var_set {
- format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
- } else {
- let api_url = OpenAiLanguageModelProvider::api_url(cx);
- if api_url == OPEN_AI_API_URL {
- "API key configured".to_string()
- } else {
- format!("API key configured for {}", api_url)
- }
- };
-
- let api_key_section = if self.should_render_editor(cx) {
- v_flex()
- .on_action(cx.listener(Self::save_api_key))
- .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
- .child(
- List::new()
- .child(InstructionListItem::new(
- "Create one by visiting",
- Some("OpenAI's console"),
- Some("https://platform.openai.com/api-keys"),
- ))
- .child(InstructionListItem::text_only(
- "Ensure your OpenAI account has credits",
- ))
- .child(InstructionListItem::text_only(
- "Paste your API key below and hit enter to start using the assistant",
- )),
- )
- .child(self.api_key_editor.clone())
- .child(
- Label::new(format!(
- "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
- ))
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- Label::new(
- "Note that having a subscription for another service like GitHub Copilot won't work.",
- )
- .size(LabelSize::Small).color(Color::Muted),
- )
- .into_any_element()
- } else {
- ConfiguredApiCard::new(configured_card_label)
- .disabled(env_var_set)
- .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
- .when(env_var_set, |this| {
- this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
- })
- .into_any_element()
- };
-
- let compatible_api_section = h_flex()
- .mt_1p5()
- .gap_0p5()
- .flex_wrap()
- .when(self.should_render_editor(cx), |this| {
- this.pt_1p5()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- })
- .child(
- h_flex()
- .gap_2()
- .child(
- Icon::new(IconName::Info)
- .size(IconSize::XSmall)
- .color(Color::Muted),
- )
- .child(Label::new("Zed also supports OpenAI-compatible models.")),
- )
- .child(
- Button::new("docs", "Learn More")
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .on_click(move |_, _window, cx| {
- cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
- }),
- );
-
- if self.load_credentials_task.is_some() {
- div().child(Label::new("Loading credentialsβ¦")).into_any()
- } else {
- v_flex()
- .size_full()
- .child(api_key_section)
- .child(compatible_api_section)
- .into_any()
- }
- }
-}
-
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use language_model::LanguageModelRequestMessage;
+ use strum::IntoEnumIterator;
use super::*;
@@ -891,7 +375,6 @@ mod tests {
thinking_allowed: true,
};
- // Validate that all models are supported by tiktoken-rs
for model in Model::iter() {
let count = cx
.executor()
@@ -1,1095 +0,0 @@
-use anyhow::{Result, anyhow};
-use collections::HashMap;
-use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
-use http_client::HttpClient;
-use language_model::{
- AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
- LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
-};
-use open_router::{
- Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models,
-};
-use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsStore};
-use std::pin::Pin;
-use std::str::FromStr as _;
-use std::sync::{Arc, LazyLock};
-use ui::{List, prelude::*};
-use ui_input::InputField;
-use util::ResultExt;
-use zed_env_vars::{EnvVar, env_var};
-
-use crate::ui::ConfiguredApiCard;
-use crate::{api_key::ApiKeyState, ui::InstructionListItem};
-
-const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
-const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
-
-const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
-static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
-
-#[derive(Default, Clone, Debug, PartialEq)]
-pub struct OpenRouterSettings {
- pub api_url: String,
- pub available_models: Vec<AvailableModel>,
-}
-
-pub struct OpenRouterLanguageModelProvider {
- http_client: Arc<dyn HttpClient>,
- state: Entity<State>,
-}
-
-pub struct State {
- api_key_state: ApiKeyState,
- http_client: Arc<dyn HttpClient>,
- available_models: Vec<open_router::Model>,
- fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
-}
-
-impl State {
- fn is_authenticated(&self) -> bool {
- self.api_key_state.has_key()
- }
-
- fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
- let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- self.api_key_state
- .store(api_url, api_key, |this| &mut this.api_key_state, cx)
- }
-
- fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- let task = self.api_key_state.load_if_needed(
- api_url,
- &API_KEY_ENV_VAR,
- |this| &mut this.api_key_state,
- cx,
- );
-
- cx.spawn(async move |this, cx| {
- let result = task.await;
- this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
- .ok();
- result
- })
- }
-
- fn fetch_models(
- &mut self,
- cx: &mut Context<Self>,
- ) -> Task<Result<(), LanguageModelCompletionError>> {
- let http_client = self.http_client.clone();
- let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- let Some(api_key) = self.api_key_state.key(&api_url) else {
- return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- }));
- };
- cx.spawn(async move |this, cx| {
- let models = list_models(http_client.as_ref(), &api_url, &api_key)
- .await
- .map_err(|e| {
- LanguageModelCompletionError::Other(anyhow::anyhow!(
- "OpenRouter error: {:?}",
- e
- ))
- })?;
-
- this.update(cx, |this, cx| {
- this.available_models = models;
- cx.notify();
- })
- .map_err(|e| LanguageModelCompletionError::Other(e))?;
-
- Ok(())
- })
- }
-
- fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
- if self.is_authenticated() {
- let task = self.fetch_models(cx);
- self.fetch_models_task.replace(task);
- } else {
- self.available_models = Vec::new();
- }
- }
-}
-
-impl OpenRouterLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| {
- cx.observe_global::<SettingsStore>({
- let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
- move |this: &mut State, cx| {
- let current_settings = OpenRouterLanguageModelProvider::settings(cx);
- let settings_changed = current_settings != &last_settings;
- if settings_changed {
- last_settings = current_settings.clone();
- this.authenticate(cx).detach();
- cx.notify();
- }
- }
- })
- .detach();
- State {
- api_key_state: ApiKeyState::new(Self::api_url(cx)),
- http_client: http_client.clone(),
- available_models: Vec::new(),
- fetch_models_task: None,
- }
- });
-
- Self { http_client, state }
- }
-
- fn settings(cx: &App) -> &OpenRouterSettings {
- &crate::AllLanguageModelSettings::get_global(cx).open_router
- }
-
- fn api_url(cx: &App) -> SharedString {
- let api_url = &Self::settings(cx).api_url;
- if api_url.is_empty() {
- OPEN_ROUTER_API_URL.into()
- } else {
- SharedString::new(api_url.as_str())
- }
- }
-
- fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
- Arc::new(OpenRouterLanguageModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- })
- }
-}
-
-impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
- type ObservableEntity = State;
-
- fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
- Some(self.state.clone())
- }
-}
-
-impl LanguageModelProvider for OpenRouterLanguageModelProvider {
- fn id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn icon(&self) -> IconName {
- IconName::AiOpenRouter
- }
-
- fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(open_router::Model::default()))
- }
-
- fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
- Some(self.create_language_model(open_router::Model::default_fast()))
- }
-
- fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
- let mut models_from_api = self.state.read(cx).available_models.clone();
- let mut settings_models = Vec::new();
-
- for model in &Self::settings(cx).available_models {
- settings_models.push(open_router::Model {
- name: model.name.clone(),
- display_name: model.display_name.clone(),
- max_tokens: model.max_tokens,
- supports_tools: model.supports_tools,
- supports_images: model.supports_images,
- mode: model.mode.unwrap_or_default(),
- provider: model.provider.clone(),
- });
- }
-
- for settings_model in &settings_models {
- if let Some(pos) = models_from_api
- .iter()
- .position(|m| m.name == settings_model.name)
- {
- models_from_api[pos] = settings_model.clone();
- } else {
- models_from_api.push(settings_model.clone());
- }
- }
-
- models_from_api
- .into_iter()
- .map(|model| self.create_language_model(model))
- .collect()
- }
-
- fn is_authenticated(&self, cx: &App) -> bool {
- self.state.read(cx).is_authenticated()
- }
-
- fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
- self.state.update(cx, |state, cx| state.authenticate(cx))
- }
-
- fn configuration_view(
- &self,
- _target_agent: language_model::ConfigurationViewTargetAgent,
- window: &mut Window,
- cx: &mut App,
- ) -> AnyView {
- cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
- .into()
- }
-
- fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state
- .update(cx, |state, cx| state.set_api_key(None, cx))
- }
-}
-
-pub struct OpenRouterLanguageModel {
- id: LanguageModelId,
- model: open_router::Model,
- state: Entity<State>,
- http_client: Arc<dyn HttpClient>,
- request_limiter: RateLimiter,
-}
-
-impl OpenRouterLanguageModel {
- fn stream_completion(
- &self,
- request: open_router::Request,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- futures::stream::BoxStream<
- 'static,
- Result<ResponseStreamEvent, open_router::OpenRouterError>,
- >,
- LanguageModelCompletionError,
- >,
- > {
- let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
- let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- (state.api_key_state.key(&api_url), api_url)
- }) else {
- return future::ready(Err(anyhow!("App state dropped").into())).boxed();
- };
-
- async move {
- let Some(api_key) = api_key else {
- return Err(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- });
- };
- let request =
- open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- request.await.map_err(Into::into)
- }
- .boxed()
- }
-}
-
-impl LanguageModel for OpenRouterLanguageModel {
- fn id(&self) -> LanguageModelId {
- self.id.clone()
- }
-
- fn name(&self) -> LanguageModelName {
- LanguageModelName::from(self.model.display_name().to_string())
- }
-
- fn provider_id(&self) -> LanguageModelProviderId {
- PROVIDER_ID
- }
-
- fn provider_name(&self) -> LanguageModelProviderName {
- PROVIDER_NAME
- }
-
- fn supports_tools(&self) -> bool {
- self.model.supports_tool_calls()
- }
-
- fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
- let model_id = self.model.id().trim().to_lowercase();
- if model_id.contains("gemini") || model_id.contains("grok") {
- LanguageModelToolSchemaFormat::JsonSchemaSubset
- } else {
- LanguageModelToolSchemaFormat::JsonSchema
- }
- }
-
- fn telemetry_id(&self) -> String {
- format!("openrouter/{}", self.model.id())
- }
-
- fn max_token_count(&self) -> u64 {
- self.model.max_token_count()
- }
-
- fn max_output_tokens(&self) -> Option<u64> {
- self.model.max_output_tokens()
- }
-
- fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
- match choice {
- LanguageModelToolChoice::Auto => true,
- LanguageModelToolChoice::Any => true,
- LanguageModelToolChoice::None => true,
- }
- }
-
- fn supports_images(&self) -> bool {
- self.model.supports_images.unwrap_or(false)
- }
-
- fn count_tokens(
- &self,
- request: LanguageModelRequest,
- cx: &App,
- ) -> BoxFuture<'static, Result<u64>> {
- count_open_router_tokens(request, self.model.clone(), cx)
- }
-
- fn stream_completion(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<
- futures::stream::BoxStream<
- 'static,
- Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
- >,
- LanguageModelCompletionError,
- >,
- > {
- let request = into_open_router(request, &self.model, self.max_output_tokens());
- let request = self.stream_completion(request, cx);
- let future = self.request_limiter.stream(async move {
- let response = request.await?;
- Ok(OpenRouterEventMapper::new().map_stream(response))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- }
-}
-
-pub fn into_open_router(
- request: LanguageModelRequest,
- model: &Model,
- max_output_tokens: Option<u64>,
-) -> open_router::Request {
- let mut messages = Vec::new();
- for message in request.messages {
- let reasoning_details = message.reasoning_details.clone();
- for content in message.content {
- match content {
- MessageContent::Text(text) => add_message_content_part(
- open_router::MessagePart::Text { text },
- message.role,
- &mut messages,
- ),
- MessageContent::Thinking { .. } => {}
- MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(image) => {
- add_message_content_part(
- open_router::MessagePart::Image {
- image_url: image.to_base64_url(),
- },
- message.role,
- &mut messages,
- );
- }
- MessageContent::ToolUse(tool_use) => {
- let tool_call = open_router::ToolCall {
- id: tool_use.id.to_string(),
- content: open_router::ToolCallContent::Function {
- function: open_router::FunctionContent {
- name: tool_use.name.to_string(),
- arguments: serde_json::to_string(&tool_use.input)
- .unwrap_or_default(),
- thought_signature: tool_use.thought_signature.clone(),
- },
- },
- };
-
- if let Some(open_router::RequestMessage::Assistant {
- tool_calls,
- reasoning_details: existing_reasoning,
- ..
- }) = messages.last_mut()
- {
- tool_calls.push(tool_call);
- if existing_reasoning.is_none() && reasoning_details.is_some() {
- *existing_reasoning = reasoning_details.clone();
- }
- } else {
- messages.push(open_router::RequestMessage::Assistant {
- content: None,
- tool_calls: vec![tool_call],
- reasoning_details: reasoning_details.clone(),
- });
- }
- }
- MessageContent::ToolResult(tool_result) => {
- let content = match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- vec![open_router::MessagePart::Text {
- text: text.to_string(),
- }]
- }
- LanguageModelToolResultContent::Image(image) => {
- vec![open_router::MessagePart::Image {
- image_url: image.to_base64_url(),
- }]
- }
- };
-
- messages.push(open_router::RequestMessage::Tool {
- content: content.into(),
- tool_call_id: tool_result.tool_use_id.to_string(),
- });
- }
- }
- }
- }
-
- open_router::Request {
- model: model.id().into(),
- messages,
- stream: true,
- stop: request.stop,
- temperature: request.temperature.unwrap_or(0.4),
- max_tokens: max_output_tokens,
- parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
- Some(false)
- } else {
- None
- },
- usage: open_router::RequestUsage { include: true },
- reasoning: if request.thinking_allowed
- && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
- {
- Some(open_router::Reasoning {
- effort: None,
- max_tokens: budget_tokens,
- exclude: Some(false),
- enabled: Some(true),
- })
- } else {
- None
- },
- tools: request
- .tools
- .into_iter()
- .map(|tool| open_router::ToolDefinition::Function {
- function: open_router::FunctionDefinition {
- name: tool.name,
- description: Some(tool.description),
- parameters: Some(tool.input_schema),
- },
- })
- .collect(),
- tool_choice: request.tool_choice.map(|choice| match choice {
- LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto,
- LanguageModelToolChoice::Any => open_router::ToolChoice::Required,
- LanguageModelToolChoice::None => open_router::ToolChoice::None,
- }),
- provider: model.provider.clone(),
- }
-}
-
-fn add_message_content_part(
- new_part: open_router::MessagePart,
- role: Role,
- messages: &mut Vec<open_router::RequestMessage>,
-) {
- match (role, messages.last_mut()) {
- (Role::User, Some(open_router::RequestMessage::User { content }))
- | (Role::System, Some(open_router::RequestMessage::System { content })) => {
- content.push_part(new_part);
- }
- (
- Role::Assistant,
- Some(open_router::RequestMessage::Assistant {
- content: Some(content),
- ..
- }),
- ) => {
- content.push_part(new_part);
- }
- _ => {
- messages.push(match role {
- Role::User => open_router::RequestMessage::User {
- content: open_router::MessageContent::from(vec![new_part]),
- },
- Role::Assistant => open_router::RequestMessage::Assistant {
- content: Some(open_router::MessageContent::from(vec![new_part])),
- tool_calls: Vec::new(),
- reasoning_details: None,
- },
- Role::System => open_router::RequestMessage::System {
- content: open_router::MessageContent::from(vec![new_part]),
- },
- });
- }
- }
-}
-
-pub struct OpenRouterEventMapper {
- tool_calls_by_index: HashMap<usize, RawToolCall>,
- reasoning_details: Option<serde_json::Value>,
-}
-
-impl OpenRouterEventMapper {
- pub fn new() -> Self {
- Self {
- tool_calls_by_index: HashMap::default(),
- reasoning_details: None,
- }
- }
-
- pub fn map_stream(
- mut self,
- events: Pin<
- Box<
- dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
- >,
- >,
- ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
- {
- events.flat_map(move |event| {
- futures::stream::iter(match event {
- Ok(event) => self.map_event(event),
- Err(error) => vec![Err(error.into())],
- })
- })
- }
-
- pub fn map_event(
- &mut self,
- event: ResponseStreamEvent,
- ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- let Some(choice) = event.choices.first() else {
- return vec![Err(LanguageModelCompletionError::from(anyhow!(
- "Response contained no choices"
- )))];
- };
-
- let mut events = Vec::new();
-
- if let Some(details) = choice.delta.reasoning_details.clone() {
- // Emit reasoning_details immediately
- events.push(Ok(LanguageModelCompletionEvent::ReasoningDetails(
- details.clone(),
- )));
- self.reasoning_details = Some(details);
- }
-
- if let Some(reasoning) = choice.delta.reasoning.clone() {
- events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text: reasoning,
- signature: None,
- }));
- }
-
- if let Some(content) = choice.delta.content.clone() {
- // OpenRouter send empty content string with the reasoning content
- // This is a workaround for the OpenRouter API bug
- if !content.is_empty() {
- events.push(Ok(LanguageModelCompletionEvent::Text(content)));
- }
- }
-
- if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
- for tool_call in tool_calls {
- let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
-
- if let Some(tool_id) = tool_call.id.clone() {
- entry.id = tool_id;
- }
-
- if let Some(function) = tool_call.function.as_ref() {
- if let Some(name) = function.name.clone() {
- entry.name = name;
- }
-
- if let Some(arguments) = function.arguments.clone() {
- entry.arguments.push_str(&arguments);
- }
-
- if let Some(signature) = function.thought_signature.clone() {
- entry.thought_signature = Some(signature);
- }
- }
- }
- }
-
- if let Some(usage) = event.usage {
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
- input_tokens: usage.prompt_tokens,
- output_tokens: usage.completion_tokens,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- })));
- }
-
- match choice.finish_reason.as_deref() {
- Some("stop") => {
- // Don't emit reasoning_details here - already emitted immediately when captured
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
- }
- Some("tool_calls") => {
- events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
- match serde_json::Value::from_str(&tool_call.arguments) {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.clone().into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: tool_call.arguments.clone(),
- thought_signature: tool_call.thought_signature.clone(),
- },
- )),
- Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
- id: tool_call.id.clone().into(),
- tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.clone().into(),
- json_parse_error: error.to_string(),
- }),
- }
- }));
-
- // Don't emit reasoning_details here - already emitted immediately when captured
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
- }
- Some(stop_reason) => {
- log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
- // Don't emit reasoning_details here - already emitted immediately when captured
- events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
- }
- None => {}
- }
-
- events
- }
-}
-
-#[derive(Default)]
-struct RawToolCall {
- id: String,
- name: String,
- arguments: String,
- thought_signature: Option<String>,
-}
-
-pub fn count_open_router_tokens(
- request: LanguageModelRequest,
- _model: open_router::Model,
- cx: &App,
-) -> BoxFuture<'static, Result<u64>> {
- cx.background_spawn(async move {
- let messages = request
- .messages
- .into_iter()
- .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
- role: match message.role {
- Role::User => "user".into(),
- Role::Assistant => "assistant".into(),
- Role::System => "system".into(),
- },
- content: Some(message.string_contents()),
- name: None,
- function_call: None,
- })
- .collect::<Vec<_>>();
-
- tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
- })
- .boxed()
-}
-
-struct ConfigurationView {
- api_key_editor: Entity<InputField>,
- state: Entity<State>,
- load_credentials_task: Option<Task<()>>,
-}
-
-impl ConfigurationView {
- fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
- let api_key_editor = cx.new(|cx| {
- InputField::new(
- window,
- cx,
- "sk_or_000000000000000000000000000000000000000000000000",
- )
- });
-
- cx.observe(&state, |_, _, cx| {
- cx.notify();
- })
- .detach();
-
- let load_credentials_task = Some(cx.spawn_in(window, {
- let state = state.clone();
- async move |this, cx| {
- if let Some(task) = state
- .update(cx, |state, cx| state.authenticate(cx))
- .log_err()
- {
- let _ = task.await;
- }
-
- this.update(cx, |this, cx| {
- this.load_credentials_task = None;
- cx.notify();
- })
- .log_err();
- }
- }));
-
- Self {
- api_key_editor,
- state,
- load_credentials_task,
- }
- }
-
- fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
- if api_key.is_empty() {
- return;
- }
-
- // url changes can cause the editor to be displayed again
- self.api_key_editor
- .update(cx, |editor, cx| editor.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- self.api_key_editor
- .update(cx, |editor, cx| editor.set_text("", window, cx));
-
- let state = self.state.clone();
- cx.spawn_in(window, async move |_, cx| {
- state
- .update(cx, |state, cx| state.set_api_key(None, cx))?
- .await
- })
- .detach_and_log_err(cx);
- }
-
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
- !self.state.read(cx).is_authenticated()
- }
-}
-
-impl Render for ConfigurationView {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
- let configured_card_label = if env_var_set {
- format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
- } else {
- let api_url = OpenRouterLanguageModelProvider::api_url(cx);
- if api_url == OPEN_ROUTER_API_URL {
- "API key configured".to_string()
- } else {
- format!("API key configured for {}", api_url)
- }
- };
-
- if self.load_credentials_task.is_some() {
- div()
- .child(Label::new("Loading credentials..."))
- .into_any_element()
- } else if self.should_render_editor(cx) {
- v_flex()
- .size_full()
- .on_action(cx.listener(Self::save_api_key))
- .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:"))
- .child(
- List::new()
- .child(InstructionListItem::new(
- "Create an API key by visiting",
- Some("OpenRouter's console"),
- Some("https://openrouter.ai/keys"),
- ))
- .child(InstructionListItem::text_only(
- "Ensure your OpenRouter account has credits",
- ))
- .child(InstructionListItem::text_only(
- "Paste your API key below and hit enter to start using the assistant",
- )),
- )
- .child(self.api_key_editor.clone())
- .child(
- Label::new(
- format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
- )
- .size(LabelSize::Small).color(Color::Muted),
- )
- .into_any_element()
- } else {
- ConfiguredApiCard::new(configured_card_label)
- .disabled(env_var_set)
- .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
- .when(env_var_set, |this| {
- this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
- })
- .into_any_element()
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use open_router::{ChoiceDelta, FunctionChunk, ResponseMessageDelta, ToolCallChunk};
-
- #[gpui::test]
- async fn test_reasoning_details_preservation_with_tool_calls() {
- // This test verifies that reasoning_details are properly captured and preserved
- // when a model uses tool calling with reasoning/thinking tokens.
- //
- // The key regression this prevents:
- // - OpenRouter sends multiple reasoning_details updates during streaming
- // - First with actual content (encrypted reasoning data)
- // - Then with empty array on completion
- // - We must NOT overwrite the real data with the empty array
-
- let mut mapper = OpenRouterEventMapper::new();
-
- // Simulate the streaming events as they come from OpenRouter/Gemini
- let events = vec![
- // Event 1: Initial reasoning details with text
- ResponseStreamEvent {
- id: Some("response_123".into()),
- created: 1234567890,
- model: "google/gemini-3-pro-preview".into(),
- choices: vec![ChoiceDelta {
- index: 0,
- delta: ResponseMessageDelta {
- role: None,
- content: None,
- reasoning: None,
- tool_calls: None,
- reasoning_details: Some(serde_json::json!([
- {
- "type": "reasoning.text",
- "text": "Let me analyze this request...",
- "format": "google-gemini-v1",
- "index": 0
- }
- ])),
- },
- finish_reason: None,
- }],
- usage: None,
- },
- // Event 2: More reasoning details
- ResponseStreamEvent {
- id: Some("response_123".into()),
- created: 1234567890,
- model: "google/gemini-3-pro-preview".into(),
- choices: vec![ChoiceDelta {
- index: 0,
- delta: ResponseMessageDelta {
- role: None,
- content: None,
- reasoning: None,
- tool_calls: None,
- reasoning_details: Some(serde_json::json!([
- {
- "type": "reasoning.encrypted",
- "data": "EtgDCtUDAdHtim9OF5jm4aeZSBAtl/randomized123",
- "format": "google-gemini-v1",
- "index": 0,
- "id": "tool_call_abc123"
- }
- ])),
- },
- finish_reason: None,
- }],
- usage: None,
- },
- // Event 3: Tool call starts
- ResponseStreamEvent {
- id: Some("response_123".into()),
- created: 1234567890,
- model: "google/gemini-3-pro-preview".into(),
- choices: vec![ChoiceDelta {
- index: 0,
- delta: ResponseMessageDelta {
- role: None,
- content: None,
- reasoning: None,
- tool_calls: Some(vec![ToolCallChunk {
- index: 0,
- id: Some("tool_call_abc123".into()),
- function: Some(FunctionChunk {
- name: Some("list_directory".into()),
- arguments: Some("{\"path\":\"test\"}".into()),
- thought_signature: Some("sha256:test_signature_xyz789".into()),
- }),
- }]),
- reasoning_details: None,
- },
- finish_reason: None,
- }],
- usage: None,
- },
- // Event 4: Empty reasoning_details on tool_calls finish
- // This is the critical event - we must not overwrite with this empty array!
- ResponseStreamEvent {
- id: Some("response_123".into()),
- created: 1234567890,
- model: "google/gemini-3-pro-preview".into(),
- choices: vec![ChoiceDelta {
- index: 0,
- delta: ResponseMessageDelta {
- role: None,
- content: None,
- reasoning: None,
- tool_calls: None,
- reasoning_details: Some(serde_json::json!([])),
- },
- finish_reason: Some("tool_calls".into()),
- }],
- usage: None,
- },
- ];
-
- // Process all events
- let mut collected_events = Vec::new();
- for event in events {
- let mapped = mapper.map_event(event);
- collected_events.extend(mapped);
- }
-
- // Verify we got the expected events
- let mut has_tool_use = false;
- let mut reasoning_details_events = Vec::new();
- let mut thought_signature_value = None;
-
- for event_result in collected_events {
- match event_result {
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
- has_tool_use = true;
- assert_eq!(tool_use.id.to_string(), "tool_call_abc123");
- assert_eq!(tool_use.name.as_ref(), "list_directory");
- thought_signature_value = tool_use.thought_signature.clone();
- }
- Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => {
- reasoning_details_events.push(details);
- }
- _ => {}
- }
- }
-
- // Assertions
- assert!(has_tool_use, "Should have emitted ToolUse event");
- assert!(
- !reasoning_details_events.is_empty(),
- "Should have emitted ReasoningDetails events"
- );
-
- // We should have received multiple reasoning_details events (text, encrypted, empty)
- // The agent layer is responsible for keeping only the first non-empty one
- assert!(
- reasoning_details_events.len() >= 2,
- "Should have multiple reasoning_details events from streaming"
- );
-
- // Verify at least one contains the encrypted data
- let has_encrypted = reasoning_details_events.iter().any(|details| {
- if let serde_json::Value::Array(arr) = details {
- arr.iter().any(|item| {
- item["type"] == "reasoning.encrypted"
- && item["data"]
- .as_str()
- .map_or(false, |s| s.contains("EtgDCtUDAdHtim9OF5jm4aeZSBAtl"))
- })
- } else {
- false
- }
- });
- assert!(
- has_encrypted,
- "Should have at least one reasoning_details with encrypted data"
- );
-
- // Verify thought_signature was captured
- assert!(
- thought_signature_value.is_some(),
- "Tool use should have thought_signature"
- );
- assert_eq!(
- thought_signature_value.unwrap(),
- "sha256:test_signature_xyz789"
- );
- }
-
- #[gpui::test]
- async fn test_agent_prevents_empty_reasoning_details_overwrite() {
- // This test verifies that the agent layer prevents empty reasoning_details
- // from overwriting non-empty ones, even though the mapper emits all events.
-
- // Simulate what the agent does when it receives multiple ReasoningDetails events
- let mut agent_reasoning_details: Option<serde_json::Value> = None;
-
- let events = vec![
- // First event: non-empty reasoning_details
- serde_json::json!([
- {
- "type": "reasoning.encrypted",
- "data": "real_data_here",
- "format": "google-gemini-v1"
- }
- ]),
- // Second event: empty array (should not overwrite)
- serde_json::json!([]),
- ];
-
- for details in events {
- // This mimics the agent's logic: only store if we don't already have it
- if agent_reasoning_details.is_none() {
- agent_reasoning_details = Some(details);
- }
- }
-
- // Verify the agent kept the first non-empty reasoning_details
- assert!(agent_reasoning_details.is_some());
- let final_details = agent_reasoning_details.unwrap();
- if let serde_json::Value::Array(arr) = &final_details {
- assert!(
- !arr.is_empty(),
- "Agent should have kept the non-empty reasoning_details"
- );
- assert_eq!(arr[0]["data"], "real_data_here");
- } else {
- panic!("Expected array");
- }
- }
-}
@@ -7,9 +7,17 @@ use crate::provider::{
bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, deepseek::DeepSeekSettings,
google::GoogleSettings, lmstudio::LmStudioSettings, mistral::MistralSettings,
ollama::OllamaSettings, open_ai::OpenAiSettings, open_ai_compatible::OpenAiCompatibleSettings,
- open_router::OpenRouterSettings, vercel::VercelSettings, x_ai::XAiSettings,
+ vercel::VercelSettings, x_ai::XAiSettings,
};
+pub use settings::OpenRouterAvailableModel as AvailableModel;
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenRouterSettings {
+ pub api_url: String,
+ pub available_models: Vec<AvailableModel>,
+}
+
#[derive(Debug, RegisterSetting)]
pub struct AllLanguageModelSettings {
pub bedrock: AmazonBedrockSettings,
@@ -47,9 +55,9 @@ impl settings::Settings for AllLanguageModelSettings {
bedrock: AmazonBedrockSettings {
available_models: bedrock.available_models.unwrap_or_default(),
region: bedrock.region,
- endpoint: bedrock.endpoint_url, // todo(should be api_url)
+ endpoint: bedrock.endpoint_url,
profile_name: bedrock.profile,
- role_arn: None, // todo(was never a setting for this...)
+ role_arn: None,
authentication_method: bedrock.authentication_method.map(Into::into),
allow_global: bedrock.allow_global,
},