custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use fs::Fs;
  7use gpui::{App, AppContext as _, Entity, Task};
  8use language_model::{ApiKey, EnvVar};
  9use project::{
 10    Project,
 11    agent_server_store::{AgentId, AllAgentServersSettings},
 12};
 13use settings::{SettingsStore, update_settings_file};
 14use std::{rc::Rc, sync::Arc};
 15use ui::IconName;
 16
 17pub const GEMINI_ID: &str = "gemini";
 18pub const CLAUDE_AGENT_ID: &str = "claude-acp";
 19pub const CODEX_ID: &str = "codex-acp";
 20
 21/// A generic agent server implementation for custom user-defined agents
 22pub struct CustomAgentServer {
 23    agent_id: AgentId,
 24}
 25
 26impl CustomAgentServer {
 27    pub fn new(agent_id: AgentId) -> Self {
 28        Self { agent_id }
 29    }
 30}
 31
 32impl AgentServer for CustomAgentServer {
 33    fn agent_id(&self) -> AgentId {
 34        self.agent_id.clone()
 35    }
 36
 37    fn logo(&self) -> IconName {
 38        IconName::Terminal
 39    }
 40
 41    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 42        let settings = cx.read_global(|settings: &SettingsStore, _| {
 43            settings
 44                .get::<AllAgentServersSettings>(None)
 45                .get(self.agent_id().0.as_ref())
 46                .cloned()
 47        });
 48
 49        settings
 50            .as_ref()
 51            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 52    }
 53
 54    fn favorite_config_option_value_ids(
 55        &self,
 56        config_id: &acp::SessionConfigId,
 57        cx: &mut App,
 58    ) -> HashSet<acp::SessionConfigValueId> {
 59        let settings = cx.read_global(|settings: &SettingsStore, _| {
 60            settings
 61                .get::<AllAgentServersSettings>(None)
 62                .get(self.agent_id().0.as_ref())
 63                .cloned()
 64        });
 65
 66        settings
 67            .as_ref()
 68            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 69            .map(|values| {
 70                values
 71                    .iter()
 72                    .cloned()
 73                    .map(acp::SessionConfigValueId::new)
 74                    .collect()
 75            })
 76            .unwrap_or_default()
 77    }
 78
 79    fn toggle_favorite_config_option_value(
 80        &self,
 81        config_id: acp::SessionConfigId,
 82        value_id: acp::SessionConfigValueId,
 83        should_be_favorite: bool,
 84        fs: Arc<dyn Fs>,
 85        cx: &App,
 86    ) {
 87        let agent_id = self.agent_id();
 88        let config_id = config_id.to_string();
 89        let value_id = value_id.to_string();
 90
 91        update_settings_file(fs, cx, move |settings, cx| {
 92            let settings = settings
 93                .agent_servers
 94                .get_or_insert_default()
 95                .entry(agent_id.0.to_string())
 96                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
 97
 98            match settings {
 99                settings::CustomAgentServerSettings::Custom {
100                    favorite_config_option_values,
101                    ..
102                }
103                | settings::CustomAgentServerSettings::Extension {
104                    favorite_config_option_values,
105                    ..
106                }
107                | settings::CustomAgentServerSettings::Registry {
108                    favorite_config_option_values,
109                    ..
110                } => {
111                    let entry = favorite_config_option_values
112                        .entry(config_id.clone())
113                        .or_insert_with(Vec::new);
114
115                    if should_be_favorite {
116                        if !entry.iter().any(|v| v == &value_id) {
117                            entry.push(value_id.clone());
118                        }
119                    } else {
120                        entry.retain(|v| v != &value_id);
121                        if entry.is_empty() {
122                            favorite_config_option_values.remove(&config_id);
123                        }
124                    }
125                }
126            }
127        });
128    }
129
130    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
131        let agent_id = self.agent_id();
132        update_settings_file(fs, cx, move |settings, cx| {
133            let settings = settings
134                .agent_servers
135                .get_or_insert_default()
136                .entry(agent_id.0.to_string())
137                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
138
139            match settings {
140                settings::CustomAgentServerSettings::Custom { default_mode, .. }
141                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
142                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
143                    *default_mode = mode_id.map(|m| m.to_string());
144                }
145            }
146        });
147    }
148
149    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
150        let settings = cx.read_global(|settings: &SettingsStore, _| {
151            settings
152                .get::<AllAgentServersSettings>(None)
153                .get(self.agent_id().as_ref())
154                .cloned()
155        });
156
157        settings
158            .as_ref()
159            .and_then(|s| s.default_model().map(acp::ModelId::new))
160    }
161
162    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
163        let agent_id = self.agent_id();
164        update_settings_file(fs, cx, move |settings, cx| {
165            let settings = settings
166                .agent_servers
167                .get_or_insert_default()
168                .entry(agent_id.0.to_string())
169                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
170
171            match settings {
172                settings::CustomAgentServerSettings::Custom { default_model, .. }
173                | settings::CustomAgentServerSettings::Extension { default_model, .. }
174                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
175                    *default_model = model_id.map(|m| m.to_string());
176                }
177            }
178        });
179    }
180
181    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
182        let settings = cx.read_global(|settings: &SettingsStore, _| {
183            settings
184                .get::<AllAgentServersSettings>(None)
185                .get(self.agent_id().as_ref())
186                .cloned()
187        });
188
189        settings
190            .as_ref()
191            .map(|s| {
192                s.favorite_models()
193                    .iter()
194                    .map(|id| acp::ModelId::new(id.clone()))
195                    .collect()
196            })
197            .unwrap_or_default()
198    }
199
200    fn toggle_favorite_model(
201        &self,
202        model_id: acp::ModelId,
203        should_be_favorite: bool,
204        fs: Arc<dyn Fs>,
205        cx: &App,
206    ) {
207        let agent_id = self.agent_id();
208        update_settings_file(fs, cx, move |settings, cx| {
209            let settings = settings
210                .agent_servers
211                .get_or_insert_default()
212                .entry(agent_id.0.to_string())
213                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
214
215            let favorite_models = match settings {
216                settings::CustomAgentServerSettings::Custom {
217                    favorite_models, ..
218                }
219                | settings::CustomAgentServerSettings::Extension {
220                    favorite_models, ..
221                }
222                | settings::CustomAgentServerSettings::Registry {
223                    favorite_models, ..
224                } => favorite_models,
225            };
226
227            let model_id_str = model_id.to_string();
228            if should_be_favorite {
229                if !favorite_models.contains(&model_id_str) {
230                    favorite_models.push(model_id_str);
231                }
232            } else {
233                favorite_models.retain(|id| id != &model_id_str);
234            }
235        });
236    }
237
238    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
239        let settings = cx.read_global(|settings: &SettingsStore, _| {
240            settings
241                .get::<AllAgentServersSettings>(None)
242                .get(self.agent_id().as_ref())
243                .cloned()
244        });
245
246        settings
247            .as_ref()
248            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
249    }
250
251    fn set_default_config_option(
252        &self,
253        config_id: &str,
254        value_id: Option<&str>,
255        fs: Arc<dyn Fs>,
256        cx: &mut App,
257    ) {
258        let agent_id = self.agent_id();
259        let config_id = config_id.to_string();
260        let value_id = value_id.map(|s| s.to_string());
261        update_settings_file(fs, cx, move |settings, cx| {
262            let settings = settings
263                .agent_servers
264                .get_or_insert_default()
265                .entry(agent_id.0.to_string())
266                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
267
268            match settings {
269                settings::CustomAgentServerSettings::Custom {
270                    default_config_options,
271                    ..
272                }
273                | settings::CustomAgentServerSettings::Extension {
274                    default_config_options,
275                    ..
276                }
277                | settings::CustomAgentServerSettings::Registry {
278                    default_config_options,
279                    ..
280                } => {
281                    if let Some(value) = value_id.clone() {
282                        default_config_options.insert(config_id.clone(), value);
283                    } else {
284                        default_config_options.remove(&config_id);
285                    }
286                }
287            }
288        });
289    }
290
291    fn connect(
292        &self,
293        delegate: AgentServerDelegate,
294        project: Entity<Project>,
295        cx: &mut App,
296    ) -> Task<Result<Rc<dyn AgentConnection>>> {
297        let agent_id = self.agent_id();
298        let default_mode = self.default_mode(cx);
299        let default_model = self.default_model(cx);
300        let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
301        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
302            settings
303                .get::<AllAgentServersSettings>(None)
304                .get(self.agent_id().as_ref())
305                .map(|s| match s {
306                    project::agent_server_store::CustomAgentServerSettings::Custom {
307                        default_config_options,
308                        ..
309                    }
310                    | project::agent_server_store::CustomAgentServerSettings::Extension {
311                        default_config_options,
312                        ..
313                    }
314                    | project::agent_server_store::CustomAgentServerSettings::Registry {
315                        default_config_options,
316                        ..
317                    } => default_config_options.clone(),
318                })
319                .unwrap_or_default()
320        });
321
322        if is_registry_agent {
323            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
324                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
325            }
326        }
327
328        let mut extra_env = load_proxy_env(cx);
329        if delegate.store.read(cx).no_browser() {
330            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
331        }
332        if is_registry_agent {
333            match agent_id.as_ref() {
334                CLAUDE_AGENT_ID => {
335                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
336                }
337                CODEX_ID => {
338                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
339                        extra_env.insert("CODEX_API_KEY".into(), api_key);
340                    }
341                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
342                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
343                    }
344                }
345                GEMINI_ID => {
346                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
347                }
348                _ => {}
349            }
350        }
351        let store = delegate.store.downgrade();
352        cx.spawn(async move |cx| {
353            if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
354                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
355                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
356                }
357            }
358            let command = store
359                .update(cx, |store, cx| {
360                    let agent = store.get_external_agent(&agent_id).with_context(|| {
361                        format!("Custom agent server `{}` is not registered", agent_id)
362                    })?;
363                    anyhow::Ok(agent.get_command(
364                        extra_env,
365                        delegate.new_version_available,
366                        &mut cx.to_async(),
367                    ))
368                })??
369                .await?;
370            let connection = crate::acp::connect(
371                agent_id,
372                project,
373                command,
374                default_mode,
375                default_model,
376                default_config_options,
377                cx,
378            )
379            .await?;
380            Ok(connection)
381        })
382    }
383
384    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
385        self
386    }
387}
388
389fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
390    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
391    if let Some(key) = env_var.value {
392        return Task::ready(Ok(key));
393    }
394    let credentials_provider = zed_credentials_provider::global(cx);
395    let api_url = google_ai::API_URL.to_string();
396    cx.spawn(async move |cx| {
397        Ok(
398            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
399                .await?
400                .key()
401                .to_string(),
402        )
403    })
404}
405
406fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
407    let agent_id = agent_id.into();
408    let is_in_registry = project::AgentRegistryStore::try_global(cx)
409        .map(|store| store.read(cx).agent(&agent_id).is_some())
410        .unwrap_or(false);
411    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
412        settings
413            .get::<AllAgentServersSettings>(None)
414            .get(agent_id.as_ref())
415            .is_some_and(|s| {
416                matches!(
417                    s,
418                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
419                )
420            })
421    });
422    is_in_registry || is_settings_registry
423}
424
425fn default_settings_for_agent(
426    agent_id: impl Into<AgentId>,
427    cx: &App,
428) -> settings::CustomAgentServerSettings {
429    if is_registry_agent(agent_id, cx) {
430        settings::CustomAgentServerSettings::Registry {
431            default_model: None,
432            default_mode: None,
433            env: Default::default(),
434            favorite_models: Vec::new(),
435            default_config_options: Default::default(),
436            favorite_config_option_values: Default::default(),
437        }
438    } else {
439        settings::CustomAgentServerSettings::Extension {
440            default_model: None,
441            default_mode: None,
442            env: Default::default(),
443            favorite_models: Vec::new(),
444            default_config_options: Default::default(),
445            favorite_config_option_values: Default::default(),
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use collections::HashMap;
454    use gpui::TestAppContext;
455    use project::agent_registry_store::{
456        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
457    };
458    use settings::Settings as _;
459    use ui::SharedString;
460
461    fn init_test(cx: &mut TestAppContext) {
462        cx.update(|cx| {
463            let settings_store = SettingsStore::test(cx);
464            cx.set_global(settings_store);
465        });
466    }
467
468    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
469        let agents: Vec<RegistryAgent> = agent_ids
470            .iter()
471            .map(|id| {
472                let id = SharedString::from(id.to_string());
473                RegistryAgent::Npx(RegistryNpxAgent {
474                    metadata: RegistryAgentMetadata {
475                        id: AgentId::new(id.clone()),
476                        name: id.clone(),
477                        description: SharedString::from(""),
478                        version: SharedString::from("1.0.0"),
479                        repository: None,
480                        website: None,
481                        icon_path: None,
482                    },
483                    package: id,
484                    args: Vec::new(),
485                    env: HashMap::default(),
486                })
487            })
488            .collect();
489        cx.update(|cx| {
490            AgentRegistryStore::init_test_global(cx, agents);
491        });
492    }
493
494    fn set_agent_server_settings(
495        cx: &mut TestAppContext,
496        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
497    ) {
498        cx.update(|cx| {
499            AllAgentServersSettings::override_global(
500                project::agent_server_store::AllAgentServersSettings(
501                    entries
502                        .into_iter()
503                        .map(|(name, settings)| (name.to_string(), settings.into()))
504                        .collect(),
505                ),
506                cx,
507            );
508        });
509    }
510
511    #[gpui::test]
512    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
513        init_test(cx);
514        cx.update(|cx| {
515            assert!(!is_registry_agent("my-custom-agent", cx));
516        });
517    }
518
519    #[gpui::test]
520    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
521        init_test(cx);
522        init_registry_with_agents(cx, &["some-new-registry-agent"]);
523        cx.update(|cx| {
524            assert!(is_registry_agent("some-new-registry-agent", cx));
525            assert!(!is_registry_agent("not-in-registry", cx));
526        });
527    }
528
529    #[gpui::test]
530    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
531        init_test(cx);
532        set_agent_server_settings(
533            cx,
534            vec![(
535                "agent-from-settings",
536                settings::CustomAgentServerSettings::Registry {
537                    env: HashMap::default(),
538                    default_mode: None,
539                    default_model: None,
540                    favorite_models: Vec::new(),
541                    default_config_options: HashMap::default(),
542                    favorite_config_option_values: HashMap::default(),
543                },
544            )],
545        );
546        cx.update(|cx| {
547            assert!(is_registry_agent("agent-from-settings", cx));
548        });
549    }
550
551    #[gpui::test]
552    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
553        init_test(cx);
554        set_agent_server_settings(
555            cx,
556            vec![(
557                "my-extension-agent",
558                settings::CustomAgentServerSettings::Extension {
559                    env: HashMap::default(),
560                    default_mode: None,
561                    default_model: None,
562                    favorite_models: Vec::new(),
563                    default_config_options: HashMap::default(),
564                    favorite_config_option_values: HashMap::default(),
565                },
566            )],
567        );
568        cx.update(|cx| {
569            assert!(!is_registry_agent("my-extension-agent", cx));
570        });
571    }
572
573    #[gpui::test]
574    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
575        init_test(cx);
576        cx.update(|cx| {
577            assert!(matches!(
578                default_settings_for_agent("some-extension-agent", cx),
579                settings::CustomAgentServerSettings::Extension { .. }
580            ));
581        });
582    }
583
584    #[gpui::test]
585    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
586        init_test(cx);
587        init_registry_with_agents(cx, &["new-registry-agent"]);
588        cx.update(|cx| {
589            assert!(matches!(
590                default_settings_for_agent("new-registry-agent", cx),
591                settings::CustomAgentServerSettings::Registry { .. }
592            ));
593            assert!(matches!(
594                default_settings_for_agent("not-in-registry", cx),
595                settings::CustomAgentServerSettings::Extension { .. }
596            ));
597        });
598    }
599}