Detailed changes
@@ -8854,16 +8854,19 @@ dependencies = [
"collections",
"component",
"convert_case 0.8.0",
+ "copilot",
"credentials_provider",
"deepseek",
"editor",
"extension",
+ "extension_host",
"fs",
"futures 0.3.31",
"google_ai",
"gpui",
"gpui_tokio",
"http_client",
+ "language",
"language_model",
"lmstudio",
"log",
@@ -8871,6 +8874,8 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
+ "open_router",
+ "partial-json-fixer",
"project",
"release_channel",
"schemars",
@@ -11219,6 +11224,12 @@ dependencies = [
"num-traits",
]
+[[package]]
+name = "partial-json-fixer"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
+
[[package]]
name = "password-hash"
version = "0.4.2"
@@ -83,14 +83,24 @@ impl AgentConfiguration {
window,
|this, _, event: &language_model::Event, window, cx| match event {
language_model::Event::AddedProvider(provider_id) => {
- let provider = LanguageModelRegistry::read_global(cx).provider(provider_id);
- if let Some(provider) = provider {
- this.add_provider_configuration_view(&provider, window, cx);
+ let registry = LanguageModelRegistry::read_global(cx);
+ // Only add if the provider is visible
+ if let Some(provider) = registry.provider(provider_id) {
+ if !registry.should_hide_provider(provider_id) {
+ this.add_provider_configuration_view(&provider, window, cx);
+ }
}
}
language_model::Event::RemovedProvider(provider_id) => {
this.remove_provider_configuration_view(provider_id);
}
+ language_model::Event::ProvidersChanged => {
+ // Rebuild all provider views when visibility changes
+ this.configuration_views_by_provider.clear();
+ this.expanded_provider_configurations.clear();
+ this.build_provider_configuration_views(window, cx);
+ cx.notify();
+ }
_ => {}
},
);
@@ -117,7 +127,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let providers = LanguageModelRegistry::read_global(cx).providers();
+ let providers = LanguageModelRegistry::read_global(cx).visible_providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@@ -420,7 +430,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
- let providers = LanguageModelRegistry::read_global(cx).providers();
+ let providers = LanguageModelRegistry::read_global(cx).visible_providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(
@@ -2291,7 +2291,7 @@ impl AgentPanel {
let history_is_empty = self.history_store.read(cx).is_empty(cx);
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
.any(|provider| {
provider.is_authenticated(cx)
@@ -46,7 +46,9 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
- let providers = LanguageModelRegistry::global(cx).read(cx).providers();
+ let providers = LanguageModelRegistry::global(cx)
+ .read(cx)
+ .visible_providers();
let recommended = providers
.iter()
@@ -423,7 +425,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
- .providers()
+ .visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -44,7 +44,7 @@ impl ApiKeysWithProviders {
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
@@ -45,7 +45,7 @@ impl AgentPanelOnboarding {
fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
- .providers()
+ .visible_providers()
.iter()
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
@@ -9,6 +9,7 @@ use strum::{EnumIter, EnumString, IntoStaticStr};
#[strum(serialize_all = "snake_case")]
pub enum IconName {
Ai,
+ AiAnthropic,
AiBedrock,
AiClaude,
AiDeepSeek,
@@ -2,12 +2,16 @@ use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
+/// Function type for checking if a built-in provider should be hidden.
+/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
+pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
+
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -48,6 +52,11 @@ pub struct LanguageModelRegistry {
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
+ /// Set of installed extension IDs that provide language models.
+ /// Used to determine which built-in providers should be hidden.
+ installed_llm_extension_ids: HashSet<Arc<str>>,
+ /// Function to check if a built-in provider should be hidden by an extension.
+ builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
}
#[derive(Debug)]
@@ -104,6 +113,8 @@ pub enum Event {
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
+ /// Emitted when provider visibility changes due to extension install/uninstall.
+ ProvidersChanged,
}
impl EventEmitter<Event> for LanguageModelRegistry {}
@@ -183,6 +194,65 @@ impl LanguageModelRegistry {
providers
}
+ /// Returns providers, filtering out hidden built-in providers.
+ pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
+ self.providers()
+ .into_iter()
+ .filter(|p| !self.should_hide_provider(&p.id()))
+ .collect()
+ }
+
+ /// Sets the function used to check if a built-in provider should be hidden.
+ pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
+ self.builtin_provider_hiding_fn = Some(hiding_fn);
+ }
+
+ /// Called when an extension is installed/loaded.
+ /// If the extension provides language models, track it so we can hide the corresponding built-in.
+ pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
+ if self.installed_llm_extension_ids.insert(extension_id) {
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Called when an extension is uninstalled/unloaded.
+ pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
+ if self.installed_llm_extension_ids.remove(extension_id) {
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Sync the set of installed LLM extension IDs.
+ pub fn sync_installed_llm_extensions(
+ &mut self,
+ extension_ids: HashSet<Arc<str>>,
+ cx: &mut Context<Self>,
+ ) {
+ if extension_ids != self.installed_llm_extension_ids {
+ self.installed_llm_extension_ids = extension_ids;
+ cx.emit(Event::ProvidersChanged);
+ cx.notify();
+ }
+ }
+
+ /// Returns the set of installed LLM extension IDs.
+ pub fn installed_llm_extension_ids(&self) -> &HashSet<Arc<str>> {
+ &self.installed_llm_extension_ids
+ }
+
+ /// Returns true if a provider should be hidden from the UI.
+ /// Built-in providers are hidden when their corresponding extension is installed.
+ pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
+ if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
+ if let Some(extension_id) = hiding_fn(&provider_id.0) {
+ return self.installed_llm_extension_ids.contains(extension_id);
+ }
+ }
+ false
+ }
+
pub fn configuration_error(
&self,
model: Option<ConfiguredModel>,
@@ -416,4 +486,151 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
+
+ #[gpui::test]
+ fn test_provider_hiding_on_extension_install(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+ let provider_id = provider.id();
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ // Set up a hiding function that hides the fake provider when "fake-extension" is installed
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+ });
+
+ // Provider should be visible initially
+ let visible = registry.read(cx).visible_providers();
+ assert_eq!(visible.len(), 1);
+ assert_eq!(visible[0].id(), provider_id);
+
+ // Install the extension
+ registry.update(cx, |registry, cx| {
+ registry.extension_installed("fake-extension".into(), cx);
+ });
+
+ // Provider should now be hidden
+ let visible = registry.read(cx).visible_providers();
+ assert!(visible.is_empty());
+
+ // But still in providers()
+ let all = registry.read(cx).providers();
+ assert_eq!(all.len(), 1);
+ }
+
+ #[gpui::test]
+ fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+ let provider_id = provider.id();
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ // Set up hiding function
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+
+ // Start with extension installed
+ registry.extension_installed("fake-extension".into(), cx);
+ });
+
+ // Provider should be hidden
+ let visible = registry.read(cx).visible_providers();
+ assert!(visible.is_empty());
+
+ // Uninstall the extension
+ registry.update(cx, |registry, cx| {
+ registry.extension_uninstalled("fake-extension", cx);
+ });
+
+ // Provider should now be visible again
+ let visible = registry.read(cx).visible_providers();
+ assert_eq!(visible.len(), 1);
+ assert_eq!(visible[0].id(), provider_id);
+ }
+
+ #[gpui::test]
+ fn test_should_hide_provider(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ registry.update(cx, |registry, cx| {
+ // Set up hiding function
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "anthropic" {
+ Some("anthropic")
+ } else if id == "openai" {
+ Some("openai")
+ } else {
+ None
+ }
+ }));
+
+ // Install only anthropic extension
+ registry.extension_installed("anthropic".into(), cx);
+ });
+
+ let registry_read = registry.read(cx);
+
+ // Anthropic should be hidden
+ assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
+
+ // OpenAI should not be hidden (extension not installed)
+ assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
+
+ // Unknown provider should not be hidden
+ assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
+ }
+
+ #[gpui::test]
+ fn test_sync_installed_llm_extensions(cx: &mut App) {
+ let registry = cx.new(|_| LanguageModelRegistry::default());
+
+ let provider = Arc::new(FakeLanguageModelProvider::default());
+
+ registry.update(cx, |registry, cx| {
+ registry.register_provider(provider.clone(), cx);
+
+ registry.set_builtin_provider_hiding_fn(Box::new(|id| {
+ if id == "fake" {
+ Some("fake-extension")
+ } else {
+ None
+ }
+ }));
+ });
+
+ // Sync with a set containing the extension
+ let mut extension_ids = HashSet::default();
+ extension_ids.insert(Arc::from("fake-extension"));
+
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(extension_ids, cx);
+ });
+
+ // Provider should be hidden
+ assert!(registry.read(cx).visible_providers().is_empty());
+
+ // Sync with empty set
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(HashSet::default(), cx);
+ });
+
+ // Provider should be visible again
+ assert_eq!(registry.read(cx).visible_providers().len(), 1);
+ }
}
@@ -25,15 +25,18 @@ cloud_llm_client.workspace = true
collections.workspace = true
component.workspace = true
convert_case.workspace = true
+copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
extension.workspace = true
+extension_host.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
+language.workspace = true
language_model.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
@@ -41,6 +44,8 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
+open_router = { workspace = true, features = ["schemars"] }
+partial-json-fixer.workspace = true
release_channel.workspace = true
schemars.workspace = true
semver.workspace = true
@@ -223,13 +223,27 @@ impl ApiKeyState {
}
impl ApiKey {
- fn from_env(env_var_name: SharedString, key: &str) -> Self {
+ pub fn key(&self) -> &str {
+ &self.key
+ }
+
+ pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
Self {
source: ApiKeySource::EnvVar(env_var_name),
key: key.into(),
}
}
+ pub async fn load_from_system_keychain(
+ url: &str,
+ credentials_provider: &dyn CredentialsProvider,
+ cx: &AsyncApp,
+ ) -> Result<Self, AuthenticateError> {
+ Self::load_from_system_keychain_impl(url, credentials_provider, cx)
+ .await
+ .into_authenticate_result()
+ }
+
async fn load_from_system_keychain_impl(
url: &str,
credentials_provider: &dyn CredentialsProvider,
@@ -1,7 +1,31 @@
+use collections::HashMap;
use extension::{ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration};
use gpui::{App, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
+
+/// Maps built-in provider IDs to their corresponding extension IDs.
+/// When an extension with this ID is installed, the built-in provider should be hidden.
+pub static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
+ LazyLock::new(|| {
+ let mut map = HashMap::default();
+ map.insert("anthropic", "anthropic");
+ map.insert("openai", "openai");
+ map.insert("google", "google-ai");
+ map.insert("open_router", "open-router");
+ map.insert("copilot_chat", "copilot-chat");
+ map
+ });
+
+/// Returns the extension ID that should hide the given built-in provider.
+pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
+ BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
+}
+
+/// Returns true if the given provider ID is a built-in provider that can be hidden by an extension.
+pub fn is_hideable_builtin_provider(provider_id: &str) -> bool {
+ BUILTIN_TO_EXTENSION_MAP.contains_key(provider_id)
+}
/// Proxy implementation that registers extension-based language model providers
/// with the LanguageModelRegistry.
@@ -1,6 +1,5 @@
use std::sync::Arc;
-use ::extension::ExtensionHostProxy;
use ::settings::{Settings, SettingsStore};
use client::{Client, UserStore};
use collections::HashSet;
@@ -9,20 +8,25 @@ use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
mod api_key;
-mod extension;
+pub mod extension;
mod google_ai_api_key;
pub mod provider;
mod settings;
pub mod ui;
-pub use google_ai_api_key::api_key_for_gemini_cli;
-
+pub use crate::extension::{extension_for_builtin_provider, is_hideable_builtin_provider};
+pub use crate::google_ai_api_key::api_key_for_gemini_cli;
+use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
+use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
+use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider;
pub use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
+use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_ai_compatible::OpenAiCompatibleLanguageModelProvider;
+use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
use crate::provider::x_ai::XAiLanguageModelProvider;
pub use crate::settings::*;
@@ -33,11 +37,65 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
register_language_model_providers(registry, user_store, client.clone(), cx);
});
- // Register the extension language model provider proxy
- let extension_proxy = ExtensionHostProxy::default_global(cx);
- extension_proxy.register_language_model_provider_proxy(
- extension::ExtensionLanguageModelProxy::new(registry.clone()),
- );
+ // Set up the provider hiding function
+ registry.update(cx, |registry, _cx| {
+ registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
+ });
+
+ // Subscribe to extension store events to track LLM extension installations
+ if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
+ cx.subscribe(&extension_store, {
+ let registry = registry.clone();
+ move |extension_store, event, cx| {
+ match event {
+ extension_host::Event::ExtensionInstalled(extension_id) => {
+ // Check if this extension has language_model_providers
+ if let Some(manifest) = extension_store
+ .read(cx)
+ .extension_manifest_for_id(extension_id)
+ {
+ if !manifest.language_model_providers.is_empty() {
+ registry.update(cx, |registry, cx| {
+ registry.extension_installed(extension_id.clone(), cx);
+ });
+ }
+ }
+ }
+ extension_host::Event::ExtensionUninstalled(extension_id) => {
+ registry.update(cx, |registry, cx| {
+ registry.extension_uninstalled(extension_id, cx);
+ });
+ }
+ extension_host::Event::ExtensionsUpdated => {
+ // Re-sync installed extensions on bulk updates
+ let mut new_ids = HashSet::default();
+ for (extension_id, entry) in extension_store.read(cx).installed_extensions()
+ {
+ if !entry.manifest.language_model_providers.is_empty() {
+ new_ids.insert(extension_id.clone());
+ }
+ }
+ registry.update(cx, |registry, cx| {
+ registry.sync_installed_llm_extensions(new_ids, cx);
+ });
+ }
+ _ => {}
+ }
+ }
+ })
+ .detach();
+
+ // Initialize with currently installed extensions
+ registry.update(cx, |registry, cx| {
+ let mut initial_ids = HashSet::default();
+ for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
+ if !entry.manifest.language_model_providers.is_empty() {
+ initial_ids.insert(extension_id.clone());
+ }
+ }
+ registry.sync_installed_llm_extensions(initial_ids, cx);
+ });
+ }
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
.openai_compatible
@@ -117,6 +175,17 @@ fn register_language_model_providers(
)),
cx,
);
+ registry.register_provider(
+ Arc::new(AnthropicLanguageModelProvider::new(
+ client.http_client(),
+ cx,
+ )),
+ cx,
+ );
+ registry.register_provider(
+ Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)),
+ cx,
+ );
registry.register_provider(
Arc::new(OllamaLanguageModelProvider::new(client.http_client(), cx)),
cx,
@@ -129,6 +198,10 @@ fn register_language_model_providers(
Arc::new(DeepSeekLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
+ registry.register_provider(
+ Arc::new(GoogleLanguageModelProvider::new(client.http_client(), cx)),
+ cx,
+ );
registry.register_provider(
MistralLanguageModelProvider::global(client.http_client(), cx),
cx,
@@ -137,6 +210,13 @@ fn register_language_model_providers(
Arc::new(BedrockLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
+ registry.register_provider(
+ Arc::new(OpenRouterLanguageModelProvider::new(
+ client.http_client(),
+ cx,
+ )),
+ cx,
+ );
registry.register_provider(
Arc::new(VercelLanguageModelProvider::new(client.http_client(), cx)),
cx,
@@ -145,4 +225,5 @@ fn register_language_model_providers(
Arc::new(XAiLanguageModelProvider::new(client.http_client(), cx)),
cx,
);
+ registry.register_provider(Arc::new(CopilotChatLanguageModelProvider::new(cx)), cx);
}
@@ -1,5 +1,7 @@
+pub mod anthropic;
pub mod bedrock;
pub mod cloud;
+pub mod copilot_chat;
pub mod deepseek;
pub mod google;
pub mod lmstudio;
@@ -7,5 +9,6 @@ pub mod mistral;
pub mod ollama;
pub mod open_ai;
pub mod open_ai_compatible;
+pub mod open_router;
pub mod vercel;
pub mod x_ai;
@@ -0,0 +1,1045 @@
+use anthropic::{
+ ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent,
+ ToolResultContent, ToolResultPart, Usage,
+};
+use anyhow::{Result, anyhow};
+use collections::{BTreeMap, HashMap};
+use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, Task};
+use http_client::HttpClient;
+use language_model::{
+ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
+ LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
+ LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
+ LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
+};
+use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
+use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
+use std::sync::{Arc, LazyLock};
+use strum::IntoEnumIterator;
+use ui::{List, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
+
+use crate::api_key::ApiKeyState;
+use crate::ui::{ConfiguredApiCard, InstructionListItem};
+
+pub use settings::AnthropicAvailableModel as AvailableModel;
+
+const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct AnthropicSettings {
+ pub api_url: String,
+ /// Extend Zed's list of Anthropic models.
+ pub available_models: Vec<AvailableModel>,
+}
+
+pub struct AnthropicLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
+pub struct State {
+ api_key_state: ApiKeyState,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
+ }
+}
+
+impl AnthropicLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
+ cx.notify();
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(AnthropicModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ fn settings(cx: &App) -> &AnthropicSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).anthropic
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ ANTHROPIC_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+}
+
+impl LanguageModelProviderState for AnthropicLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for AnthropicLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiAnthropic
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(anthropic::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(anthropic::Model::default_fast()))
+ }
+
+ fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ [
+ anthropic::Model::ClaudeSonnet4_5,
+ anthropic::Model::ClaudeSonnet4_5Thinking,
+ ]
+ .into_iter()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ // Add base models from anthropic::Model::iter()
+ for model in anthropic::Model::iter() {
+ if !matches!(model, anthropic::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ // Override with available models from settings
+ for model in &AnthropicLanguageModelProvider::settings(cx).available_models {
+ models.insert(
+ model.name.clone(),
+ anthropic::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ tool_override: model.tool_override.clone(),
+ cache_configuration: model.cache_configuration.as_ref().map(|config| {
+ anthropic::AnthropicModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ }
+ }),
+ max_output_tokens: model.max_output_tokens,
+ default_temperature: model.default_temperature,
+ extra_beta_headers: model.extra_beta_headers.clone(),
+ mode: model.mode.unwrap_or_default().into(),
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ target_agent: ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct AnthropicModel {
+ id: LanguageModelId,
+ model: anthropic::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+pub fn count_anthropic_tokens(
+ request: LanguageModelRequest,
+ cx: &App,
+) -> BoxFuture<'static, Result<u64>> {
+ cx.background_spawn(async move {
+ let messages = request.messages;
+ let mut tokens_from_images = 0;
+ let mut string_messages = Vec::with_capacity(messages.len());
+
+ for message in messages {
+ use language_model::MessageContent;
+
+ let mut string_contents = String::new();
+
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => {
+ string_contents.push_str(&text);
+ }
+ MessageContent::Thinking { .. } => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::RedactedThinking(_) => {
+ // Thinking blocks are not included in the input token count.
+ }
+ MessageContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ MessageContent::ToolUse(_tool_use) => {
+ // TODO: Estimate token usage from tool uses.
+ }
+ MessageContent::ToolResult(tool_result) => match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ string_contents.push_str(text);
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ tokens_from_images += image.estimate_tokens();
+ }
+ },
+ }
+ }
+
+ if !string_contents.is_empty() {
+ string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(string_contents),
+ name: None,
+ function_call: None,
+ });
+ }
+ }
+
+ // Tiktoken doesn't yet support these models, so we manually use the
+ // same tokenizer as GPT-4.
+ tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
+ .map(|tokens| (tokens + tokens_from_images) as u64)
+ })
+ .boxed()
+}
+
+impl AnthropicModel {
+ fn stream_completion(
+ &self,
+ request: anthropic::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let http_client = self.http_client.clone();
+
+ let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ (state.api_key_state.key(&api_url), api_url)
+ }) else {
+ return future::ready(Err(anyhow!("App state dropped").into())).boxed();
+ };
+
+ let beta_headers = self.model.beta_headers();
+
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request = anthropic::stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ beta_headers,
+ );
+ request.await.map_err(Into::into)
+ }
+ .boxed()
+ }
+}
+
+impl LanguageModel for AnthropicModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ true
+ }
+
+ fn supports_images(&self) -> bool {
+ true
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto
+ | LanguageModelToolChoice::Any
+ | LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("anthropic/{}", self.model.id())
+ }
+
+ fn api_key(&self, cx: &App) -> Option<String> {
+ self.state.read_with(cx, |state, cx| {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ state.api_key_state.key(&api_url).map(|key| key.to_string())
+ })
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ Some(self.model.max_output_tokens())
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ count_anthropic_tokens(request, cx)
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let request = into_anthropic(
+ request,
+ self.model.request_id().into(),
+ self.model.default_temperature(),
+ self.model.max_output_tokens(),
+ self.model.mode(),
+ );
+ let request = self.stream_completion(request, cx);
+ let future = self.request_limiter.stream(async move {
+ let response = request.await?;
+ Ok(AnthropicEventMapper::new().map_stream(response))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+
+ fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+ self.model
+ .cache_configuration()
+ .map(|config| LanguageModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ })
+ }
+}
+
+pub fn into_anthropic(
+ request: LanguageModelRequest,
+ model: String,
+ default_temperature: f32,
+ max_output_tokens: u64,
+ mode: AnthropicModelMode,
+) -> anthropic::Request {
+ let mut new_messages: Vec<anthropic::Message> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in request.messages {
+ if message.contents_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ MessageContent::Text(text) => {
+ let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
+ text.trim_end().to_string()
+ } else {
+ text
+ };
+ if !text.is_empty() {
+ Some(anthropic::RequestContent::Text {
+ text,
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::Thinking {
+ text: thinking,
+ signature,
+ } => {
+ if !thinking.is_empty() {
+ Some(anthropic::RequestContent::Thinking {
+ thinking,
+ signature: signature.unwrap_or_default(),
+ cache_control: None,
+ })
+ } else {
+ None
+ }
+ }
+ MessageContent::RedactedThinking(data) => {
+ if !data.is_empty() {
+ Some(anthropic::RequestContent::RedactedThinking { data })
+ } else {
+ None
+ }
+ }
+ MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ cache_control: None,
+ }),
+ MessageContent::ToolUse(tool_use) => {
+ Some(anthropic::RequestContent::ToolUse {
+ id: tool_use.id.to_string(),
+ name: tool_use.name.to_string(),
+ input: tool_use.input,
+ cache_control: None,
+ })
+ }
+ MessageContent::ToolResult(tool_result) => {
+ Some(anthropic::RequestContent::ToolResult {
+ tool_use_id: tool_result.tool_use_id.to_string(),
+ is_error: tool_result.is_error,
+ content: match tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ ToolResultContent::Plain(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ ToolResultContent::Multipart(vec![ToolResultPart::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: image.source.to_string(),
+ },
+ }])
+ }
+ },
+ cache_control: None,
+ })
+ }
+ })
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => anthropic::Role::User,
+ Role::Assistant => anthropic::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
+ if let Some(last_message) = new_messages.last_mut()
+ && last_message.role == anthropic_role
+ {
+ last_message.content.extend(anthropic_message_content);
+ continue;
+ }
+
+ // Mark the last segment of the message as cached
+ if message.cache {
+ let cache_control_value = Some(anthropic::CacheControl {
+ cache_type: anthropic::CacheControlType::Ephemeral,
+ });
+ for message_content in anthropic_message_content.iter_mut().rev() {
+ match message_content {
+ anthropic::RequestContent::RedactedThinking { .. } => {
+ // Caching is not possible, fallback to next message
+ }
+ anthropic::RequestContent::Text { cache_control, .. }
+ | anthropic::RequestContent::Thinking { cache_control, .. }
+ | anthropic::RequestContent::Image { cache_control, .. }
+ | anthropic::RequestContent::ToolUse { cache_control, .. }
+ | anthropic::RequestContent::ToolResult { cache_control, .. } => {
+ *cache_control = cache_control_value;
+ break;
+ }
+ }
+ }
+ }
+
+ new_messages.push(anthropic::Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
+ });
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.string_contents());
+ }
+ }
+ }
+
+ anthropic::Request {
+ model,
+ messages: new_messages,
+ max_tokens: max_output_tokens,
+ system: if system_message.is_empty() {
+ None
+ } else {
+ Some(anthropic::StringOrContents::String(system_message))
+ },
+ thinking: if request.thinking_allowed
+ && let AnthropicModelMode::Thinking { budget_tokens } = mode
+ {
+ Some(anthropic::Thinking::Enabled { budget_tokens })
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| anthropic::Tool {
+ name: tool.name,
+ description: tool.description,
+ input_schema: tool.input_schema,
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
+ LanguageModelToolChoice::None => anthropic::ToolChoice::None,
+ }),
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: request.temperature.or(Some(default_temperature)),
+ top_k: None,
+ top_p: None,
+ }
+}
+
+pub struct AnthropicEventMapper {
+ tool_uses_by_index: HashMap<usize, RawToolUse>,
+ usage: Usage,
+ stop_reason: StopReason,
+}
+
+impl AnthropicEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_uses_by_index: HashMap::default(),
+ usage: Usage::default(),
+ stop_reason: StopReason::EndTurn,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(error.into())],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: Event,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ match event {
+ Event::ContentBlockStart {
+ index,
+ content_block,
+ } => match content_block {
+ ResponseContent::Text { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ResponseContent::Thinking { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ResponseContent::RedactedThinking { data } => {
+ vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
+ }
+ ResponseContent::ToolUse { id, name, .. } => {
+ self.tool_uses_by_index.insert(
+ index,
+ RawToolUse {
+ id,
+ name,
+ input_json: String::new(),
+ },
+ );
+ Vec::new()
+ }
+ },
+ Event::ContentBlockDelta { index, delta } => match delta {
+ ContentDelta::TextDelta { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ContentDelta::ThinkingDelta { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ContentDelta::SignatureDelta { signature } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: "".to_string(),
+ signature: Some(signature),
+ })]
+ }
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
+ tool_use.input_json.push_str(&partial_json);
+
+ // Try to convert invalid (incomplete) JSON into
+ // valid JSON that serde can accept, e.g. by closing
+ // unclosed delimiters. This way, we can update the
+ // UI with whatever has been streamed back so far.
+ if let Ok(input) = serde_json::Value::from_str(
+ &partial_json_fixer::fix_json(&tool_use.input_json),
+ ) {
+ return vec![Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.clone().into(),
+ name: tool_use.name.clone().into(),
+ is_input_complete: false,
+ raw_input: tool_use.input_json.clone(),
+ input,
+ thought_signature: None,
+ },
+ ))];
+ }
+ }
+ vec![]
+ }
+ },
+ Event::ContentBlockStop { index } => {
+ if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
+ let input_json = tool_use.input_json.trim();
+ let input_value = if input_json.is_empty() {
+ Ok(serde_json::Value::Object(serde_json::Map::default()))
+ } else {
+ serde_json::Value::from_str(input_json)
+ };
+ let event_result = match input_value {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_use.input_json.clone(),
+ thought_signature: None,
+ },
+ )),
+ Err(json_parse_err) => {
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_use.id.into(),
+ tool_name: tool_use.name.into(),
+ raw_input: input_json.into(),
+ json_parse_error: json_parse_err.to_string(),
+ })
+ }
+ };
+
+ vec![event_result]
+ } else {
+ Vec::new()
+ }
+ }
+ Event::MessageStart { message } => {
+ update_usage(&mut self.usage, &message.usage);
+ vec![
+ Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
+ &self.usage,
+ ))),
+ Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: message.id,
+ }),
+ ]
+ }
+ Event::MessageDelta { delta, usage } => {
+ update_usage(&mut self.usage, &usage);
+ if let Some(stop_reason) = delta.stop_reason.as_deref() {
+ self.stop_reason = match stop_reason {
+ "end_turn" => StopReason::EndTurn,
+ "max_tokens" => StopReason::MaxTokens,
+ "tool_use" => StopReason::ToolUse,
+ "refusal" => StopReason::Refusal,
+ _ => {
+ log::error!("Unexpected anthropic stop_reason: {stop_reason}");
+ StopReason::EndTurn
+ }
+ };
+ }
+ vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
+ convert_usage(&self.usage),
+ ))]
+ }
+ Event::MessageStop => {
+ vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
+ }
+ Event::Error { error } => {
+ vec![Err(error.into())]
+ }
+ _ => Vec::new(),
+ }
+ }
+}
+
+struct RawToolUse {
+ id: String,
+ name: String,
+ input_json: String,
+}
+
+/// Updates usage data by preferring counts from `new`.
+fn update_usage(usage: &mut Usage, new: &Usage) {
+ if let Some(input_tokens) = new.input_tokens {
+ usage.input_tokens = Some(input_tokens);
+ }
+ if let Some(output_tokens) = new.output_tokens {
+ usage.output_tokens = Some(output_tokens);
+ }
+ if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
+ usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
+ }
+ if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
+ usage.cache_read_input_tokens = Some(cache_read_input_tokens);
+ }
+}
+
+fn convert_usage(usage: &Usage) -> language_model::TokenUsage {
+ language_model::TokenUsage {
+ input_tokens: usage.input_tokens.unwrap_or(0),
+ output_tokens: usage.output_tokens.unwrap_or(0),
+ cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
+ cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
+ }
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ load_credentials_task: Option<Task<()>>,
+ target_agent: ConfigurationViewTargetAgent,
+}
+
+impl ConfigurationView {
+ const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
+
+ fn new(
+ state: Entity<State>,
+ target_agent: ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn({
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ // We don't log an error, because "not signed in" is also an error.
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)),
+ state,
+ load_credentials_task,
+ target_agent,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ // url changes can cause the editor to be displayed again
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ } else {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ if api_url == ANTHROPIC_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ if self.load_credentials_task.is_some() {
+ div()
+ .child(Label::new("Loading credentials..."))
+ .into_any_element()
+ } else if self.should_render_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
+ ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
+ ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
+ })))
+ .child(
+ List::new()
+ .child(
+ InstructionListItem::new(
+ "Create one by visiting",
+ Some("Anthropic's settings"),
+ Some("https://console.anthropic.com/settings/keys")
+ )
+ )
+ .child(
+ InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent")
+ )
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!(
+ "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
+ ))
+ })
+ .into_any_element()
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use anthropic::AnthropicModelMode;
+ use language_model::{LanguageModelRequestMessage, MessageContent};
+
+ #[test]
+ fn test_cache_control_only_on_last_segment() {
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ MessageContent::Text("Some prompt".to_string()),
+ MessageContent::Image(language_model::LanguageModelImage::empty()),
+ MessageContent::Image(language_model::LanguageModelImage::empty()),
+ MessageContent::Image(language_model::LanguageModelImage::empty()),
+ MessageContent::Image(language_model::LanguageModelImage::empty()),
+ ],
+ cache: true,
+ reasoning_details: None,
+ }],
+ thread_id: None,
+ prompt_id: None,
+ intent: None,
+ mode: None,
+ stop: vec![],
+ temperature: None,
+ tools: vec![],
+ tool_choice: None,
+ thinking_allowed: true,
+ };
+
+ let anthropic_request = into_anthropic(
+ request,
+ "claude-3-5-sonnet".to_string(),
+ 0.7,
+ 4096,
+ AnthropicModelMode::Default,
+ );
+
+ assert_eq!(anthropic_request.messages.len(), 1);
+
+ let message = &anthropic_request.messages[0];
+ assert_eq!(message.content.len(), 5);
+
+ assert!(matches!(
+ message.content[0],
+ anthropic::RequestContent::Text {
+ cache_control: None,
+ ..
+ }
+ ));
+ for i in 1..3 {
+ assert!(matches!(
+ message.content[i],
+ anthropic::RequestContent::Image {
+ cache_control: None,
+ ..
+ }
+ ));
+ }
+
+ assert!(matches!(
+ message.content[4],
+ anthropic::RequestContent::Image {
+ cache_control: Some(anthropic::CacheControl {
+ cache_type: anthropic::CacheControlType::Ephemeral,
+ }),
+ ..
+ }
+ ));
+ }
+}
@@ -0,0 +1,1565 @@
+use std::pin::Pin;
+use std::str::FromStr as _;
+use std::sync::Arc;
+
+use anyhow::{Result, anyhow};
+use cloud_llm_client::CompletionIntent;
+use collections::HashMap;
+use copilot::copilot_chat::{
+ ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
+ Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
+ ToolCall,
+};
+use copilot::{Copilot, Status};
+use futures::future::BoxFuture;
+use futures::stream::BoxStream;
+use futures::{FutureExt, Stream, StreamExt};
+use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg};
+use http_client::StatusCode;
+use language::language_settings::all_language_settings;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
+ LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
+ StopReason, TokenUsage,
+};
+use settings::SettingsStore;
+use ui::{CommonAnimationExt, prelude::*};
+use util::debug_panic;
+
+use crate::ui::ConfiguredApiCard;
+
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
+const PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("GitHub Copilot Chat");
+
+pub struct CopilotChatLanguageModelProvider {
+ state: Entity<State>,
+}
+
+pub struct State {
+ _copilot_chat_subscription: Option<Subscription>,
+ _settings_subscription: Subscription,
+}
+
+impl State {
+ fn is_authenticated(&self, cx: &App) -> bool {
+ CopilotChat::global(cx)
+ .map(|m| m.read(cx).is_authenticated())
+ .unwrap_or(false)
+ }
+}
+
+impl CopilotChatLanguageModelProvider {
+ pub fn new(cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ let copilot_chat_subscription = CopilotChat::global(cx)
+ .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
+ State {
+ _copilot_chat_subscription: copilot_chat_subscription,
+ _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ if let Some(copilot_chat) = CopilotChat::global(cx) {
+ let language_settings = all_language_settings(None, cx);
+ let configuration = copilot::copilot_chat::CopilotChatConfiguration {
+ enterprise_uri: language_settings
+ .edit_predictions
+ .copilot
+ .enterprise_uri
+ .clone(),
+ };
+ copilot_chat.update(cx, |chat, cx| {
+ chat.set_configuration(configuration, cx);
+ });
+ }
+ cx.notify();
+ }),
+ }
+ });
+
+ Self { state }
+ }
+
+ fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
+ Arc::new(CopilotChatLanguageModel {
+ model,
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+}
+
+impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for CopilotChatLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::Copilot
+ }
+
+ fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
+ models
+ .first()
+ .map(|model| self.create_language_model(model.clone()))
+ }
+
+ fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
+ // model (e.g. 4o) and a sensible choice when considering premium requests
+ self.default_model(cx)
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
+ return Vec::new();
+ };
+ models
+ .iter()
+ .map(|model| self.create_language_model(model.clone()))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated(cx)
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ if self.is_authenticated(cx) {
+ return Task::ready(Ok(()));
+ };
+
+ let Some(copilot) = Copilot::global(cx) else {
+ return Task::ready(Err(anyhow!(concat!(
+ "Copilot must be enabled for Copilot Chat to work. ",
+ "Please enable Copilot and try again."
+ ))
+ .into()));
+ };
+
+ let err = match copilot.read(cx).status() {
+ Status::Authorized => return Task::ready(Ok(())),
+ Status::Disabled => anyhow!(
+ "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
+ ),
+ Status::Error(err) => anyhow!(format!(
+ "Received the following error while signing into Copilot: {err}"
+ )),
+ Status::Starting { task: _ } => anyhow!(
+ "Copilot is still starting, please wait for Copilot to start then try again"
+ ),
+ Status::Unauthorized => anyhow!(
+ "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
+ ),
+ Status::SignedOut { .. } => {
+ anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
+ }
+ Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
+ };
+
+ Task::ready(Err(err.into()))
+ }
+
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ _: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ let state = self.state.clone();
+ cx.new(|cx| ConfigurationView::new(state, cx)).into()
+ }
+
+ fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
+ Task::ready(Err(anyhow!(
+ "Signing out of GitHub Copilot Chat is currently not supported."
+ )))
+ }
+}
+
+fn collect_tiktoken_messages(
+ request: LanguageModelRequest,
+) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
+ request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>()
+}
+
+pub struct CopilotChatLanguageModel {
+ model: CopilotChatModel,
+ request_limiter: RateLimiter,
+}
+
+impl LanguageModel for CopilotChatLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ LanguageModelId::from(self.model.id().to_string())
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_vision()
+ }
+
+ fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ match self.model.vendor() {
+ ModelVendor::OpenAI | ModelVendor::Anthropic => {
+ LanguageModelToolSchemaFormat::JsonSchema
+ }
+ ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => {
+ LanguageModelToolSchemaFormat::JsonSchemaSubset
+ }
+ }
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto
+ | LanguageModelToolChoice::Any
+ | LanguageModelToolChoice::None => self.supports_tools(),
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("copilot_chat/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ let model = self.model.clone();
+ cx.background_spawn(async move {
+ let messages = collect_tiktoken_messages(request);
+ // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor).
+ let tokenizer_model = match model.tokenizer() {
+ Some("o200k_base") => "gpt-4o",
+ Some("cl100k_base") => "gpt-4",
+ _ => "gpt-4o",
+ };
+
+ tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages)
+ .map(|tokens| tokens as u64)
+ })
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let is_user_initiated = request.intent.is_none_or(|intent| match intent {
+ CompletionIntent::UserPrompt
+ | CompletionIntent::ThreadContextSummarization
+ | CompletionIntent::InlineAssist
+ | CompletionIntent::TerminalInlineAssist
+ | CompletionIntent::GenerateGitCommitMessage => true,
+
+ CompletionIntent::ToolResults
+ | CompletionIntent::ThreadSummarization
+ | CompletionIntent::CreateFile
+ | CompletionIntent::EditFile => false,
+ });
+
+ if self.model.supports_response() {
+ let responses_request = into_copilot_responses(&self.model, request);
+ let request_limiter = self.request_limiter.clone();
+ let future = cx.spawn(async move |cx| {
+ let request =
+ CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone());
+ request_limiter
+ .stream(async move {
+ let stream = request.await?;
+ let mapper = CopilotResponsesEventMapper::new();
+ Ok(mapper.map_stream(stream).boxed())
+ })
+ .await
+ });
+ return async move { Ok(future.await?.boxed()) }.boxed();
+ }
+
+ let copilot_request = match into_copilot_chat(&self.model, request) {
+ Ok(request) => request,
+ Err(err) => return futures::future::ready(Err(err.into())).boxed(),
+ };
+ let is_streaming = copilot_request.stream;
+
+ let request_limiter = self.request_limiter.clone();
+ let future = cx.spawn(async move |cx| {
+ let request =
+ CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
+ request_limiter
+ .stream(async move {
+ let response = request.await?;
+ Ok(map_to_language_model_completion_events(
+ response,
+ is_streaming,
+ ))
+ })
+ .await
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+pub fn map_to_language_model_completion_events(
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+ is_streaming: bool,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ #[derive(Default)]
+ struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+ thought_signature: Option<String>,
+ }
+
+ struct State {
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+ reasoning_opaque: Option<String>,
+ reasoning_text: Option<String>,
+ }
+
+ futures::stream::unfold(
+ State {
+ events,
+ tool_calls_by_index: HashMap::default(),
+ reasoning_opaque: None,
+ reasoning_text: None,
+ },
+ move |mut state| async move {
+ if let Some(event) = state.events.next().await {
+ match event {
+ Ok(event) => {
+ let Some(choice) = event.choices.first() else {
+ return Some((
+ vec![Err(anyhow!("Response contained no choices").into())],
+ state,
+ ));
+ };
+
+ let delta = if is_streaming {
+ choice.delta.as_ref()
+ } else {
+ choice.message.as_ref()
+ };
+
+ let Some(delta) = delta else {
+ return Some((
+ vec![Err(anyhow!("Response contained no delta").into())],
+ state,
+ ));
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ // Capture reasoning data from the delta (e.g. for Gemini 3)
+ if let Some(opaque) = delta.reasoning_opaque.clone() {
+ state.reasoning_opaque = Some(opaque);
+ }
+ if let Some(text) = delta.reasoning_text.clone() {
+ state.reasoning_text = Some(text);
+ }
+
+ for (index, tool_call) in delta.tool_calls.iter().enumerate() {
+ let tool_index = tool_call.index.unwrap_or(index);
+ let entry = state.tool_calls_by_index.entry(tool_index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+
+ if let Some(thought_signature) = function.thought_signature.clone()
+ {
+ entry.thought_signature = Some(thought_signature);
+ }
+ }
+ }
+
+ if let Some(usage) = event.usage {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+ TokenUsage {
+ input_tokens: usage.prompt_tokens,
+ output_tokens: usage.completion_tokens,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ },
+ )));
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ Some("tool_calls") => {
+ // Gemini 3 models send reasoning_opaque/reasoning_text that must
+ // be preserved and sent back in subsequent requests. Emit as
+ // ReasoningDetails so the agent stores it in the message.
+ if state.reasoning_opaque.is_some()
+ || state.reasoning_text.is_some()
+ {
+ let mut details = serde_json::Map::new();
+ if let Some(opaque) = state.reasoning_opaque.take() {
+ details.insert(
+ "reasoning_opaque".to_string(),
+ serde_json::Value::String(opaque),
+ );
+ }
+ if let Some(text) = state.reasoning_text.take() {
+ details.insert(
+ "reasoning_text".to_string(),
+ serde_json::Value::String(text),
+ );
+ }
+ events.push(Ok(
+ LanguageModelCompletionEvent::ReasoningDetails(
+ serde_json::Value::Object(details),
+ ),
+ ));
+ }
+
+ events.extend(state.tool_calls_by_index.drain().map(
+ |(_, tool_call)| {
+ // The model can output an empty string
+ // to indicate the absence of arguments.
+ // When that happens, create an empty
+ // object instead.
+ let arguments = if tool_call.arguments.is_empty() {
+ Ok(serde_json::Value::Object(Default::default()))
+ } else {
+ serde_json::Value::from_str(&tool_call.arguments)
+ };
+ match arguments {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments,
+ thought_signature: tool_call.thought_signature,
+ },
+ )),
+ Err(error) => Ok(
+ LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ },
+ ),
+ }
+ },
+ ));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::ToolUse,
+ )));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ None => {}
+ }
+
+ return Some((events, state));
+ }
+ Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
+ }
+ }
+
+ None
+ },
+ )
+ .flat_map(futures::stream::iter)
+}
+
+pub struct CopilotResponsesEventMapper {
+ pending_stop_reason: Option<StopReason>,
+}
+
+impl CopilotResponsesEventMapper {
+ pub fn new() -> Self {
+ Self {
+ pending_stop_reason: None,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<copilot::copilot_responses::StreamEvent>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
+ })
+ })
+ }
+
+ fn map_event(
+ &mut self,
+ event: copilot::copilot_responses::StreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ match event {
+ copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item {
+ copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => {
+ vec![Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: id,
+ })]
+ }
+ _ => Vec::new(),
+ },
+
+ copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => {
+ if delta.is_empty() {
+ Vec::new()
+ } else {
+ vec![Ok(LanguageModelCompletionEvent::Text(delta))]
+ }
+ }
+
+ copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item {
+ copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(),
+ copilot::copilot_responses::ResponseOutputItem::FunctionCall {
+ call_id,
+ name,
+ arguments,
+ thought_signature,
+ ..
+ } => {
+ let mut events = Vec::new();
+ match serde_json::from_str::<serde_json::Value>(&arguments) {
+ Ok(input) => events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: call_id.into(),
+ name: name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: arguments.clone(),
+ thought_signature,
+ },
+ ))),
+ Err(error) => {
+ events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: call_id.into(),
+ tool_name: name.as_str().into(),
+ raw_input: arguments.clone().into(),
+ json_parse_error: error.to_string(),
+ }))
+ }
+ }
+ // Record that we already emitted a tool-use stop so we can avoid duplicating
+ // a Stop event on Completed.
+ self.pending_stop_reason = Some(StopReason::ToolUse);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ events
+ }
+ copilot::copilot_responses::ResponseOutputItem::Reasoning {
+ summary,
+ encrypted_content,
+ ..
+ } => {
+ let mut events = Vec::new();
+
+ if let Some(blocks) = summary {
+ let mut text = String::new();
+ for block in blocks {
+ text.push_str(&block.text);
+ }
+ if !text.is_empty() {
+ events.push(Ok(LanguageModelCompletionEvent::Thinking {
+ text,
+ signature: None,
+ }));
+ }
+ }
+
+ if let Some(data) = encrypted_content {
+ events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data }));
+ }
+
+ events
+ }
+ },
+
+ copilot::copilot_responses::StreamEvent::Completed { response } => {
+ let mut events = Vec::new();
+ if let Some(usage) = response.usage {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: usage.input_tokens.unwrap_or(0),
+ output_tokens: usage.output_tokens.unwrap_or(0),
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ })));
+ }
+ if self.pending_stop_reason.take() != Some(StopReason::ToolUse) {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ events
+ }
+
+ copilot::copilot_responses::StreamEvent::Incomplete { response } => {
+ let reason = response
+ .incomplete_details
+ .as_ref()
+ .and_then(|details| details.reason.as_ref());
+ let stop_reason = match reason {
+ Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => {
+ StopReason::MaxTokens
+ }
+ Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => {
+ StopReason::Refusal
+ }
+ _ => self
+ .pending_stop_reason
+ .take()
+ .unwrap_or(StopReason::EndTurn),
+ };
+
+ let mut events = Vec::new();
+ if let Some(usage) = response.usage {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: usage.input_tokens.unwrap_or(0),
+ output_tokens: usage.output_tokens.unwrap_or(0),
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ })));
+ }
+ events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason)));
+ events
+ }
+
+ copilot::copilot_responses::StreamEvent::Failed { response } => {
+ let provider = PROVIDER_NAME;
+ let (status_code, message) = match response.error {
+ Some(error) => {
+ let status_code = StatusCode::from_str(&error.code)
+ .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
+ (status_code, error.message)
+ }
+ None => (
+ StatusCode::INTERNAL_SERVER_ERROR,
+ "response.failed".to_string(),
+ ),
+ };
+ vec![Err(LanguageModelCompletionError::HttpResponseError {
+ provider,
+ status_code,
+ message,
+ })]
+ }
+
+ copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err(
+ LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))),
+ )],
+
+ copilot::copilot_responses::StreamEvent::Created { .. }
+ | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(),
+ }
+ }
+}
+
+fn into_copilot_chat(
+ model: &copilot::copilot_chat::Model,
+ request: LanguageModelRequest,
+) -> Result<CopilotChatRequest> {
+ let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+ for message in request.messages {
+ if let Some(last_message) = request_messages.last_mut() {
+ if last_message.role == message.role {
+ last_message.content.extend(message.content);
+ } else {
+ request_messages.push(message);
+ }
+ } else {
+ request_messages.push(message);
+ }
+ }
+
+ let mut messages: Vec<ChatMessage> = Vec::new();
+ for message in request_messages {
+ match message.role {
+ Role::User => {
+ for content in &message.content {
+ if let MessageContent::ToolResult(tool_result) = content {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => text.to_string().into(),
+ LanguageModelToolResultContent::Image(image) => {
+ if model.supports_vision() {
+ ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
+ image_url: ImageUrl {
+ url: image.to_base64_url(),
+ },
+ }])
+ } else {
+ debug_panic!(
+ "This should be caught at {} level",
+ tool_result.tool_name
+ );
+ "[Tool responded with an image, but this model does not support vision]".to_string().into()
+ }
+ }
+ };
+
+ messages.push(ChatMessage::Tool {
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ content,
+ });
+ }
+ }
+
+ let mut content_parts = Vec::new();
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. }
+ if !text.is_empty() =>
+ {
+ if let Some(ChatMessagePart::Text { text: text_content }) =
+ content_parts.last_mut()
+ {
+ text_content.push_str(text);
+ } else {
+ content_parts.push(ChatMessagePart::Text {
+ text: text.to_string(),
+ });
+ }
+ }
+ MessageContent::Image(image) if model.supports_vision() => {
+ content_parts.push(ChatMessagePart::Image {
+ image_url: ImageUrl {
+ url: image.to_base64_url(),
+ },
+ });
+ }
+ _ => {}
+ }
+ }
+
+ if !content_parts.is_empty() {
+ messages.push(ChatMessage::User {
+ content: content_parts.into(),
+ });
+ }
+ }
+ Role::Assistant => {
+ let mut tool_calls = Vec::new();
+ for content in &message.content {
+ if let MessageContent::ToolUse(tool_use) = content {
+ tool_calls.push(ToolCall {
+ id: tool_use.id.to_string(),
+ content: copilot::copilot_chat::ToolCallContent::Function {
+ function: copilot::copilot_chat::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)?,
+ thought_signature: tool_use.thought_signature.clone(),
+ },
+ },
+ });
+ }
+ }
+
+ let text_content = {
+ let mut buffer = String::new();
+ for string in message.content.iter().filter_map(|content| match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+ Some(text.as_str())
+ }
+ MessageContent::ToolUse(_)
+ | MessageContent::RedactedThinking(_)
+ | MessageContent::ToolResult(_)
+ | MessageContent::Image(_) => None,
+ }) {
+ buffer.push_str(string);
+ }
+
+ buffer
+ };
+
+ // Extract reasoning_opaque and reasoning_text from reasoning_details
+ let (reasoning_opaque, reasoning_text) =
+ if let Some(details) = &message.reasoning_details {
+ let opaque = details
+ .get("reasoning_opaque")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+ let text = details
+ .get("reasoning_text")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+ (opaque, text)
+ } else {
+ (None, None)
+ };
+
+ messages.push(ChatMessage::Assistant {
+ content: if text_content.is_empty() {
+ ChatMessageContent::empty()
+ } else {
+ text_content.into()
+ },
+ tool_calls,
+ reasoning_opaque,
+ reasoning_text,
+ });
+ }
+ Role::System => messages.push(ChatMessage::System {
+ content: message.string_contents(),
+ }),
+ }
+ }
+
+ let tools = request
+ .tools
+ .iter()
+ .map(|tool| Tool::Function {
+ function: copilot::copilot_chat::Function {
+ name: tool.name.clone(),
+ description: tool.description.clone(),
+ parameters: tool.input_schema.clone(),
+ },
+ })
+ .collect::<Vec<_>>();
+
+ Ok(CopilotChatRequest {
+ intent: true,
+ n: 1,
+ stream: model.uses_streaming(),
+ temperature: 0.1,
+ model: model.id().to_string(),
+ messages,
+ tools,
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
+ LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
+ }),
+ })
+}
+
+fn into_copilot_responses(
+ model: &copilot::copilot_chat::Model,
+ request: LanguageModelRequest,
+) -> copilot::copilot_responses::Request {
+ use copilot::copilot_responses as responses;
+
+ let LanguageModelRequest {
+ thread_id: _,
+ prompt_id: _,
+ intent: _,
+ mode: _,
+ messages,
+ tools,
+ tool_choice,
+ stop: _,
+ temperature,
+ thinking_allowed: _,
+ } = request;
+
+ let mut input_items: Vec<responses::ResponseInputItem> = Vec::new();
+
+ for message in messages {
+ match message.role {
+ Role::User => {
+ for content in &message.content {
+ if let MessageContent::ToolResult(tool_result) = content {
+ let output = if let Some(out) = &tool_result.output {
+ match out {
+ serde_json::Value::String(s) => {
+ responses::ResponseFunctionOutput::Text(s.clone())
+ }
+ serde_json::Value::Null => {
+ responses::ResponseFunctionOutput::Text(String::new())
+ }
+ other => responses::ResponseFunctionOutput::Text(other.to_string()),
+ }
+ } else {
+ match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ responses::ResponseFunctionOutput::Text(text.to_string())
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ if model.supports_vision() {
+ responses::ResponseFunctionOutput::Content(vec![
+ responses::ResponseInputContent::InputImage {
+ image_url: Some(image.to_base64_url()),
+ detail: Default::default(),
+ },
+ ])
+ } else {
+ debug_panic!(
+ "This should be caught at {} level",
+ tool_result.tool_name
+ );
+ responses::ResponseFunctionOutput::Text(
+ "[Tool responded with an image, but this model does not support vision]".into(),
+ )
+ }
+ }
+ }
+ };
+
+ input_items.push(responses::ResponseInputItem::FunctionCallOutput {
+ call_id: tool_result.tool_use_id.to_string(),
+ output,
+ status: None,
+ });
+ }
+ }
+
+ let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) => {
+ parts.push(responses::ResponseInputContent::InputText {
+ text: text.clone(),
+ });
+ }
+
+ MessageContent::Image(image) => {
+ if model.supports_vision() {
+ parts.push(responses::ResponseInputContent::InputImage {
+ image_url: Some(image.to_base64_url()),
+ detail: Default::default(),
+ });
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if !parts.is_empty() {
+ input_items.push(responses::ResponseInputItem::Message {
+ role: "user".into(),
+ content: Some(parts),
+ status: None,
+ });
+ }
+ }
+
+ Role::Assistant => {
+ for content in &message.content {
+ if let MessageContent::ToolUse(tool_use) = content {
+ input_items.push(responses::ResponseInputItem::FunctionCall {
+ call_id: tool_use.id.to_string(),
+ name: tool_use.name.to_string(),
+ arguments: tool_use.raw_input.clone(),
+ status: None,
+ thought_signature: tool_use.thought_signature.clone(),
+ });
+ }
+ }
+
+ for content in &message.content {
+ if let MessageContent::RedactedThinking(data) = content {
+ input_items.push(responses::ResponseInputItem::Reasoning {
+ id: None,
+ summary: Vec::new(),
+ encrypted_content: data.clone(),
+ });
+ }
+ }
+
+ let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) => {
+ parts.push(responses::ResponseInputContent::OutputText {
+ text: text.clone(),
+ });
+ }
+ MessageContent::Image(_) => {
+ parts.push(responses::ResponseInputContent::OutputText {
+ text: "[image omitted]".to_string(),
+ });
+ }
+ _ => {}
+ }
+ }
+
+ if !parts.is_empty() {
+ input_items.push(responses::ResponseInputItem::Message {
+ role: "assistant".into(),
+ content: Some(parts),
+ status: Some("completed".into()),
+ });
+ }
+ }
+
+ Role::System => {
+ let mut parts: Vec<responses::ResponseInputContent> = Vec::new();
+ for content in &message.content {
+ if let MessageContent::Text(text) = content {
+ parts.push(responses::ResponseInputContent::InputText {
+ text: text.clone(),
+ });
+ }
+ }
+
+ if !parts.is_empty() {
+ input_items.push(responses::ResponseInputItem::Message {
+ role: "system".into(),
+ content: Some(parts),
+ status: None,
+ });
+ }
+ }
+ }
+ }
+
+ let converted_tools: Vec<responses::ToolDefinition> = tools
+ .into_iter()
+ .map(|tool| responses::ToolDefinition::Function {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ strict: None,
+ })
+ .collect();
+
+ let mapped_tool_choice = tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => responses::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => responses::ToolChoice::Any,
+ LanguageModelToolChoice::None => responses::ToolChoice::None,
+ });
+
+ responses::Request {
+ model: model.id().to_string(),
+ input: input_items,
+ stream: model.uses_streaming(),
+ temperature,
+ tools: converted_tools,
+ tool_choice: mapped_tool_choice,
+ reasoning: None, // We would need to add support for setting from user settings.
+ include: Some(vec![
+ copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent,
+ ]),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use copilot::copilot_responses as responses;
+ use futures::StreamExt;
+
+ fn map_events(events: Vec<responses::StreamEvent>) -> Vec<LanguageModelCompletionEvent> {
+ futures::executor::block_on(async {
+ CopilotResponsesEventMapper::new()
+ .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
+ .collect::<Vec<_>>()
+ .await
+ .into_iter()
+ .map(Result::unwrap)
+ .collect()
+ })
+ }
+
+ #[test]
+ fn responses_stream_maps_text_and_usage() {
+ let events = vec![
+ responses::StreamEvent::OutputItemAdded {
+ output_index: 0,
+ sequence_number: None,
+ item: responses::ResponseOutputItem::Message {
+ id: "msg_1".into(),
+ role: "assistant".into(),
+ content: Some(Vec::new()),
+ },
+ },
+ responses::StreamEvent::OutputTextDelta {
+ item_id: "msg_1".into(),
+ output_index: 0,
+ delta: "Hello".into(),
+ },
+ responses::StreamEvent::Completed {
+ response: responses::Response {
+ usage: Some(responses::ResponseUsage {
+ input_tokens: Some(5),
+ output_tokens: Some(3),
+ total_tokens: Some(8),
+ }),
+ ..Default::default()
+ },
+ },
+ ];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_1"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Text(ref text) if text == "Hello"
+ ));
+ assert!(matches!(
+ mapped[2],
+ LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: 5,
+ output_tokens: 3,
+ ..
+ })
+ ));
+ assert!(matches!(
+ mapped[3],
+ LanguageModelCompletionEvent::Stop(StopReason::EndTurn)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_maps_tool_calls() {
+ let events = vec![responses::StreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: responses::ResponseOutputItem::FunctionCall {
+ id: Some("fn_1".into()),
+ call_id: "call_1".into(),
+ name: "do_it".into(),
+ arguments: "{\"x\":1}".into(),
+ status: None,
+ thought_signature: None,
+ },
+ }];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUse(ref use_) if use_.id.to_string() == "call_1" && use_.name.as_ref() == "do_it"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_json_parse_error() {
+ let events = vec![responses::StreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: responses::ResponseOutputItem::FunctionCall {
+ id: Some("fn_1".into()),
+ call_id: "call_1".into(),
+ name: "do_it".into(),
+ arguments: "{not json}".into(),
+ status: None,
+ thought_signature: None,
+ },
+ }];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::ToolUseJsonParseError { ref id, ref tool_name, .. }
+ if id.to_string() == "call_1" && tool_name.as_ref() == "do_it"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_maps_reasoning_summary_and_encrypted_content() {
+ let events = vec![responses::StreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: responses::ResponseOutputItem::Reasoning {
+ id: "r1".into(),
+ summary: Some(vec![responses::ResponseReasoningItem {
+ kind: "summary_text".into(),
+ text: "Chain".into(),
+ }]),
+ encrypted_content: Some("ENC".into()),
+ },
+ }];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain"
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC"
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_incomplete_max_tokens() {
+ let events = vec![responses::StreamEvent::Incomplete {
+ response: responses::Response {
+ usage: Some(responses::ResponseUsage {
+ input_tokens: Some(10),
+ output_tokens: Some(0),
+ total_tokens: Some(10),
+ }),
+ incomplete_details: Some(responses::IncompleteDetails {
+ reason: Some(responses::IncompleteReason::MaxOutputTokens),
+ }),
+ ..Default::default()
+ },
+ }];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped[0],
+ LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: 10,
+ output_tokens: 0,
+ ..
+ })
+ ));
+ assert!(matches!(
+ mapped[1],
+ LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_handles_incomplete_content_filter() {
+ let events = vec![responses::StreamEvent::Incomplete {
+ response: responses::Response {
+ usage: None,
+ incomplete_details: Some(responses::IncompleteDetails {
+ reason: Some(responses::IncompleteReason::ContentFilter),
+ }),
+ ..Default::default()
+ },
+ }];
+
+ let mapped = map_events(events);
+ assert!(matches!(
+ mapped.last().unwrap(),
+ LanguageModelCompletionEvent::Stop(StopReason::Refusal)
+ ));
+ }
+
+ #[test]
+ fn responses_stream_completed_no_duplicate_after_tool_use() {
+ let events = vec![
+ responses::StreamEvent::OutputItemDone {
+ output_index: 0,
+ sequence_number: None,
+ item: responses::ResponseOutputItem::FunctionCall {
+ id: Some("fn_1".into()),
+ call_id: "call_1".into(),
+ name: "do_it".into(),
+ arguments: "{}".into(),
+ status: None,
+ thought_signature: None,
+ },
+ },
+ responses::StreamEvent::Completed {
+ response: responses::Response::default(),
+ },
+ ];
+
+ let mapped = map_events(events);
+
+ let mut stop_count = 0usize;
+ let mut saw_tool_use_stop = false;
+ for event in mapped {
+ if let LanguageModelCompletionEvent::Stop(reason) = event {
+ stop_count += 1;
+ if matches!(reason, StopReason::ToolUse) {
+ saw_tool_use_stop = true;
+ }
+ }
+ }
+ assert_eq!(stop_count, 1, "should emit exactly one Stop event");
+ assert!(saw_tool_use_stop, "Stop reason should be ToolUse");
+ }
+
+ #[test]
+ fn responses_stream_failed_maps_http_response_error() {
+ let events = vec![responses::StreamEvent::Failed {
+ response: responses::Response {
+ error: Some(responses::ResponseError {
+ code: "429".into(),
+ message: "too many requests".into(),
+ }),
+ ..Default::default()
+ },
+ }];
+
+ let mapped_results = futures::executor::block_on(async {
+ CopilotResponsesEventMapper::new()
+ .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok))))
+ .collect::<Vec<_>>()
+ .await
+ });
+
+ assert_eq!(mapped_results.len(), 1);
+ match &mapped_results[0] {
+ Err(LanguageModelCompletionError::HttpResponseError {
+ status_code,
+ message,
+ ..
+ }) => {
+ assert_eq!(*status_code, http_client::StatusCode::TOO_MANY_REQUESTS);
+ assert_eq!(message, "too many requests");
+ }
+ other => panic!("expected HttpResponseError, got {:?}", other),
+ }
+ }
+
+ #[test]
+ fn chat_completions_stream_maps_reasoning_data() {
+ use copilot::copilot_chat::ResponseEvent;
+
+ let events = vec![
+ ResponseEvent {
+ choices: vec![copilot::copilot_chat::ResponseChoice {
+ index: Some(0),
+ finish_reason: None,
+ delta: Some(copilot::copilot_chat::ResponseDelta {
+ content: None,
+ role: Some(copilot::copilot_chat::Role::Assistant),
+ tool_calls: vec![copilot::copilot_chat::ToolCallChunk {
+ index: Some(0),
+ id: Some("call_abc123".to_string()),
+ function: Some(copilot::copilot_chat::FunctionChunk {
+ name: Some("list_directory".to_string()),
+ arguments: Some("{\"path\":\"test\"}".to_string()),
+ thought_signature: None,
+ }),
+ }],
+ reasoning_opaque: Some("encrypted_reasoning_token_xyz".to_string()),
+ reasoning_text: Some("Let me check the directory".to_string()),
+ }),
+ message: None,
+ }],
+ id: "chatcmpl-123".to_string(),
+ usage: None,
+ },
+ ResponseEvent {
+ choices: vec![copilot::copilot_chat::ResponseChoice {
+ index: Some(0),
+ finish_reason: Some("tool_calls".to_string()),
+ delta: Some(copilot::copilot_chat::ResponseDelta {
+ content: None,
+ role: None,
+ tool_calls: vec![],
+ reasoning_opaque: None,
+ reasoning_text: None,
+ }),
+ message: None,
+ }],
+ id: "chatcmpl-123".to_string(),
+ usage: None,
+ },
+ ];
+
+ let mapped = futures::executor::block_on(async {
+ map_to_language_model_completion_events(
+ Box::pin(futures::stream::iter(events.into_iter().map(Ok))),
+ true,
+ )
+ .collect::<Vec<_>>()
+ .await
+ });
+
+ let mut has_reasoning_details = false;
+ let mut has_tool_use = false;
+ let mut reasoning_opaque_value: Option<String> = None;
+ let mut reasoning_text_value: Option<String> = None;
+
+ for event_result in mapped {
+ match event_result {
+ Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => {
+ has_reasoning_details = true;
+ reasoning_opaque_value = details
+ .get("reasoning_opaque")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+ reasoning_text_value = details
+ .get("reasoning_text")
+ .and_then(|v| v.as_str())
+ .map(|s| s.to_string());
+ }
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ has_tool_use = true;
+ assert_eq!(tool_use.id.to_string(), "call_abc123");
+ assert_eq!(tool_use.name.as_ref(), "list_directory");
+ }
+ _ => {}
+ }
+ }
+
+ assert!(
+ has_reasoning_details,
+ "Should emit ReasoningDetails event for Gemini 3 reasoning"
+ );
+ assert!(has_tool_use, "Should emit ToolUse event");
+ assert_eq!(
+ reasoning_opaque_value,
+ Some("encrypted_reasoning_token_xyz".to_string()),
+ "Should capture reasoning_opaque"
+ );
+ assert_eq!(
+ reasoning_text_value,
+ Some("Let me check the directory".to_string()),
+ "Should capture reasoning_text"
+ );
+ }
+}
+struct ConfigurationView {
+ copilot_status: Option<copilot::Status>,
+ state: Entity<State>,
+ _subscription: Option<Subscription>,
+}
+
+impl ConfigurationView {
+ pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
+ let copilot = Copilot::global(cx);
+
+ Self {
+ copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
+ state,
+ _subscription: copilot.as_ref().map(|copilot| {
+ cx.observe(copilot, |this, model, cx| {
+ this.copilot_status = Some(model.read(cx).status());
+ cx.notify();
+ })
+ }),
+ }
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ if self.state.read(cx).is_authenticated(cx) {
+ ConfiguredApiCard::new("Authorized")
+ .button_label("Sign Out")
+ .on_click(|_, window, cx| {
+ window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
+ })
+ .into_any_element()
+ } else {
+ let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4);
+
+ const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
+
+ match &self.copilot_status {
+ Some(status) => match status {
+ Status::Starting { task: _ } => h_flex()
+ .gap_2()
+ .child(loading_icon)
+ .child(Label::new("Starting Copilot…"))
+ .into_any_element(),
+ Status::SigningIn { prompt: _ }
+ | Status::SignedOut {
+ awaiting_signing_in: true,
+ } => h_flex()
+ .gap_2()
+ .child(loading_icon)
+ .child(Label::new("Signing into Copilot…"))
+ .into_any_element(),
+ Status::Error(_) => {
+ const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
+ v_flex()
+ .gap_6()
+ .child(Label::new(LABEL))
+ .child(svg().size_8().path(IconName::CopilotError.path()))
+ .into_any_element()
+ }
+ _ => {
+ const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
+
+ v_flex()
+ .gap_2()
+ .child(Label::new(LABEL))
+ .child(
+ Button::new("sign_in", "Sign in to use GitHub Copilot")
+ .full_width()
+ .style(ButtonStyle::Outlined)
+ .icon_color(Color::Muted)
+ .icon(IconName::Github)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .on_click(|_, window, cx| {
+ copilot::initiate_sign_in(window, cx)
+ }),
+ )
+ .into_any_element()
+ }
+ },
+ None => v_flex()
+ .gap_6()
+ .child(Label::new(ERROR_LABEL))
+ .into_any_element(),
+ }
+ }
+ }
+}
@@ -1,22 +1,44 @@
-use anyhow::Result;
-use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
+use anyhow::{Context as _, Result, anyhow};
+use collections::BTreeMap;
+use credentials_provider::CredentialsProvider;
+use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
ThinkingConfig, UsageMetadata,
};
-use gpui::{App, AppContext as _};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use http_client::HttpClient;
use language_model::{
- LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
- StopReason,
+ AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
+};
+use language_model::{
+ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, RateLimiter, Role,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub use settings::GoogleAvailableModel as AvailableModel;
-use std::{
- pin::Pin,
- sync::atomic::{self, AtomicU64},
+use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::sync::{
+ Arc, LazyLock,
+ atomic::{self, AtomicU64},
};
+use strum::IntoEnumIterator;
+use ui::{List, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+use zed_env_vars::EnvVar;
+
+use crate::api_key::ApiKey;
+use crate::api_key::ApiKeyState;
+use crate::ui::{ConfiguredApiCard, InstructionListItem};
+
+const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@@ -35,6 +57,346 @@ pub enum ModelMode {
},
}
+pub struct GoogleLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+pub struct State {
+ api_key_state: ApiKeyState,
+}
+
+const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
+const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
+
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
+ // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
+ EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
+});
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
+ }
+}
+
+impl GoogleLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
+ cx.notify();
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(GoogleLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
+ if let Some(key) = API_KEY_ENV_VAR.value.clone() {
+ return Task::ready(Ok(key));
+ }
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let api_url = Self::api_url(cx).to_string();
+ cx.spawn(async move |cx| {
+ Ok(
+ ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
+ .await?
+ .key()
+ .to_string(),
+ )
+ })
+ }
+
+ fn settings(cx: &App) -> &GoogleSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).google
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ google_ai::API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+}
+
+impl LanguageModelProviderState for GoogleLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for GoogleLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiGoogle
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(google_ai::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(google_ai::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ // Add base models from google_ai::Model::iter()
+ for model in google_ai::Model::iter() {
+ if !matches!(model, google_ai::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ // Override with available models from settings
+ for model in &GoogleLanguageModelProvider::settings(cx).available_models {
+ models.insert(
+ model.name.clone(),
+ google_ai::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ mode: model.mode.unwrap_or_default(),
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| {
+ Arc::new(GoogleLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ }) as Arc<dyn LanguageModel>
+ })
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct GoogleLanguageModel {
+ id: LanguageModelId,
+ model: google_ai::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+impl GoogleLanguageModel {
+ fn stream_completion(
+ &self,
+ request: google_ai::GenerateContentRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
+ > {
+ let http_client = self.http_client.clone();
+
+ let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ (state.api_key_state.key(&api_url), api_url)
+ }) else {
+ return future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ async move {
+ let api_key = api_key.context("Missing Google API key")?;
+ let request = google_ai::stream_generate_content(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ request.await.context("failed to stream completion")
+ }
+ .boxed()
+ }
+}
+
+impl LanguageModel for GoogleLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tools()
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images()
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto
+ | LanguageModelToolChoice::Any
+ | LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ LanguageModelToolSchemaFormat::JsonSchemaSubset
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("google/{}", self.model.request_id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ self.model.max_output_tokens()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ let model_id = self.model.request_id().to_string();
+ let request = into_google(request, model_id, self.model.mode());
+ let http_client = self.http_client.clone();
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ let api_key = self.state.read(cx).api_key_state.key(&api_url);
+
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ }
+ .into());
+ };
+ let response = google_ai::count_tokens(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ google_ai::CountTokensRequest {
+ generate_content_request: request,
+ },
+ )
+ .await?;
+ Ok(response.total_tokens)
+ }
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let request = into_google(
+ request,
+ self.model.request_id().to_string(),
+ self.model.mode(),
+ );
+ let request = self.stream_completion(request, cx);
+ let future = self.request_limiter.stream(async move {
+ let response = request.await.map_err(LanguageModelCompletionError::from)?;
+ Ok(GoogleEventMapper::new().map_stream(response))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
@@ -77,6 +439,7 @@ pub fn into_google(
})]
}
language_model::MessageContent::ToolUse(tool_use) => {
+ // Normalize empty string signatures to None
let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
@@ -94,6 +457,7 @@ pub fn into_google(
google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
+ // The API expects a valid JSON object
response: serde_json::json!({
"output": text
}),
@@ -106,6 +470,7 @@ pub fn into_google(
Part::FunctionResponsePart(google_ai::FunctionResponsePart {
function_response: google_ai::FunctionResponse {
name: tool_result.tool_name.to_string(),
+ // The API expects a valid JSON object
response: serde_json::json!({
"output": "Tool responded with an image"
}),
@@ -154,7 +519,7 @@ pub fn into_google(
role: match message.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
- Role::System => google_ai::Role::User,
+ Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
}
@@ -288,13 +653,13 @@ impl GoogleEventMapper {
Part::InlineDataPart(_) => {}
Part::FunctionCallPart(function_call_part) => {
wants_to_use_tool = true;
- let name: std::sync::Arc<str> =
- function_call_part.function_call.name.into();
+ let name: Arc<str> = function_call_part.function_call.name.into();
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
let id: LanguageModelToolUseId =
format!("{}-{}", name, next_tool_id).into();
+ // Normalize empty string signatures to None
let thought_signature = function_call_part
.thought_signature
.filter(|s| !s.is_empty());
@@ -313,7 +678,7 @@ impl GoogleEventMapper {
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
- text: "(Encrypted thought)".to_string(),
+ text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
@@ -321,6 +686,8 @@ impl GoogleEventMapper {
}
}
+ // Even when Gemini wants to use a Tool, the API
+ // responds with `finish_reason: STOP`
if wants_to_use_tool {
self.stop_reason = StopReason::ToolUse;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
@@ -333,6 +700,8 @@ pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
+ // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
+ // So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
let messages = request
.messages
@@ -349,6 +718,8 @@ pub fn count_google_tokens(
})
.collect::<Vec<_>>();
+ // Tiktoken doesn't yet support these models, so we manually use the
+ // same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
@@ -389,6 +760,148 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
}
}
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ target_agent: language_model::ConfigurationViewTargetAgent,
+ load_credentials_task: Option<Task<()>>,
+}
+
+impl ConfigurationView {
+ fn new(
+ state: Entity<State>,
+ target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ // We don't log an error, because "not signed in" is also an error.
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
+ target_agent,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ // url changes can cause the editor to be displayed again
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!(
+ "API key set in {} environment variable",
+ API_KEY_ENV_VAR.name
+ )
+ } else {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ if api_url == google_ai::API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ if self.load_credentials_task.is_some() {
+ div()
+ .child(Label::new("Loading credentials..."))
+ .into_any_element()
+ } else if self.should_render_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
+ ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
+ ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
+ })))
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create one by visiting",
+ Some("Google AI's console"),
+ Some("https://aistudio.google.com/app/apikey"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Paste your API key below and hit enter to start using the assistant",
+ )),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(
+ format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
+ })
+ .into_any_element()
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -427,7 +940,7 @@ mod tests {
let events = mapper.map_event(response);
- assert_eq!(events.len(), 2);
+ assert_eq!(events.len(), 2); // ToolUse event + Stop event
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(tool_use.name.as_ref(), "test_function");
@@ -521,25 +1034,18 @@ mod tests {
parts: vec![
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "function_a".to_string(),
- args: json!({}),
+ name: "function_1".to_string(),
+ args: json!({"arg": "value1"}),
},
- thought_signature: Some("sig_a".to_string()),
+ thought_signature: Some("signature_1".to_string()),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "function_b".to_string(),
- args: json!({}),
+ name: "function_2".to_string(),
+ args: json!({"arg": "value2"}),
},
thought_signature: None,
}),
- Part::FunctionCallPart(FunctionCallPart {
- function_call: FunctionCall {
- name: "function_c".to_string(),
- args: json!({}),
- },
- thought_signature: Some("sig_c".to_string()),
- }),
],
role: GoogleRole::Model,
},
@@ -554,35 +1060,35 @@ mod tests {
let events = mapper.map_event(response);
- let tool_uses: Vec<_> = events
- .iter()
- .filter_map(|e| {
- if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = e {
- Some(tool_use)
- } else {
- None
- }
- })
- .collect();
+ assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
- assert_eq!(tool_uses.len(), 3);
- assert_eq!(tool_uses[0].thought_signature.as_deref(), Some("sig_a"));
- assert_eq!(tool_uses[1].thought_signature, None);
- assert_eq!(tool_uses[2].thought_signature.as_deref(), Some("sig_c"));
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
+ assert_eq!(tool_use.name.as_ref(), "function_1");
+ assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
+ } else {
+ panic!("Expected ToolUse event for function_1");
+ }
+
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
+ assert_eq!(tool_use.name.as_ref(), "function_2");
+ assert_eq!(tool_use.thought_signature, None);
+ } else {
+ panic!("Expected ToolUse event for function_2");
+ }
}
#[test]
fn test_tool_use_with_signature_converts_to_function_call_part() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test-id"),
- name: "test_tool".into(),
- input: json!({"key": "value"}),
- raw_input: r#"{"key": "value"}"#.to_string(),
+ id: LanguageModelToolUseId::from("test_id"),
+ name: "test_function".into(),
+ raw_input: json!({"arg": "value"}).to_string(),
+ input: json!({"arg": "value"}),
is_input_complete: true,
- thought_signature: Some("test_sig".to_string()),
+ thought_signature: Some("test_signature_456".to_string()),
};
- let request = into_google(
+ let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -596,11 +1102,13 @@ mod tests {
GoogleModelMode::Default,
);
- let parts = &request.contents[0].parts;
- assert_eq!(parts.len(), 1);
-
- if let Part::FunctionCallPart(fcp) = &parts[0] {
- assert_eq!(fcp.thought_signature.as_deref(), Some("test_sig"));
+ assert_eq!(request.contents[0].parts.len(), 1);
+ if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
+ assert_eq!(fc_part.function_call.name, "test_function");
+ assert_eq!(
+ fc_part.thought_signature.as_deref(),
+ Some("test_signature_456")
+ );
} else {
panic!("Expected FunctionCallPart");
}
@@ -609,15 +1117,15 @@ mod tests {
#[test]
fn test_tool_use_without_signature_omits_field() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test-id"),
- name: "test_tool".into(),
- input: json!({"key": "value"}),
- raw_input: r#"{"key": "value"}"#.to_string(),
+ id: LanguageModelToolUseId::from("test_id"),
+ name: "test_function".into(),
+ raw_input: json!({"arg": "value"}).to_string(),
+ input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: None,
};
- let request = into_google(
+ let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -631,10 +1139,9 @@ mod tests {
GoogleModelMode::Default,
);
- let parts = &request.contents[0].parts;
-
- if let Part::FunctionCallPart(fcp) = &parts[0] {
- assert_eq!(fcp.thought_signature, None);
+ assert_eq!(request.contents[0].parts.len(), 1);
+ if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
+ assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
@@ -643,15 +1150,15 @@ mod tests {
#[test]
fn test_empty_signature_in_tool_use_normalized_to_none() {
let tool_use = language_model::LanguageModelToolUse {
- id: LanguageModelToolUseId::from("test-id"),
- name: "test_tool".into(),
- input: json!({}),
- raw_input: "{}".to_string(),
+ id: LanguageModelToolUseId::from("test_id"),
+ name: "test_function".into(),
+ raw_input: json!({"arg": "value"}).to_string(),
+ input: json!({"arg": "value"}),
is_input_complete: true,
thought_signature: Some("".to_string()),
};
- let request = into_google(
+ let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -665,10 +1172,8 @@ mod tests {
GoogleModelMode::Default,
);
- let parts = &request.contents[0].parts;
-
- if let Part::FunctionCallPart(fcp) = &parts[0] {
- assert_eq!(fcp.thought_signature, None);
+ if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
+ assert_eq!(fc_part.thought_signature, None);
} else {
panic!("Expected FunctionCallPart");
}
@@ -676,8 +1181,9 @@ mod tests {
#[test]
fn test_round_trip_preserves_signature() {
- let original_signature = "original_thought_signature_abc123";
+ let mut mapper = GoogleEventMapper::new();
+ // Simulate receiving a response from Google with a signature
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
@@ -687,7 +1193,7 @@ mod tests {
name: "test_function".to_string(),
args: json!({"arg": "value"}),
},
- thought_signature: Some(original_signature.to_string()),
+ thought_signature: Some("round_trip_sig".to_string()),
})],
role: GoogleRole::Model,
},
@@ -700,7 +1206,6 @@ mod tests {
usage_metadata: None,
};
- let mut mapper = GoogleEventMapper::new();
let events = mapper.map_event(response);
let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
@@ -709,7 +1214,8 @@ mod tests {
panic!("Expected ToolUse event");
};
- let request = into_google(
+ // Convert back to Google format
+ let request = super::into_google(
LanguageModelRequest {
messages: vec![language_model::LanguageModelRequestMessage {
role: Role::Assistant,
@@ -723,9 +1229,9 @@ mod tests {
GoogleModelMode::Default,
);
- let parts = &request.contents[0].parts;
- if let Part::FunctionCallPart(fcp) = &parts[0] {
- assert_eq!(fcp.thought_signature.as_deref(), Some(original_signature));
+ // Verify signature is preserved
+ if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
+ assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
} else {
panic!("Expected FunctionCallPart");
}
@@ -741,14 +1247,14 @@ mod tests {
content: Content {
parts: vec![
Part::TextPart(TextPart {
- text: "Let me help you with that.".to_string(),
+ text: "I'll help with that.".to_string(),
}),
Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "search".to_string(),
- args: json!({"query": "test"}),
+ name: "helper_function".to_string(),
+ args: json!({"query": "help"}),
},
- thought_signature: Some("thinking_sig".to_string()),
+ thought_signature: Some("mixed_sig".to_string()),
}),
],
role: GoogleRole::Model,
@@ -764,46 +1270,38 @@ mod tests {
let events = mapper.map_event(response);
- let mut found_text = false;
- let mut found_tool_with_sig = false;
+ assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
- for event in events {
- match event {
- Ok(LanguageModelCompletionEvent::Text(text)) => {
- assert_eq!(text, "Let me help you with that.");
- found_text = true;
- }
- Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
- assert_eq!(tool_use.thought_signature.as_deref(), Some("thinking_sig"));
- found_tool_with_sig = true;
- }
- _ => {}
- }
+ if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
+ assert_eq!(text, "I'll help with that.");
+ } else {
+ panic!("Expected Text event");
}
- assert!(found_text, "Should have found text event");
- assert!(
- found_tool_with_sig,
- "Should have found tool use with signature"
- );
+ if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
+ assert_eq!(tool_use.name.as_ref(), "helper_function");
+ assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
+ } else {
+ panic!("Expected ToolUse event");
+ }
}
#[test]
fn test_special_characters_in_signature_preserved() {
- let special_signature = "sig/with+special=chars&more%stuff";
-
let mut mapper = GoogleEventMapper::new();
+ let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
+
let response = GenerateContentResponse {
candidates: Some(vec![GenerateContentCandidate {
index: Some(0),
content: Content {
parts: vec![Part::FunctionCallPart(FunctionCallPart {
function_call: FunctionCall {
- name: "test".to_string(),
- args: json!({}),
+ name: "test_function".to_string(),
+ args: json!({"arg": "value"}),
},
- thought_signature: Some(special_signature.to_string()),
+ thought_signature: Some(signature_with_special_chars.clone()),
})],
role: GoogleRole::Model,
},
@@ -821,7 +1319,7 @@ mod tests {
if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
assert_eq!(
tool_use.thought_signature.as_deref(),
- Some(special_signature)
+ Some(signature_with_special_chars.as_str())
);
} else {
panic!("Expected ToolUse event");
@@ -1,17 +1,38 @@
use anyhow::{Result, anyhow};
-use collections::HashMap;
-use futures::{FutureExt, Stream, future::BoxFuture};
-use gpui::{App, AppContext as _};
+use collections::{BTreeMap, HashMap};
+use futures::Stream;
+use futures::{FutureExt, StreamExt, future, future::BoxFuture};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
+use http_client::HttpClient;
use language_model::{
- LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
- LanguageModelToolChoice, LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, StopReason, TokenUsage,
};
-use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent};
-pub use settings::OpenAiAvailableModel as AvailableModel;
+use menu;
+use open_ai::{
+ ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
+};
+use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
use std::pin::Pin;
-use std::str::FromStr;
+use std::str::FromStr as _;
+use std::sync::{Arc, LazyLock};
+use strum::IntoEnumIterator;
+use ui::{List, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
+
+use crate::ui::ConfiguredApiCard;
+use crate::{api_key::ApiKeyState, ui::InstructionListItem};
+
+const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
-use language_model::LanguageModelToolResultContent;
+const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@@ -19,6 +40,314 @@ pub struct OpenAiSettings {
pub available_models: Vec<AvailableModel>,
}
+pub struct OpenAiLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+pub struct State {
+ api_key_state: ApiKeyState,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
+ }
+}
+
+impl OpenAiLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
+ cx.notify();
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(OpenAiLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ fn settings(cx: &App) -> &OpenAiSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).openai
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ open_ai::OPEN_AI_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+}
+
+impl LanguageModelProviderState for OpenAiLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenAiLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiOpenAi
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(open_ai::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(open_ai::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ // Add base models from open_ai::Model::iter()
+ for model in open_ai::Model::iter() {
+ if !matches!(model, open_ai::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ // Override with available models from settings
+ for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
+ models.insert(
+ model.name.clone(),
+ open_ai::Model::Custom {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ max_output_tokens: model.max_output_tokens,
+ max_completion_tokens: model.max_completion_tokens,
+ reasoning_effort: model.reasoning_effort.clone(),
+ },
+ );
+ }
+
+ models
+ .into_values()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct OpenAiLanguageModel {
+ id: LanguageModelId,
+ model: open_ai::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+impl OpenAiLanguageModel {
+ fn stream_completion(
+ &self,
+ request: open_ai::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
+ {
+ let http_client = self.http_client.clone();
+
+ let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ (state.api_key_state.key(&api_url), api_url)
+ }) else {
+ return future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ let future = self.request_limiter.stream(async move {
+ let provider = PROVIDER_NAME;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey { provider });
+ };
+ let request = stream_completion(
+ http_client.as_ref(),
+ provider.0.as_str(),
+ &api_url,
+ &api_key,
+ request,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+impl LanguageModel for OpenAiLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ true
+ }
+
+ fn supports_images(&self) -> bool {
+ use open_ai::Model;
+ match &self.model {
+ Model::FourOmni
+ | Model::FourOmniMini
+ | Model::FourPointOne
+ | Model::FourPointOneMini
+ | Model::FourPointOneNano
+ | Model::Five
+ | Model::FiveMini
+ | Model::FiveNano
+ | Model::FivePointOne
+ | Model::O1
+ | Model::O3
+ | Model::O4Mini => true,
+ Model::ThreePointFiveTurbo
+ | Model::Four
+ | Model::FourTurbo
+ | Model::O3Mini
+ | Model::Custom { .. } => false,
+ }
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto => true,
+ LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("openai/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ self.model.max_output_tokens()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ count_open_ai_tokens(request, self.model.clone(), cx)
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let request = into_open_ai(
+ request,
+ self.model.id(),
+ self.model.supports_parallel_tool_calls(),
+ self.model.supports_prompt_cache_key(),
+ self.max_output_tokens(),
+ self.model.reasoning_effort(),
+ );
+ let completions = self.stream_completion(request, cx);
+ async move {
+ let mapper = OpenAiEventMapper::new();
+ Ok(mapper.map_stream(completions.await?).boxed())
+ }
+ .boxed()
+ }
+}
+
pub fn into_open_ai(
request: LanguageModelRequest,
model_id: &str,
@@ -112,6 +441,7 @@ pub fn into_open_ai(
temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
+ // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
Some(false)
} else {
None
@@ -191,7 +521,6 @@ impl OpenAiEventMapper {
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
- use futures::StreamExt;
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
@@ -319,12 +648,19 @@ pub fn count_open_ai_tokens(
match model {
Model::Custom { max_tokens, .. } => {
let model = if max_tokens >= 100_000 {
+ // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
"gpt-4o"
} else {
+ // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
+ // supported with this tiktoken method
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model, &messages)
}
+ // Currently supported by tiktoken_rs
+ // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
+ // arm with an override. We enumerate all supported models here so that we can check if new
+ // models are supported yet or not.
Model::ThreePointFiveTurbo
| Model::Four
| Model::FourTurbo
@@ -339,7 +675,7 @@ pub fn count_open_ai_tokens(
| Model::O4Mini
| Model::Five
| Model::FiveMini
- | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
+ | Model::FiveNano => tiktoken_rs::num_tokens_from_messages(model.id(), &messages), // GPT-5.1 doesn't have tiktoken support yet; fall back on gpt-4o tokenizer
Model::FivePointOne => tiktoken_rs::num_tokens_from_messages("gpt-5", &messages),
}
.map(|tokens| tokens as u64)
@@ -347,11 +683,191 @@ pub fn count_open_ai_tokens(
.boxed()
}
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ load_credentials_task: Option<Task<()>>,
+}
+
+impl ConfigurationView {
+ fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ InputField::new(
+ window,
+ cx,
+ "sk-000000000000000000000000000000000000000000000000",
+ )
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ // We don't log an error, because "not signed in" is also an error.
+ let _ = task.await;
+ }
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ // url changes can cause the editor to be displayed again
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |input, cx| input.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ } else {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ if api_url == OPEN_AI_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ let api_key_section = if self.should_render_editor(cx) {
+ v_flex()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create one by visiting",
+ Some("OpenAI's console"),
+ Some("https://platform.openai.com/api-keys"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Ensure your OpenAI account has credits",
+ ))
+ .child(InstructionListItem::text_only(
+ "Paste your API key below and hit enter to start using the assistant",
+ )),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(format!(
+ "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new(
+ "Note that having a subscription for another service like GitHub Copilot won't work.",
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
+ })
+ .into_any_element()
+ };
+
+ let compatible_api_section = h_flex()
+ .mt_1p5()
+ .gap_0p5()
+ .flex_wrap()
+ .when(self.should_render_editor(cx), |this| {
+ this.pt_1p5()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ })
+ .child(
+ h_flex()
+ .gap_2()
+ .child(
+ Icon::new(IconName::Info)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(Label::new("Zed also supports OpenAI-compatible models.")),
+ )
+ .child(
+ Button::new("docs", "Learn More")
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .on_click(move |_, _window, cx| {
+ cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
+ }),
+ );
+
+ if self.load_credentials_task.is_some() {
+ div().child(Label::new("Loading credentials…")).into_any()
+ } else {
+ v_flex()
+ .size_full()
+ .child(api_key_section)
+ .child(compatible_api_section)
+ .into_any()
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use gpui::TestAppContext;
use language_model::LanguageModelRequestMessage;
- use strum::IntoEnumIterator;
use super::*;
@@ -375,6 +891,7 @@ mod tests {
thinking_allowed: true,
};
+ // Validate that all models are supported by tiktoken-rs
for model in Model::iter() {
let count = cx
.executor()
@@ -0,0 +1,1095 @@
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task};
+use http_client::HttpClient;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
+};
+use open_router::{
+ Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models,
+};
+use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr as _;
+use std::sync::{Arc, LazyLock};
+use ui::{List, prelude::*};
+use ui_input::InputField;
+use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
+
+use crate::ui::ConfiguredApiCard;
+use crate::{api_key::ApiKeyState, ui::InstructionListItem};
+
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
+
+const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct OpenRouterSettings {
+ pub api_url: String,
+ pub available_models: Vec<AvailableModel>,
+}
+
+pub struct OpenRouterLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: Entity<State>,
+}
+
+pub struct State {
+ api_key_state: ApiKeyState,
+ http_client: Arc<dyn HttpClient>,
+ available_models: Vec<open_router::Model>,
+ fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
+}
+
+impl State {
+ fn is_authenticated(&self) -> bool {
+ self.api_key_state.has_key()
+ }
+
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
+ }
+
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
+
+ cx.spawn(async move |this, cx| {
+ let result = task.await;
+ this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
+ .ok();
+ result
+ })
+ }
+
+ fn fetch_models(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<(), LanguageModelCompletionError>> {
+ let http_client = self.http_client.clone();
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ let Some(api_key) = self.api_key_state.key(&api_url) else {
+ return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ }));
+ };
+ cx.spawn(async move |this, cx| {
+ let models = list_models(http_client.as_ref(), &api_url, &api_key)
+ .await
+ .map_err(|e| {
+ LanguageModelCompletionError::Other(anyhow::anyhow!(
+ "OpenRouter error: {:?}",
+ e
+ ))
+ })?;
+
+ this.update(cx, |this, cx| {
+ this.available_models = models;
+ cx.notify();
+ })
+ .map_err(|e| LanguageModelCompletionError::Other(e))?;
+
+ Ok(())
+ })
+ }
+
+ fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
+ if self.is_authenticated() {
+ let task = self.fetch_models(cx);
+ self.fetch_models_task.replace(task);
+ } else {
+ self.available_models = Vec::new();
+ }
+ }
+}
+
+impl OpenRouterLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>({
+ let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
+ move |this: &mut State, cx| {
+ let current_settings = OpenRouterLanguageModelProvider::settings(cx);
+ let settings_changed = current_settings != &last_settings;
+ if settings_changed {
+ last_settings = current_settings.clone();
+ this.authenticate(cx).detach();
+ cx.notify();
+ }
+ }
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ http_client: http_client.clone(),
+ available_models: Vec::new(),
+ fetch_models_task: None,
+ }
+ });
+
+ Self { http_client, state }
+ }
+
+ fn settings(cx: &App) -> &OpenRouterSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).open_router
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ OPEN_ROUTER_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
+
+ fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(OpenRouterLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+}
+
+impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
+ type ObservableEntity = State;
+
+ fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
+ Some(self.state.clone())
+ }
+}
+
+impl LanguageModelProvider for OpenRouterLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn icon(&self) -> IconName {
+ IconName::AiOpenRouter
+ }
+
+ fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(open_router::Model::default()))
+ }
+
+ fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
+ Some(self.create_language_model(open_router::Model::default_fast()))
+ }
+
+ fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models_from_api = self.state.read(cx).available_models.clone();
+ let mut settings_models = Vec::new();
+
+ for model in &Self::settings(cx).available_models {
+ settings_models.push(open_router::Model {
+ name: model.name.clone(),
+ display_name: model.display_name.clone(),
+ max_tokens: model.max_tokens,
+ supports_tools: model.supports_tools,
+ supports_images: model.supports_images,
+ mode: model.mode.unwrap_or_default(),
+ provider: model.provider.clone(),
+ });
+ }
+
+ for settings_model in &settings_models {
+ if let Some(pos) = models_from_api
+ .iter()
+ .position(|m| m.name == settings_model.name)
+ {
+ models_from_api[pos] = settings_model.clone();
+ } else {
+ models_from_api.push(settings_model.clone());
+ }
+ }
+
+ models_from_api
+ .into_iter()
+ .map(|model| self.create_language_model(model))
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &App) -> bool {
+ self.state.read(cx).is_authenticated()
+ }
+
+ fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
+ self.state.update(cx, |state, cx| state.authenticate(cx))
+ }
+
+ fn configuration_view(
+ &self,
+ _target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut App,
+ ) -> AnyView {
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ }
+}
+
+pub struct OpenRouterLanguageModel {
+ id: LanguageModelId,
+ model: open_router::Model,
+ state: Entity<State>,
+ http_client: Arc<dyn HttpClient>,
+ request_limiter: RateLimiter,
+}
+
+impl OpenRouterLanguageModel {
+ fn stream_completion(
+ &self,
+ request: open_router::Request,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<ResponseStreamEvent, open_router::OpenRouterError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ (state.api_key_state.key(&api_url), api_url)
+ }) else {
+ return future::ready(Err(anyhow!("App state dropped").into())).boxed();
+ };
+
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
+ let request =
+ open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ request.await.map_err(Into::into)
+ }
+ .boxed()
+ }
+}
+
+impl LanguageModel for OpenRouterLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ PROVIDER_ID
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ PROVIDER_NAME
+ }
+
+ fn supports_tools(&self) -> bool {
+ self.model.supports_tool_calls()
+ }
+
+ fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
+ let model_id = self.model.id().trim().to_lowercase();
+ if model_id.contains("gemini") || model_id.contains("grok") {
+ LanguageModelToolSchemaFormat::JsonSchemaSubset
+ } else {
+ LanguageModelToolSchemaFormat::JsonSchema
+ }
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("openrouter/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> u64 {
+ self.model.max_token_count()
+ }
+
+ fn max_output_tokens(&self) -> Option<u64> {
+ self.model.max_output_tokens()
+ }
+
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ match choice {
+ LanguageModelToolChoice::Auto => true,
+ LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => true,
+ }
+ }
+
+ fn supports_images(&self) -> bool {
+ self.model.supports_images.unwrap_or(false)
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &App,
+ ) -> BoxFuture<'static, Result<u64>> {
+ count_open_router_tokens(request, self.model.clone(), cx)
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
+ let request = into_open_router(request, &self.model, self.max_output_tokens());
+ let request = self.stream_completion(request, cx);
+ let future = self.request_limiter.stream(async move {
+ let response = request.await?;
+ Ok(OpenRouterEventMapper::new().map_stream(response))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+pub fn into_open_router(
+ request: LanguageModelRequest,
+ model: &Model,
+ max_output_tokens: Option<u64>,
+) -> open_router::Request {
+ let mut messages = Vec::new();
+ for message in request.messages {
+ let reasoning_details = message.reasoning_details.clone();
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => add_message_content_part(
+ open_router::MessagePart::Text { text },
+ message.role,
+ &mut messages,
+ ),
+ MessageContent::Thinking { .. } => {}
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(image) => {
+ add_message_content_part(
+ open_router::MessagePart::Image {
+ image_url: image.to_base64_url(),
+ },
+ message.role,
+ &mut messages,
+ );
+ }
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = open_router::ToolCall {
+ id: tool_use.id.to_string(),
+ content: open_router::ToolCallContent::Function {
+ function: open_router::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ thought_signature: tool_use.thought_signature.clone(),
+ },
+ },
+ };
+
+ if let Some(open_router::RequestMessage::Assistant {
+ tool_calls,
+ reasoning_details: existing_reasoning,
+ ..
+ }) = messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ if existing_reasoning.is_none() && reasoning_details.is_some() {
+ *existing_reasoning = reasoning_details.clone();
+ }
+ } else {
+ messages.push(open_router::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ reasoning_details: reasoning_details.clone(),
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ vec![open_router::MessagePart::Text {
+ text: text.to_string(),
+ }]
+ }
+ LanguageModelToolResultContent::Image(image) => {
+ vec![open_router::MessagePart::Image {
+ image_url: image.to_base64_url(),
+ }]
+ }
+ };
+
+ messages.push(open_router::RequestMessage::Tool {
+ content: content.into(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ }
+ }
+ }
+
+ open_router::Request {
+ model: model.id().into(),
+ messages,
+ stream: true,
+ stop: request.stop,
+ temperature: request.temperature.unwrap_or(0.4),
+ max_tokens: max_output_tokens,
+ parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
+ Some(false)
+ } else {
+ None
+ },
+ usage: open_router::RequestUsage { include: true },
+ reasoning: if request.thinking_allowed
+ && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
+ {
+ Some(open_router::Reasoning {
+ effort: None,
+ max_tokens: budget_tokens,
+ exclude: Some(false),
+ enabled: Some(true),
+ })
+ } else {
+ None
+ },
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| open_router::ToolDefinition::Function {
+ function: open_router::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => open_router::ToolChoice::Required,
+ LanguageModelToolChoice::None => open_router::ToolChoice::None,
+ }),
+ provider: model.provider.clone(),
+ }
+}
+
+fn add_message_content_part(
+ new_part: open_router::MessagePart,
+ role: Role,
+ messages: &mut Vec<open_router::RequestMessage>,
+) {
+ match (role, messages.last_mut()) {
+ (Role::User, Some(open_router::RequestMessage::User { content }))
+ | (Role::System, Some(open_router::RequestMessage::System { content })) => {
+ content.push_part(new_part);
+ }
+ (
+ Role::Assistant,
+ Some(open_router::RequestMessage::Assistant {
+ content: Some(content),
+ ..
+ }),
+ ) => {
+ content.push_part(new_part);
+ }
+ _ => {
+ messages.push(match role {
+ Role::User => open_router::RequestMessage::User {
+ content: open_router::MessageContent::from(vec![new_part]),
+ },
+ Role::Assistant => open_router::RequestMessage::Assistant {
+ content: Some(open_router::MessageContent::from(vec![new_part])),
+ tool_calls: Vec::new(),
+ reasoning_details: None,
+ },
+ Role::System => open_router::RequestMessage::System {
+ content: open_router::MessageContent::from(vec![new_part]),
+ },
+ });
+ }
+ }
+}
+
+pub struct OpenRouterEventMapper {
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+ reasoning_details: Option<serde_json::Value>,
+}
+
+impl OpenRouterEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ reasoning_details: None,
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<
+ Box<
+ dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
+ >,
+ >,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(error.into())],
+ })
+ })
+ }
+
+ pub fn map_event(
+ &mut self,
+ event: ResponseStreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let Some(choice) = event.choices.first() else {
+ return vec![Err(LanguageModelCompletionError::from(anyhow!(
+ "Response contained no choices"
+ )))];
+ };
+
+ let mut events = Vec::new();
+
+ if let Some(details) = choice.delta.reasoning_details.clone() {
+ // Emit reasoning_details immediately
+ events.push(Ok(LanguageModelCompletionEvent::ReasoningDetails(
+ details.clone(),
+ )));
+ self.reasoning_details = Some(details);
+ }
+
+ if let Some(reasoning) = choice.delta.reasoning.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Thinking {
+ text: reasoning,
+ signature: None,
+ }));
+ }
+
+ if let Some(content) = choice.delta.content.clone() {
+ // OpenRouter send empty content string with the reasoning content
+ // This is a workaround for the OpenRouter API bug
+ if !content.is_empty() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+
+ if let Some(signature) = function.thought_signature.clone() {
+ entry.thought_signature = Some(signature);
+ }
+ }
+ }
+ }
+
+ if let Some(usage) = event.usage {
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
+ input_tokens: usage.prompt_tokens,
+ output_tokens: usage.completion_tokens,
+ cache_creation_input_tokens: 0,
+ cache_read_input_tokens: 0,
+ })));
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ // Don't emit reasoning_details here - already emitted immediately when captured
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ thought_signature: tool_call.thought_signature.clone(),
+ },
+ )),
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.clone().into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.clone().into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ // Don't emit reasoning_details here - already emitted immediately when captured
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
+ // Don't emit reasoning_details here - already emitted immediately when captured
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
+ }
+
+ events
+ }
+}
+
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+ thought_signature: Option<String>,
+}
+
+pub fn count_open_router_tokens(
+ request: LanguageModelRequest,
+ _model: open_router::Model,
+ cx: &App,
+) -> BoxFuture<'static, Result<u64>> {
+ cx.background_spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.string_contents()),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
+ })
+ .boxed()
+}
+
+struct ConfigurationView {
+ api_key_editor: Entity<InputField>,
+ state: Entity<State>,
+ load_credentials_task: Option<Task<()>>,
+}
+
+impl ConfigurationView {
+ fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_editor = cx.new(|cx| {
+ InputField::new(
+ window,
+ cx,
+ "sk_or_000000000000000000000000000000000000000000000000",
+ )
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
+
+ let load_credentials_task = Some(cx.spawn_in(window, {
+ let state = state.clone();
+ async move |this, cx| {
+ if let Some(task) = state
+ .update(cx, |state, cx| state.authenticate(cx))
+ .log_err()
+ {
+ let _ = task.await;
+ }
+
+ this.update(cx, |this, cx| {
+ this.load_credentials_task = None;
+ cx.notify();
+ })
+ .log_err();
+ }
+ }));
+
+ Self {
+ api_key_editor,
+ state,
+ load_credentials_task,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ // url changes can cause the editor to be displayed again
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |editor, cx| editor.set_text("", window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ !self.state.read(cx).is_authenticated()
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
+ } else {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ if api_url == OPEN_ROUTER_API_URL {
+ "API key configured".to_string()
+ } else {
+ format!("API key configured for {}", api_url)
+ }
+ };
+
+ if self.load_credentials_task.is_some() {
+ div()
+ .child(Label::new("Loading credentials..."))
+ .into_any_element()
+ } else if self.should_render_editor(cx) {
+ v_flex()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:"))
+ .child(
+ List::new()
+ .child(InstructionListItem::new(
+ "Create an API key by visiting",
+ Some("OpenRouter's console"),
+ Some("https://openrouter.ai/keys"),
+ ))
+ .child(InstructionListItem::text_only(
+ "Ensure your OpenRouter account has credits",
+ ))
+ .child(InstructionListItem::text_only(
+ "Paste your API key below and hit enter to start using the assistant",
+ )),
+ )
+ .child(self.api_key_editor.clone())
+ .child(
+ Label::new(
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
+ )
+ .size(LabelSize::Small).color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
+ })
+ .into_any_element()
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use open_router::{ChoiceDelta, FunctionChunk, ResponseMessageDelta, ToolCallChunk};
+
+ #[gpui::test]
+ async fn test_reasoning_details_preservation_with_tool_calls() {
+ // This test verifies that reasoning_details are properly captured and preserved
+ // when a model uses tool calling with reasoning/thinking tokens.
+ //
+ // The key regression this prevents:
+ // - OpenRouter sends multiple reasoning_details updates during streaming
+ // - First with actual content (encrypted reasoning data)
+ // - Then with empty array on completion
+ // - We must NOT overwrite the real data with the empty array
+
+ let mut mapper = OpenRouterEventMapper::new();
+
+ // Simulate the streaming events as they come from OpenRouter/Gemini
+ let events = vec![
+ // Event 1: Initial reasoning details with text
+ ResponseStreamEvent {
+ id: Some("response_123".into()),
+ created: 1234567890,
+ model: "google/gemini-3-pro-preview".into(),
+ choices: vec![ChoiceDelta {
+ index: 0,
+ delta: ResponseMessageDelta {
+ role: None,
+ content: None,
+ reasoning: None,
+ tool_calls: None,
+ reasoning_details: Some(serde_json::json!([
+ {
+ "type": "reasoning.text",
+ "text": "Let me analyze this request...",
+ "format": "google-gemini-v1",
+ "index": 0
+ }
+ ])),
+ },
+ finish_reason: None,
+ }],
+ usage: None,
+ },
+ // Event 2: More reasoning details
+ ResponseStreamEvent {
+ id: Some("response_123".into()),
+ created: 1234567890,
+ model: "google/gemini-3-pro-preview".into(),
+ choices: vec![ChoiceDelta {
+ index: 0,
+ delta: ResponseMessageDelta {
+ role: None,
+ content: None,
+ reasoning: None,
+ tool_calls: None,
+ reasoning_details: Some(serde_json::json!([
+ {
+ "type": "reasoning.encrypted",
+ "data": "EtgDCtUDAdHtim9OF5jm4aeZSBAtl/randomized123",
+ "format": "google-gemini-v1",
+ "index": 0,
+ "id": "tool_call_abc123"
+ }
+ ])),
+ },
+ finish_reason: None,
+ }],
+ usage: None,
+ },
+ // Event 3: Tool call starts
+ ResponseStreamEvent {
+ id: Some("response_123".into()),
+ created: 1234567890,
+ model: "google/gemini-3-pro-preview".into(),
+ choices: vec![ChoiceDelta {
+ index: 0,
+ delta: ResponseMessageDelta {
+ role: None,
+ content: None,
+ reasoning: None,
+ tool_calls: Some(vec![ToolCallChunk {
+ index: 0,
+ id: Some("tool_call_abc123".into()),
+ function: Some(FunctionChunk {
+ name: Some("list_directory".into()),
+ arguments: Some("{\"path\":\"test\"}".into()),
+ thought_signature: Some("sha256:test_signature_xyz789".into()),
+ }),
+ }]),
+ reasoning_details: None,
+ },
+ finish_reason: None,
+ }],
+ usage: None,
+ },
+ // Event 4: Empty reasoning_details on tool_calls finish
+ // This is the critical event - we must not overwrite with this empty array!
+ ResponseStreamEvent {
+ id: Some("response_123".into()),
+ created: 1234567890,
+ model: "google/gemini-3-pro-preview".into(),
+ choices: vec![ChoiceDelta {
+ index: 0,
+ delta: ResponseMessageDelta {
+ role: None,
+ content: None,
+ reasoning: None,
+ tool_calls: None,
+ reasoning_details: Some(serde_json::json!([])),
+ },
+ finish_reason: Some("tool_calls".into()),
+ }],
+ usage: None,
+ },
+ ];
+
+ // Process all events
+ let mut collected_events = Vec::new();
+ for event in events {
+ let mapped = mapper.map_event(event);
+ collected_events.extend(mapped);
+ }
+
+ // Verify we got the expected events
+ let mut has_tool_use = false;
+ let mut reasoning_details_events = Vec::new();
+ let mut thought_signature_value = None;
+
+ for event_result in collected_events {
+ match event_result {
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
+ has_tool_use = true;
+ assert_eq!(tool_use.id.to_string(), "tool_call_abc123");
+ assert_eq!(tool_use.name.as_ref(), "list_directory");
+ thought_signature_value = tool_use.thought_signature.clone();
+ }
+ Ok(LanguageModelCompletionEvent::ReasoningDetails(details)) => {
+ reasoning_details_events.push(details);
+ }
+ _ => {}
+ }
+ }
+
+ // Assertions
+ assert!(has_tool_use, "Should have emitted ToolUse event");
+ assert!(
+ !reasoning_details_events.is_empty(),
+ "Should have emitted ReasoningDetails events"
+ );
+
+ // We should have received multiple reasoning_details events (text, encrypted, empty)
+ // The agent layer is responsible for keeping only the first non-empty one
+ assert!(
+ reasoning_details_events.len() >= 2,
+ "Should have multiple reasoning_details events from streaming"
+ );
+
+ // Verify at least one contains the encrypted data
+ let has_encrypted = reasoning_details_events.iter().any(|details| {
+ if let serde_json::Value::Array(arr) = details {
+ arr.iter().any(|item| {
+ item["type"] == "reasoning.encrypted"
+ && item["data"]
+ .as_str()
+ .map_or(false, |s| s.contains("EtgDCtUDAdHtim9OF5jm4aeZSBAtl"))
+ })
+ } else {
+ false
+ }
+ });
+ assert!(
+ has_encrypted,
+ "Should have at least one reasoning_details with encrypted data"
+ );
+
+ // Verify thought_signature was captured
+ assert!(
+ thought_signature_value.is_some(),
+ "Tool use should have thought_signature"
+ );
+ assert_eq!(
+ thought_signature_value.unwrap(),
+ "sha256:test_signature_xyz789"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_agent_prevents_empty_reasoning_details_overwrite() {
+ // This test verifies that the agent layer prevents empty reasoning_details
+ // from overwriting non-empty ones, even though the mapper emits all events.
+
+ // Simulate what the agent does when it receives multiple ReasoningDetails events
+ let mut agent_reasoning_details: Option<serde_json::Value> = None;
+
+ let events = vec![
+ // First event: non-empty reasoning_details
+ serde_json::json!([
+ {
+ "type": "reasoning.encrypted",
+ "data": "real_data_here",
+ "format": "google-gemini-v1"
+ }
+ ]),
+ // Second event: empty array (should not overwrite)
+ serde_json::json!([]),
+ ];
+
+ for details in events {
+ // This mimics the agent's logic: only store if we don't already have it
+ if agent_reasoning_details.is_none() {
+ agent_reasoning_details = Some(details);
+ }
+ }
+
+ // Verify the agent kept the first non-empty reasoning_details
+ assert!(agent_reasoning_details.is_some());
+ let final_details = agent_reasoning_details.unwrap();
+ if let serde_json::Value::Array(arr) = &final_details {
+ assert!(
+ !arr.is_empty(),
+ "Agent should have kept the non-empty reasoning_details"
+ );
+ assert_eq!(arr[0]["data"], "real_data_here");
+ } else {
+ panic!("Expected array");
+ }
+ }
+}
@@ -4,22 +4,16 @@ use collections::HashMap;
use settings::RegisterSetting;
use crate::provider::{
- bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, deepseek::DeepSeekSettings,
- google::GoogleSettings, lmstudio::LmStudioSettings, mistral::MistralSettings,
- ollama::OllamaSettings, open_ai::OpenAiSettings, open_ai_compatible::OpenAiCompatibleSettings,
+ anthropic::AnthropicSettings, bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings,
+ deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings,
+ mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings,
+ open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings,
vercel::VercelSettings, x_ai::XAiSettings,
};
-pub use settings::OpenRouterAvailableModel as AvailableModel;
-
-#[derive(Default, Clone, Debug, PartialEq)]
-pub struct OpenRouterSettings {
- pub api_url: String,
- pub available_models: Vec<AvailableModel>,
-}
-
#[derive(Debug, RegisterSetting)]
pub struct AllLanguageModelSettings {
+ pub anthropic: AnthropicSettings,
pub bedrock: AmazonBedrockSettings,
pub deepseek: DeepSeekSettings,
pub google: GoogleSettings,
@@ -39,6 +33,7 @@ impl settings::Settings for AllLanguageModelSettings {
fn from_settings(content: &settings::SettingsContent) -> Self {
let language_models = content.language_models.clone().unwrap();
+ let anthropic = language_models.anthropic.unwrap();
let bedrock = language_models.bedrock.unwrap();
let deepseek = language_models.deepseek.unwrap();
let google = language_models.google.unwrap();
@@ -52,12 +47,16 @@ impl settings::Settings for AllLanguageModelSettings {
let x_ai = language_models.x_ai.unwrap();
let zed_dot_dev = language_models.zed_dot_dev.unwrap();
Self {
+ anthropic: AnthropicSettings {
+ api_url: anthropic.api_url.unwrap(),
+ available_models: anthropic.available_models.unwrap_or_default(),
+ },
bedrock: AmazonBedrockSettings {
available_models: bedrock.available_models.unwrap_or_default(),
region: bedrock.region,
- endpoint: bedrock.endpoint_url,
+ endpoint: bedrock.endpoint_url, // todo(should be api_url)
profile_name: bedrock.profile,
- role_arn: None,
+ role_arn: None, // todo(was never a setting for this...)
authentication_method: bedrock.authentication_method.map(Into::into),
allow_global: bedrock.allow_global,
},
@@ -8,6 +8,7 @@ use std::sync::Arc;
#[with_fallible_options]
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
pub struct AllLanguageModelSettingsContent {
+ pub anthropic: Option<AnthropicSettingsContent>,
pub bedrock: Option<AmazonBedrockSettingsContent>,
pub deepseek: Option<DeepseekSettingsContent>,
pub google: Option<GoogleSettingsContent>,
@@ -23,6 +24,35 @@ pub struct AllLanguageModelSettingsContent {
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
}
+#[with_fallible_options]
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
+pub struct AnthropicSettingsContent {
+ pub api_url: Option<String>,
+ pub available_models: Option<Vec<AnthropicAvailableModel>>,
+}
+
+#[with_fallible_options]
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom)]
+pub struct AnthropicAvailableModel {
+ /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc
+ pub name: String,
+ /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
+ pub display_name: Option<String>,
+ /// The model's context window size.
+ pub max_tokens: u64,
+ /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
+ pub tool_override: Option<String>,
+ /// Configuration of Anthropic's caching API.
+ pub cache_configuration: Option<LanguageModelCacheConfiguration>,
+ pub max_output_tokens: Option<u64>,
+ #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")]
+ pub default_temperature: Option<f32>,
+ #[serde(default)]
+ pub extra_beta_headers: Vec<String>,
+ /// The model's mode (e.g. thinking)
+ pub mode: Option<ModelMode>,
+}
+
#[with_fallible_options]
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)]
pub struct AmazonBedrockSettingsContent {