custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, SharedString, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::agent_server_store::{
 11    AllAgentServersSettings, CLAUDE_AGENT_NAME, CODEX_NAME, ExternalAgentServerName, GEMINI_NAME,
 12};
 13use settings::{SettingsStore, update_settings_file};
 14use std::{rc::Rc, sync::Arc};
 15use ui::IconName;
 16
 17/// A generic agent server implementation for custom user-defined agents
 18pub struct CustomAgentServer {
 19    name: SharedString,
 20}
 21
 22impl CustomAgentServer {
 23    pub fn new(name: SharedString) -> Self {
 24        Self { name }
 25    }
 26}
 27
 28impl AgentServer for CustomAgentServer {
 29    fn name(&self) -> SharedString {
 30        self.name.clone()
 31    }
 32
 33    fn logo(&self) -> IconName {
 34        IconName::Terminal
 35    }
 36
 37    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 38        let settings = cx.read_global(|settings: &SettingsStore, _| {
 39            settings
 40                .get::<AllAgentServersSettings>(None)
 41                .get(self.name().as_ref())
 42                .cloned()
 43        });
 44
 45        settings
 46            .as_ref()
 47            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 48    }
 49
 50    fn favorite_config_option_value_ids(
 51        &self,
 52        config_id: &acp::SessionConfigId,
 53        cx: &mut App,
 54    ) -> HashSet<acp::SessionConfigValueId> {
 55        let settings = cx.read_global(|settings: &SettingsStore, _| {
 56            settings
 57                .get::<AllAgentServersSettings>(None)
 58                .get(self.name().as_ref())
 59                .cloned()
 60        });
 61
 62        settings
 63            .as_ref()
 64            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 65            .map(|values| {
 66                values
 67                    .iter()
 68                    .cloned()
 69                    .map(acp::SessionConfigValueId::new)
 70                    .collect()
 71            })
 72            .unwrap_or_default()
 73    }
 74
 75    fn toggle_favorite_config_option_value(
 76        &self,
 77        config_id: acp::SessionConfigId,
 78        value_id: acp::SessionConfigValueId,
 79        should_be_favorite: bool,
 80        fs: Arc<dyn Fs>,
 81        cx: &App,
 82    ) {
 83        let name = self.name();
 84        let config_id = config_id.to_string();
 85        let value_id = value_id.to_string();
 86
 87        update_settings_file(fs, cx, move |settings, cx| {
 88            let settings = settings
 89                .agent_servers
 90                .get_or_insert_default()
 91                .entry(name.to_string())
 92                .or_insert_with(|| default_settings_for_agent(&name, cx));
 93
 94            match settings {
 95                settings::CustomAgentServerSettings::Custom {
 96                    favorite_config_option_values,
 97                    ..
 98                }
 99                | settings::CustomAgentServerSettings::Extension {
100                    favorite_config_option_values,
101                    ..
102                }
103                | settings::CustomAgentServerSettings::Registry {
104                    favorite_config_option_values,
105                    ..
106                } => {
107                    let entry = favorite_config_option_values
108                        .entry(config_id.clone())
109                        .or_insert_with(Vec::new);
110
111                    if should_be_favorite {
112                        if !entry.iter().any(|v| v == &value_id) {
113                            entry.push(value_id.clone());
114                        }
115                    } else {
116                        entry.retain(|v| v != &value_id);
117                        if entry.is_empty() {
118                            favorite_config_option_values.remove(&config_id);
119                        }
120                    }
121                }
122            }
123        });
124    }
125
126    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
127        let name = self.name();
128        update_settings_file(fs, cx, move |settings, cx| {
129            let settings = settings
130                .agent_servers
131                .get_or_insert_default()
132                .entry(name.to_string())
133                .or_insert_with(|| default_settings_for_agent(&name, cx));
134
135            match settings {
136                settings::CustomAgentServerSettings::Custom { default_mode, .. }
137                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
138                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
139                    *default_mode = mode_id.map(|m| m.to_string());
140                }
141            }
142        });
143    }
144
145    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
146        let settings = cx.read_global(|settings: &SettingsStore, _| {
147            settings
148                .get::<AllAgentServersSettings>(None)
149                .get(self.name().as_ref())
150                .cloned()
151        });
152
153        settings
154            .as_ref()
155            .and_then(|s| s.default_model().map(acp::ModelId::new))
156    }
157
158    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
159        let name = self.name();
160        update_settings_file(fs, cx, move |settings, cx| {
161            let settings = settings
162                .agent_servers
163                .get_or_insert_default()
164                .entry(name.to_string())
165                .or_insert_with(|| default_settings_for_agent(&name, cx));
166
167            match settings {
168                settings::CustomAgentServerSettings::Custom { default_model, .. }
169                | settings::CustomAgentServerSettings::Extension { default_model, .. }
170                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
171                    *default_model = model_id.map(|m| m.to_string());
172                }
173            }
174        });
175    }
176
177    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
178        let settings = cx.read_global(|settings: &SettingsStore, _| {
179            settings
180                .get::<AllAgentServersSettings>(None)
181                .get(self.name().as_ref())
182                .cloned()
183        });
184
185        settings
186            .as_ref()
187            .map(|s| {
188                s.favorite_models()
189                    .iter()
190                    .map(|id| acp::ModelId::new(id.clone()))
191                    .collect()
192            })
193            .unwrap_or_default()
194    }
195
196    fn toggle_favorite_model(
197        &self,
198        model_id: acp::ModelId,
199        should_be_favorite: bool,
200        fs: Arc<dyn Fs>,
201        cx: &App,
202    ) {
203        let name = self.name();
204        update_settings_file(fs, cx, move |settings, cx| {
205            let settings = settings
206                .agent_servers
207                .get_or_insert_default()
208                .entry(name.to_string())
209                .or_insert_with(|| default_settings_for_agent(&name, cx));
210
211            let favorite_models = match settings {
212                settings::CustomAgentServerSettings::Custom {
213                    favorite_models, ..
214                }
215                | settings::CustomAgentServerSettings::Extension {
216                    favorite_models, ..
217                }
218                | settings::CustomAgentServerSettings::Registry {
219                    favorite_models, ..
220                } => favorite_models,
221            };
222
223            let model_id_str = model_id.to_string();
224            if should_be_favorite {
225                if !favorite_models.contains(&model_id_str) {
226                    favorite_models.push(model_id_str);
227                }
228            } else {
229                favorite_models.retain(|id| id != &model_id_str);
230            }
231        });
232    }
233
234    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
235        let settings = cx.read_global(|settings: &SettingsStore, _| {
236            settings
237                .get::<AllAgentServersSettings>(None)
238                .get(self.name().as_ref())
239                .cloned()
240        });
241
242        settings
243            .as_ref()
244            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
245    }
246
247    fn set_default_config_option(
248        &self,
249        config_id: &str,
250        value_id: Option<&str>,
251        fs: Arc<dyn Fs>,
252        cx: &mut App,
253    ) {
254        let name = self.name();
255        let config_id = config_id.to_string();
256        let value_id = value_id.map(|s| s.to_string());
257        update_settings_file(fs, cx, move |settings, cx| {
258            let settings = settings
259                .agent_servers
260                .get_or_insert_default()
261                .entry(name.to_string())
262                .or_insert_with(|| default_settings_for_agent(&name, cx));
263
264            match settings {
265                settings::CustomAgentServerSettings::Custom {
266                    default_config_options,
267                    ..
268                }
269                | settings::CustomAgentServerSettings::Extension {
270                    default_config_options,
271                    ..
272                }
273                | settings::CustomAgentServerSettings::Registry {
274                    default_config_options,
275                    ..
276                } => {
277                    if let Some(value) = value_id.clone() {
278                        default_config_options.insert(config_id.clone(), value);
279                    } else {
280                        default_config_options.remove(&config_id);
281                    }
282                }
283            }
284        });
285    }
286
287    fn connect(
288        &self,
289        delegate: AgentServerDelegate,
290        cx: &mut App,
291    ) -> Task<Result<Rc<dyn AgentConnection>>> {
292        let name = self.name();
293        let display_name = delegate
294            .store
295            .read(cx)
296            .agent_display_name(&ExternalAgentServerName(name.clone()))
297            .unwrap_or_else(|| name.clone());
298        let default_mode = self.default_mode(cx);
299        let default_model = self.default_model(cx);
300        let is_registry_agent = is_registry_agent(&name, cx);
301        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
302            settings
303                .get::<AllAgentServersSettings>(None)
304                .get(self.name().as_ref())
305                .map(|s| match s {
306                    project::agent_server_store::CustomAgentServerSettings::Custom {
307                        default_config_options,
308                        ..
309                    }
310                    | project::agent_server_store::CustomAgentServerSettings::Extension {
311                        default_config_options,
312                        ..
313                    }
314                    | project::agent_server_store::CustomAgentServerSettings::Registry {
315                        default_config_options,
316                        ..
317                    } => default_config_options.clone(),
318                })
319                .unwrap_or_default()
320        });
321
322        if is_registry_agent {
323            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
324                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
325            }
326        }
327
328        let mut extra_env = load_proxy_env(cx);
329        if delegate.store.read(cx).no_browser() {
330            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
331        }
332        if is_registry_agent {
333            match name.as_ref() {
334                CLAUDE_AGENT_NAME => {
335                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
336                }
337                CODEX_NAME => {
338                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
339                        extra_env.insert("CODEX_API_KEY".into(), api_key);
340                    }
341                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
342                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
343                    }
344                }
345                GEMINI_NAME => {
346                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
347                }
348                _ => {}
349            }
350        }
351        let store = delegate.store.downgrade();
352        cx.spawn(async move |cx| {
353            if is_registry_agent && name.as_ref() == GEMINI_NAME {
354                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
355                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
356                }
357            }
358            let command = store
359                .update(cx, |store, cx| {
360                    let agent = store
361                        .get_external_agent(&ExternalAgentServerName(name.clone()))
362                        .with_context(|| {
363                            format!("Custom agent server `{}` is not registered", name)
364                        })?;
365                    anyhow::Ok(agent.get_command(
366                        extra_env,
367                        delegate.new_version_available,
368                        &mut cx.to_async(),
369                    ))
370                })??
371                .await?;
372            let connection = crate::acp::connect(
373                name,
374                display_name,
375                command,
376                default_mode,
377                default_model,
378                default_config_options,
379                cx,
380            )
381            .await?;
382            Ok(connection)
383        })
384    }
385
386    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
387        self
388    }
389}
390
391fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
392    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
393    if let Some(key) = env_var.value {
394        return Task::ready(Ok(key));
395    }
396    let credentials_provider = <dyn CredentialsProvider>::global(cx);
397    let api_url = google_ai::API_URL.to_string();
398    cx.spawn(async move |cx| {
399        Ok(
400            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
401                .await?
402                .key()
403                .to_string(),
404        )
405    })
406}
407
408fn is_registry_agent(name: &str, cx: &App) -> bool {
409    let is_previous_built_in = matches!(name, CLAUDE_AGENT_NAME | CODEX_NAME | GEMINI_NAME);
410    let is_in_registry = project::AgentRegistryStore::try_global(cx)
411        .map(|store| store.read(cx).agent(name).is_some())
412        .unwrap_or(false);
413    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
414        settings
415            .get::<AllAgentServersSettings>(None)
416            .get(name)
417            .is_some_and(|s| {
418                matches!(
419                    s,
420                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
421                )
422            })
423    });
424    is_previous_built_in || is_in_registry || is_settings_registry
425}
426
427fn default_settings_for_agent(name: &str, cx: &App) -> settings::CustomAgentServerSettings {
428    if is_registry_agent(name, cx) {
429        settings::CustomAgentServerSettings::Registry {
430            default_model: None,
431            default_mode: None,
432            env: Default::default(),
433            favorite_models: Vec::new(),
434            default_config_options: Default::default(),
435            favorite_config_option_values: Default::default(),
436        }
437    } else {
438        settings::CustomAgentServerSettings::Extension {
439            default_model: None,
440            default_mode: None,
441            env: Default::default(),
442            favorite_models: Vec::new(),
443            default_config_options: Default::default(),
444            favorite_config_option_values: Default::default(),
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use collections::HashMap;
453    use gpui::TestAppContext;
454    use project::agent_registry_store::{
455        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
456    };
457    use settings::Settings as _;
458
459    fn init_test(cx: &mut TestAppContext) {
460        cx.update(|cx| {
461            let settings_store = SettingsStore::test(cx);
462            cx.set_global(settings_store);
463        });
464    }
465
466    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
467        let agents: Vec<RegistryAgent> = agent_ids
468            .iter()
469            .map(|id| {
470                let id = SharedString::from(id.to_string());
471                RegistryAgent::Npx(RegistryNpxAgent {
472                    metadata: RegistryAgentMetadata {
473                        id: id.clone(),
474                        name: id.clone(),
475                        description: SharedString::from(""),
476                        version: SharedString::from("1.0.0"),
477                        repository: None,
478                        icon_path: None,
479                    },
480                    package: id,
481                    args: Vec::new(),
482                    env: HashMap::default(),
483                })
484            })
485            .collect();
486        cx.update(|cx| {
487            AgentRegistryStore::init_test_global(cx, agents);
488        });
489    }
490
491    fn set_agent_server_settings(
492        cx: &mut TestAppContext,
493        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
494    ) {
495        cx.update(|cx| {
496            AllAgentServersSettings::override_global(
497                project::agent_server_store::AllAgentServersSettings(
498                    entries
499                        .into_iter()
500                        .map(|(name, settings)| (name.to_string(), settings.into()))
501                        .collect(),
502                ),
503                cx,
504            );
505        });
506    }
507
508    #[gpui::test]
509    fn test_previous_builtins_are_registry(cx: &mut TestAppContext) {
510        init_test(cx);
511        cx.update(|cx| {
512            assert!(is_registry_agent(CLAUDE_AGENT_NAME, cx));
513            assert!(is_registry_agent(CODEX_NAME, cx));
514            assert!(is_registry_agent(GEMINI_NAME, cx));
515        });
516    }
517
518    #[gpui::test]
519    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
520        init_test(cx);
521        cx.update(|cx| {
522            assert!(!is_registry_agent("my-custom-agent", cx));
523        });
524    }
525
526    #[gpui::test]
527    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
528        init_test(cx);
529        init_registry_with_agents(cx, &["some-new-registry-agent"]);
530        cx.update(|cx| {
531            assert!(is_registry_agent("some-new-registry-agent", cx));
532            assert!(!is_registry_agent("not-in-registry", cx));
533        });
534    }
535
536    #[gpui::test]
537    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
538        init_test(cx);
539        set_agent_server_settings(
540            cx,
541            vec![(
542                "agent-from-settings",
543                settings::CustomAgentServerSettings::Registry {
544                    env: HashMap::default(),
545                    default_mode: None,
546                    default_model: None,
547                    favorite_models: Vec::new(),
548                    default_config_options: HashMap::default(),
549                    favorite_config_option_values: HashMap::default(),
550                },
551            )],
552        );
553        cx.update(|cx| {
554            assert!(is_registry_agent("agent-from-settings", cx));
555        });
556    }
557
558    #[gpui::test]
559    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
560        init_test(cx);
561        set_agent_server_settings(
562            cx,
563            vec![(
564                "my-extension-agent",
565                settings::CustomAgentServerSettings::Extension {
566                    env: HashMap::default(),
567                    default_mode: None,
568                    default_model: None,
569                    favorite_models: Vec::new(),
570                    default_config_options: HashMap::default(),
571                    favorite_config_option_values: HashMap::default(),
572                },
573            )],
574        );
575        cx.update(|cx| {
576            assert!(!is_registry_agent("my-extension-agent", cx));
577        });
578    }
579
580    #[gpui::test]
581    fn test_default_settings_for_builtin_agent(cx: &mut TestAppContext) {
582        init_test(cx);
583        cx.update(|cx| {
584            assert!(matches!(
585                default_settings_for_agent(CODEX_NAME, cx),
586                settings::CustomAgentServerSettings::Registry { .. }
587            ));
588            assert!(matches!(
589                default_settings_for_agent(CLAUDE_AGENT_NAME, cx),
590                settings::CustomAgentServerSettings::Registry { .. }
591            ));
592            assert!(matches!(
593                default_settings_for_agent(GEMINI_NAME, cx),
594                settings::CustomAgentServerSettings::Registry { .. }
595            ));
596        });
597    }
598
599    #[gpui::test]
600    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
601        init_test(cx);
602        cx.update(|cx| {
603            assert!(matches!(
604                default_settings_for_agent("some-extension-agent", cx),
605                settings::CustomAgentServerSettings::Extension { .. }
606            ));
607        });
608    }
609
610    #[gpui::test]
611    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
612        init_test(cx);
613        init_registry_with_agents(cx, &["new-registry-agent"]);
614        cx.update(|cx| {
615            assert!(matches!(
616                default_settings_for_agent("new-registry-agent", cx),
617                settings::CustomAgentServerSettings::Registry { .. }
618            ));
619            assert!(matches!(
620                default_settings_for_agent("not-in-registry", cx),
621                settings::CustomAgentServerSettings::Extension { .. }
622            ));
623        });
624    }
625}