custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::agent_server_store::{
 11    AgentId, AllAgentServersSettings, CLAUDE_AGENT_ID, CODEX_ID, GEMINI_ID,
 12};
 13use settings::{SettingsStore, update_settings_file};
 14use std::{rc::Rc, sync::Arc};
 15use ui::IconName;
 16
 17/// A generic agent server implementation for custom user-defined agents
 18pub struct CustomAgentServer {
 19    agent_id: AgentId,
 20}
 21
 22impl CustomAgentServer {
 23    pub fn new(agent_id: AgentId) -> Self {
 24        Self { agent_id }
 25    }
 26}
 27
 28impl AgentServer for CustomAgentServer {
 29    fn agent_id(&self) -> AgentId {
 30        self.agent_id.clone()
 31    }
 32
 33    fn logo(&self) -> IconName {
 34        IconName::Terminal
 35    }
 36
 37    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 38        let settings = cx.read_global(|settings: &SettingsStore, _| {
 39            settings
 40                .get::<AllAgentServersSettings>(None)
 41                .get(self.agent_id().0.as_ref())
 42                .cloned()
 43        });
 44
 45        settings
 46            .as_ref()
 47            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 48    }
 49
 50    fn favorite_config_option_value_ids(
 51        &self,
 52        config_id: &acp::SessionConfigId,
 53        cx: &mut App,
 54    ) -> HashSet<acp::SessionConfigValueId> {
 55        let settings = cx.read_global(|settings: &SettingsStore, _| {
 56            settings
 57                .get::<AllAgentServersSettings>(None)
 58                .get(self.agent_id().0.as_ref())
 59                .cloned()
 60        });
 61
 62        settings
 63            .as_ref()
 64            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 65            .map(|values| {
 66                values
 67                    .iter()
 68                    .cloned()
 69                    .map(acp::SessionConfigValueId::new)
 70                    .collect()
 71            })
 72            .unwrap_or_default()
 73    }
 74
 75    fn toggle_favorite_config_option_value(
 76        &self,
 77        config_id: acp::SessionConfigId,
 78        value_id: acp::SessionConfigValueId,
 79        should_be_favorite: bool,
 80        fs: Arc<dyn Fs>,
 81        cx: &App,
 82    ) {
 83        let agent_id = self.agent_id();
 84        let config_id = config_id.to_string();
 85        let value_id = value_id.to_string();
 86
 87        update_settings_file(fs, cx, move |settings, cx| {
 88            let settings = settings
 89                .agent_servers
 90                .get_or_insert_default()
 91                .entry(agent_id.0.to_string())
 92                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
 93
 94            match settings {
 95                settings::CustomAgentServerSettings::Custom {
 96                    favorite_config_option_values,
 97                    ..
 98                }
 99                | settings::CustomAgentServerSettings::Extension {
100                    favorite_config_option_values,
101                    ..
102                }
103                | settings::CustomAgentServerSettings::Registry {
104                    favorite_config_option_values,
105                    ..
106                } => {
107                    let entry = favorite_config_option_values
108                        .entry(config_id.clone())
109                        .or_insert_with(Vec::new);
110
111                    if should_be_favorite {
112                        if !entry.iter().any(|v| v == &value_id) {
113                            entry.push(value_id.clone());
114                        }
115                    } else {
116                        entry.retain(|v| v != &value_id);
117                        if entry.is_empty() {
118                            favorite_config_option_values.remove(&config_id);
119                        }
120                    }
121                }
122            }
123        });
124    }
125
126    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
127        let agent_id = self.agent_id();
128        update_settings_file(fs, cx, move |settings, cx| {
129            let settings = settings
130                .agent_servers
131                .get_or_insert_default()
132                .entry(agent_id.0.to_string())
133                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
134
135            match settings {
136                settings::CustomAgentServerSettings::Custom { default_mode, .. }
137                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
138                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
139                    *default_mode = mode_id.map(|m| m.to_string());
140                }
141            }
142        });
143    }
144
145    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
146        let settings = cx.read_global(|settings: &SettingsStore, _| {
147            settings
148                .get::<AllAgentServersSettings>(None)
149                .get(self.agent_id().as_ref())
150                .cloned()
151        });
152
153        settings
154            .as_ref()
155            .and_then(|s| s.default_model().map(acp::ModelId::new))
156    }
157
158    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
159        let agent_id = self.agent_id();
160        update_settings_file(fs, cx, move |settings, cx| {
161            let settings = settings
162                .agent_servers
163                .get_or_insert_default()
164                .entry(agent_id.0.to_string())
165                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
166
167            match settings {
168                settings::CustomAgentServerSettings::Custom { default_model, .. }
169                | settings::CustomAgentServerSettings::Extension { default_model, .. }
170                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
171                    *default_model = model_id.map(|m| m.to_string());
172                }
173            }
174        });
175    }
176
177    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
178        let settings = cx.read_global(|settings: &SettingsStore, _| {
179            settings
180                .get::<AllAgentServersSettings>(None)
181                .get(self.agent_id().as_ref())
182                .cloned()
183        });
184
185        settings
186            .as_ref()
187            .map(|s| {
188                s.favorite_models()
189                    .iter()
190                    .map(|id| acp::ModelId::new(id.clone()))
191                    .collect()
192            })
193            .unwrap_or_default()
194    }
195
196    fn toggle_favorite_model(
197        &self,
198        model_id: acp::ModelId,
199        should_be_favorite: bool,
200        fs: Arc<dyn Fs>,
201        cx: &App,
202    ) {
203        let agent_id = self.agent_id();
204        update_settings_file(fs, cx, move |settings, cx| {
205            let settings = settings
206                .agent_servers
207                .get_or_insert_default()
208                .entry(agent_id.0.to_string())
209                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
210
211            let favorite_models = match settings {
212                settings::CustomAgentServerSettings::Custom {
213                    favorite_models, ..
214                }
215                | settings::CustomAgentServerSettings::Extension {
216                    favorite_models, ..
217                }
218                | settings::CustomAgentServerSettings::Registry {
219                    favorite_models, ..
220                } => favorite_models,
221            };
222
223            let model_id_str = model_id.to_string();
224            if should_be_favorite {
225                if !favorite_models.contains(&model_id_str) {
226                    favorite_models.push(model_id_str);
227                }
228            } else {
229                favorite_models.retain(|id| id != &model_id_str);
230            }
231        });
232    }
233
234    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
235        let settings = cx.read_global(|settings: &SettingsStore, _| {
236            settings
237                .get::<AllAgentServersSettings>(None)
238                .get(self.agent_id().as_ref())
239                .cloned()
240        });
241
242        settings
243            .as_ref()
244            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
245    }
246
247    fn set_default_config_option(
248        &self,
249        config_id: &str,
250        value_id: Option<&str>,
251        fs: Arc<dyn Fs>,
252        cx: &mut App,
253    ) {
254        let agent_id = self.agent_id();
255        let config_id = config_id.to_string();
256        let value_id = value_id.map(|s| s.to_string());
257        update_settings_file(fs, cx, move |settings, cx| {
258            let settings = settings
259                .agent_servers
260                .get_or_insert_default()
261                .entry(agent_id.0.to_string())
262                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
263
264            match settings {
265                settings::CustomAgentServerSettings::Custom {
266                    default_config_options,
267                    ..
268                }
269                | settings::CustomAgentServerSettings::Extension {
270                    default_config_options,
271                    ..
272                }
273                | settings::CustomAgentServerSettings::Registry {
274                    default_config_options,
275                    ..
276                } => {
277                    if let Some(value) = value_id.clone() {
278                        default_config_options.insert(config_id.clone(), value);
279                    } else {
280                        default_config_options.remove(&config_id);
281                    }
282                }
283            }
284        });
285    }
286
287    fn connect(
288        &self,
289        delegate: AgentServerDelegate,
290        cx: &mut App,
291    ) -> Task<Result<Rc<dyn AgentConnection>>> {
292        let agent_id = self.agent_id();
293        let display_name = delegate
294            .store
295            .read(cx)
296            .agent_display_name(&agent_id)
297            .unwrap_or_else(|| agent_id.0.clone());
298        let default_mode = self.default_mode(cx);
299        let default_model = self.default_model(cx);
300        let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
301        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
302            settings
303                .get::<AllAgentServersSettings>(None)
304                .get(self.agent_id().as_ref())
305                .map(|s| match s {
306                    project::agent_server_store::CustomAgentServerSettings::Custom {
307                        default_config_options,
308                        ..
309                    }
310                    | project::agent_server_store::CustomAgentServerSettings::Extension {
311                        default_config_options,
312                        ..
313                    }
314                    | project::agent_server_store::CustomAgentServerSettings::Registry {
315                        default_config_options,
316                        ..
317                    } => default_config_options.clone(),
318                })
319                .unwrap_or_default()
320        });
321
322        if is_registry_agent {
323            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
324                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
325            }
326        }
327
328        let mut extra_env = load_proxy_env(cx);
329        if delegate.store.read(cx).no_browser() {
330            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
331        }
332        if is_registry_agent {
333            match agent_id.as_ref() {
334                CLAUDE_AGENT_ID => {
335                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
336                }
337                CODEX_ID => {
338                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
339                        extra_env.insert("CODEX_API_KEY".into(), api_key);
340                    }
341                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
342                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
343                    }
344                }
345                GEMINI_ID => {
346                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
347                }
348                _ => {}
349            }
350        }
351        let store = delegate.store.downgrade();
352        cx.spawn(async move |cx| {
353            if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
354                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
355                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
356                }
357            }
358            let command = store
359                .update(cx, |store, cx| {
360                    let agent = store.get_external_agent(&agent_id).with_context(|| {
361                        format!("Custom agent server `{}` is not registered", agent_id)
362                    })?;
363                    anyhow::Ok(agent.get_command(
364                        extra_env,
365                        delegate.new_version_available,
366                        &mut cx.to_async(),
367                    ))
368                })??
369                .await?;
370            let connection = crate::acp::connect(
371                agent_id,
372                display_name,
373                command,
374                default_mode,
375                default_model,
376                default_config_options,
377                cx,
378            )
379            .await?;
380            Ok(connection)
381        })
382    }
383
384    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
385        self
386    }
387}
388
389fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
390    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
391    if let Some(key) = env_var.value {
392        return Task::ready(Ok(key));
393    }
394    let credentials_provider = <dyn CredentialsProvider>::global(cx);
395    let api_url = google_ai::API_URL.to_string();
396    cx.spawn(async move |cx| {
397        Ok(
398            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
399                .await?
400                .key()
401                .to_string(),
402        )
403    })
404}
405
406fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
407    let agent_id = agent_id.into();
408    let is_previous_built_in =
409        matches!(agent_id.0.as_ref(), CLAUDE_AGENT_ID | CODEX_ID | GEMINI_ID);
410    let is_in_registry = project::AgentRegistryStore::try_global(cx)
411        .map(|store| store.read(cx).agent(&agent_id).is_some())
412        .unwrap_or(false);
413    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
414        settings
415            .get::<AllAgentServersSettings>(None)
416            .get(agent_id.as_ref())
417            .is_some_and(|s| {
418                matches!(
419                    s,
420                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
421                )
422            })
423    });
424    is_previous_built_in || is_in_registry || is_settings_registry
425}
426
427fn default_settings_for_agent(
428    agent_id: impl Into<AgentId>,
429    cx: &App,
430) -> settings::CustomAgentServerSettings {
431    if is_registry_agent(agent_id, cx) {
432        settings::CustomAgentServerSettings::Registry {
433            default_model: None,
434            default_mode: None,
435            env: Default::default(),
436            favorite_models: Vec::new(),
437            default_config_options: Default::default(),
438            favorite_config_option_values: Default::default(),
439        }
440    } else {
441        settings::CustomAgentServerSettings::Extension {
442            default_model: None,
443            default_mode: None,
444            env: Default::default(),
445            favorite_models: Vec::new(),
446            default_config_options: Default::default(),
447            favorite_config_option_values: Default::default(),
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use collections::HashMap;
456    use gpui::TestAppContext;
457    use project::agent_registry_store::{
458        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
459    };
460    use settings::Settings as _;
461    use ui::SharedString;
462
463    fn init_test(cx: &mut TestAppContext) {
464        cx.update(|cx| {
465            let settings_store = SettingsStore::test(cx);
466            cx.set_global(settings_store);
467        });
468    }
469
470    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
471        let agents: Vec<RegistryAgent> = agent_ids
472            .iter()
473            .map(|id| {
474                let id = SharedString::from(id.to_string());
475                RegistryAgent::Npx(RegistryNpxAgent {
476                    metadata: RegistryAgentMetadata {
477                        id: AgentId::new(id.clone()),
478                        name: id.clone(),
479                        description: SharedString::from(""),
480                        version: SharedString::from("1.0.0"),
481                        repository: None,
482                        icon_path: None,
483                    },
484                    package: id,
485                    args: Vec::new(),
486                    env: HashMap::default(),
487                })
488            })
489            .collect();
490        cx.update(|cx| {
491            AgentRegistryStore::init_test_global(cx, agents);
492        });
493    }
494
495    fn set_agent_server_settings(
496        cx: &mut TestAppContext,
497        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
498    ) {
499        cx.update(|cx| {
500            AllAgentServersSettings::override_global(
501                project::agent_server_store::AllAgentServersSettings(
502                    entries
503                        .into_iter()
504                        .map(|(name, settings)| (name.to_string(), settings.into()))
505                        .collect(),
506                ),
507                cx,
508            );
509        });
510    }
511
512    #[gpui::test]
513    fn test_previous_builtins_are_registry(cx: &mut TestAppContext) {
514        init_test(cx);
515        cx.update(|cx| {
516            assert!(is_registry_agent(CLAUDE_AGENT_ID, cx));
517            assert!(is_registry_agent(CODEX_ID, cx));
518            assert!(is_registry_agent(GEMINI_ID, cx));
519        });
520    }
521
522    #[gpui::test]
523    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
524        init_test(cx);
525        cx.update(|cx| {
526            assert!(!is_registry_agent("my-custom-agent", cx));
527        });
528    }
529
530    #[gpui::test]
531    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
532        init_test(cx);
533        init_registry_with_agents(cx, &["some-new-registry-agent"]);
534        cx.update(|cx| {
535            assert!(is_registry_agent("some-new-registry-agent", cx));
536            assert!(!is_registry_agent("not-in-registry", cx));
537        });
538    }
539
540    #[gpui::test]
541    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
542        init_test(cx);
543        set_agent_server_settings(
544            cx,
545            vec![(
546                "agent-from-settings",
547                settings::CustomAgentServerSettings::Registry {
548                    env: HashMap::default(),
549                    default_mode: None,
550                    default_model: None,
551                    favorite_models: Vec::new(),
552                    default_config_options: HashMap::default(),
553                    favorite_config_option_values: HashMap::default(),
554                },
555            )],
556        );
557        cx.update(|cx| {
558            assert!(is_registry_agent("agent-from-settings", cx));
559        });
560    }
561
562    #[gpui::test]
563    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
564        init_test(cx);
565        set_agent_server_settings(
566            cx,
567            vec![(
568                "my-extension-agent",
569                settings::CustomAgentServerSettings::Extension {
570                    env: HashMap::default(),
571                    default_mode: None,
572                    default_model: None,
573                    favorite_models: Vec::new(),
574                    default_config_options: HashMap::default(),
575                    favorite_config_option_values: HashMap::default(),
576                },
577            )],
578        );
579        cx.update(|cx| {
580            assert!(!is_registry_agent("my-extension-agent", cx));
581        });
582    }
583
584    #[gpui::test]
585    fn test_default_settings_for_builtin_agent(cx: &mut TestAppContext) {
586        init_test(cx);
587        cx.update(|cx| {
588            assert!(matches!(
589                default_settings_for_agent(CODEX_ID, cx),
590                settings::CustomAgentServerSettings::Registry { .. }
591            ));
592            assert!(matches!(
593                default_settings_for_agent(CLAUDE_AGENT_ID, cx),
594                settings::CustomAgentServerSettings::Registry { .. }
595            ));
596            assert!(matches!(
597                default_settings_for_agent(GEMINI_ID, cx),
598                settings::CustomAgentServerSettings::Registry { .. }
599            ));
600        });
601    }
602
603    #[gpui::test]
604    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
605        init_test(cx);
606        cx.update(|cx| {
607            assert!(matches!(
608                default_settings_for_agent("some-extension-agent", cx),
609                settings::CustomAgentServerSettings::Extension { .. }
610            ));
611        });
612    }
613
614    #[gpui::test]
615    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
616        init_test(cx);
617        init_registry_with_agents(cx, &["new-registry-agent"]);
618        cx.update(|cx| {
619            assert!(matches!(
620                default_settings_for_agent("new-registry-agent", cx),
621                settings::CustomAgentServerSettings::Registry { .. }
622            ));
623            assert!(matches!(
624                default_settings_for_agent("not-in-registry", cx),
625                settings::CustomAgentServerSettings::Extension { .. }
626            ));
627        });
628    }
629}