custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, SharedString, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::agent_server_store::{
 11    AllAgentServersSettings, CLAUDE_AGENT_NAME, CODEX_NAME, ExternalAgentServerName, GEMINI_NAME,
 12};
 13use settings::{SettingsStore, update_settings_file};
 14use std::{rc::Rc, sync::Arc};
 15use ui::IconName;
 16
 17/// A generic agent server implementation for custom user-defined agents
 18pub struct CustomAgentServer {
 19    name: SharedString,
 20}
 21
 22impl CustomAgentServer {
 23    pub fn new(name: SharedString) -> Self {
 24        Self { name }
 25    }
 26}
 27
 28impl AgentServer for CustomAgentServer {
 29    fn name(&self) -> SharedString {
 30        self.name.clone()
 31    }
 32
 33    fn logo(&self) -> IconName {
 34        IconName::Terminal
 35    }
 36
 37    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 38        let settings = cx.read_global(|settings: &SettingsStore, _| {
 39            settings
 40                .get::<AllAgentServersSettings>(None)
 41                .get(self.name().as_ref())
 42                .cloned()
 43        });
 44
 45        settings
 46            .as_ref()
 47            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 48    }
 49
 50    fn favorite_config_option_value_ids(
 51        &self,
 52        config_id: &acp::SessionConfigId,
 53        cx: &mut App,
 54    ) -> HashSet<acp::SessionConfigValueId> {
 55        let settings = cx.read_global(|settings: &SettingsStore, _| {
 56            settings
 57                .get::<AllAgentServersSettings>(None)
 58                .get(self.name().as_ref())
 59                .cloned()
 60        });
 61
 62        settings
 63            .as_ref()
 64            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 65            .map(|values| {
 66                values
 67                    .iter()
 68                    .cloned()
 69                    .map(acp::SessionConfigValueId::new)
 70                    .collect()
 71            })
 72            .unwrap_or_default()
 73    }
 74
 75    fn toggle_favorite_config_option_value(
 76        &self,
 77        config_id: acp::SessionConfigId,
 78        value_id: acp::SessionConfigValueId,
 79        should_be_favorite: bool,
 80        fs: Arc<dyn Fs>,
 81        cx: &App,
 82    ) {
 83        let name = self.name();
 84        let config_id = config_id.to_string();
 85        let value_id = value_id.to_string();
 86
 87        update_settings_file(fs, cx, move |settings, cx| {
 88            let settings = settings
 89                .agent_servers
 90                .get_or_insert_default()
 91                .entry(name.to_string())
 92                .or_insert_with(|| default_settings_for_agent(&name, cx));
 93
 94            match settings {
 95                settings::CustomAgentServerSettings::Custom {
 96                    favorite_config_option_values,
 97                    ..
 98                }
 99                | settings::CustomAgentServerSettings::Extension {
100                    favorite_config_option_values,
101                    ..
102                }
103                | settings::CustomAgentServerSettings::Registry {
104                    favorite_config_option_values,
105                    ..
106                } => {
107                    let entry = favorite_config_option_values
108                        .entry(config_id.clone())
109                        .or_insert_with(Vec::new);
110
111                    if should_be_favorite {
112                        if !entry.iter().any(|v| v == &value_id) {
113                            entry.push(value_id.clone());
114                        }
115                    } else {
116                        entry.retain(|v| v != &value_id);
117                        if entry.is_empty() {
118                            favorite_config_option_values.remove(&config_id);
119                        }
120                    }
121                }
122            }
123        });
124    }
125
126    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
127        let name = self.name();
128        update_settings_file(fs, cx, move |settings, cx| {
129            let settings = settings
130                .agent_servers
131                .get_or_insert_default()
132                .entry(name.to_string())
133                .or_insert_with(|| default_settings_for_agent(&name, cx));
134
135            match settings {
136                settings::CustomAgentServerSettings::Custom { default_mode, .. }
137                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
138                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
139                    *default_mode = mode_id.map(|m| m.to_string());
140                }
141            }
142        });
143    }
144
145    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
146        let settings = cx.read_global(|settings: &SettingsStore, _| {
147            settings
148                .get::<AllAgentServersSettings>(None)
149                .get(self.name().as_ref())
150                .cloned()
151        });
152
153        settings
154            .as_ref()
155            .and_then(|s| s.default_model().map(acp::ModelId::new))
156    }
157
158    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
159        let name = self.name();
160        update_settings_file(fs, cx, move |settings, cx| {
161            let settings = settings
162                .agent_servers
163                .get_or_insert_default()
164                .entry(name.to_string())
165                .or_insert_with(|| default_settings_for_agent(&name, cx));
166
167            match settings {
168                settings::CustomAgentServerSettings::Custom { default_model, .. }
169                | settings::CustomAgentServerSettings::Extension { default_model, .. }
170                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
171                    *default_model = model_id.map(|m| m.to_string());
172                }
173            }
174        });
175    }
176
177    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
178        let settings = cx.read_global(|settings: &SettingsStore, _| {
179            settings
180                .get::<AllAgentServersSettings>(None)
181                .get(self.name().as_ref())
182                .cloned()
183        });
184
185        settings
186            .as_ref()
187            .map(|s| {
188                s.favorite_models()
189                    .iter()
190                    .map(|id| acp::ModelId::new(id.clone()))
191                    .collect()
192            })
193            .unwrap_or_default()
194    }
195
196    fn toggle_favorite_model(
197        &self,
198        model_id: acp::ModelId,
199        should_be_favorite: bool,
200        fs: Arc<dyn Fs>,
201        cx: &App,
202    ) {
203        let name = self.name();
204        update_settings_file(fs, cx, move |settings, cx| {
205            let settings = settings
206                .agent_servers
207                .get_or_insert_default()
208                .entry(name.to_string())
209                .or_insert_with(|| default_settings_for_agent(&name, cx));
210
211            let favorite_models = match settings {
212                settings::CustomAgentServerSettings::Custom {
213                    favorite_models, ..
214                }
215                | settings::CustomAgentServerSettings::Extension {
216                    favorite_models, ..
217                }
218                | settings::CustomAgentServerSettings::Registry {
219                    favorite_models, ..
220                } => favorite_models,
221            };
222
223            let model_id_str = model_id.to_string();
224            if should_be_favorite {
225                if !favorite_models.contains(&model_id_str) {
226                    favorite_models.push(model_id_str);
227                }
228            } else {
229                favorite_models.retain(|id| id != &model_id_str);
230            }
231        });
232    }
233
234    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
235        let settings = cx.read_global(|settings: &SettingsStore, _| {
236            settings
237                .get::<AllAgentServersSettings>(None)
238                .get(self.name().as_ref())
239                .cloned()
240        });
241
242        settings
243            .as_ref()
244            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
245    }
246
247    fn set_default_config_option(
248        &self,
249        config_id: &str,
250        value_id: Option<&str>,
251        fs: Arc<dyn Fs>,
252        cx: &mut App,
253    ) {
254        let name = self.name();
255        let config_id = config_id.to_string();
256        let value_id = value_id.map(|s| s.to_string());
257        update_settings_file(fs, cx, move |settings, cx| {
258            let settings = settings
259                .agent_servers
260                .get_or_insert_default()
261                .entry(name.to_string())
262                .or_insert_with(|| default_settings_for_agent(&name, cx));
263
264            match settings {
265                settings::CustomAgentServerSettings::Custom {
266                    default_config_options,
267                    ..
268                }
269                | settings::CustomAgentServerSettings::Extension {
270                    default_config_options,
271                    ..
272                }
273                | settings::CustomAgentServerSettings::Registry {
274                    default_config_options,
275                    ..
276                } => {
277                    if let Some(value) = value_id.clone() {
278                        default_config_options.insert(config_id.clone(), value);
279                    } else {
280                        default_config_options.remove(&config_id);
281                    }
282                }
283            }
284        });
285    }
286
287    fn connect(
288        &self,
289        delegate: AgentServerDelegate,
290        cx: &mut App,
291    ) -> Task<Result<Rc<dyn AgentConnection>>> {
292        let name = self.name();
293        let display_name = delegate
294            .store
295            .read(cx)
296            .agent_display_name(&ExternalAgentServerName(name.clone()))
297            .unwrap_or_else(|| name.clone());
298        let default_mode = self.default_mode(cx);
299        let default_model = self.default_model(cx);
300        let is_registry_agent = is_registry_agent(&name, cx);
301        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
302            settings
303                .get::<AllAgentServersSettings>(None)
304                .get(self.name().as_ref())
305                .map(|s| match s {
306                    project::agent_server_store::CustomAgentServerSettings::Custom {
307                        default_config_options,
308                        ..
309                    }
310                    | project::agent_server_store::CustomAgentServerSettings::Extension {
311                        default_config_options,
312                        ..
313                    }
314                    | project::agent_server_store::CustomAgentServerSettings::Registry {
315                        default_config_options,
316                        ..
317                    } => default_config_options.clone(),
318                })
319                .unwrap_or_default()
320        });
321
322        if is_registry_agent {
323            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
324                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
325            }
326        }
327
328        let mut extra_env = load_proxy_env(cx);
329        if delegate.store.read(cx).no_browser() {
330            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
331        }
332        if is_registry_agent {
333            match name.as_ref() {
334                CLAUDE_AGENT_NAME => {
335                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
336                }
337                CODEX_NAME => {
338                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
339                        extra_env.insert("CODEX_API_KEY".into(), api_key);
340                    }
341                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
342                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
343                    }
344                }
345                GEMINI_NAME => {
346                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
347                }
348                _ => {}
349            }
350        }
351        let store = delegate.store.downgrade();
352        cx.spawn(async move |cx| {
353            if is_registry_agent && name.as_ref() == GEMINI_NAME {
354                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
355                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
356                }
357            }
358            let command = store
359                .update(cx, |store, cx| {
360                    let agent = store
361                        .get_external_agent(&ExternalAgentServerName(name.clone()))
362                        .with_context(|| {
363                            format!("Custom agent server `{}` is not registered", name)
364                        })?;
365                    anyhow::Ok(agent.get_command(
366                        extra_env,
367                        delegate.new_version_available,
368                        &mut cx.to_async(),
369                    ))
370                })??
371                .await?;
372            let connection = crate::acp::connect(
373                name,
374                display_name,
375                command,
376                default_mode,
377                default_model,
378                default_config_options,
379                cx,
380            )
381            .await?;
382            Ok(connection)
383        })
384    }
385
386    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
387        self
388    }
389}
390
391fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
392    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
393    if let Some(key) = env_var.value {
394        return Task::ready(Ok(key));
395    }
396    let credentials_provider = <dyn CredentialsProvider>::global(cx);
397    let api_url = google_ai::API_URL.to_string();
398    cx.spawn(async move |cx| {
399        Ok(
400            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
401                .await?
402                .key()
403                .to_string(),
404        )
405    })
406}
407
408fn is_registry_agent(name: &str, cx: &App) -> bool {
409    let is_in_registry = project::AgentRegistryStore::try_global(cx)
410        .map(|store| store.read(cx).agent(name).is_some())
411        .unwrap_or(false);
412    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
413        settings
414            .get::<AllAgentServersSettings>(None)
415            .get(name)
416            .is_some_and(|s| {
417                matches!(
418                    s,
419                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
420                )
421            })
422    });
423    is_in_registry || is_settings_registry
424}
425
426fn default_settings_for_agent(name: &str, cx: &App) -> settings::CustomAgentServerSettings {
427    if is_registry_agent(name, cx) {
428        settings::CustomAgentServerSettings::Registry {
429            default_model: None,
430            default_mode: None,
431            env: Default::default(),
432            favorite_models: Vec::new(),
433            default_config_options: Default::default(),
434            favorite_config_option_values: Default::default(),
435        }
436    } else {
437        settings::CustomAgentServerSettings::Extension {
438            default_model: None,
439            default_mode: None,
440            env: Default::default(),
441            favorite_models: Vec::new(),
442            default_config_options: Default::default(),
443            favorite_config_option_values: Default::default(),
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use collections::HashMap;
452    use gpui::TestAppContext;
453    use project::agent_registry_store::{
454        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
455    };
456    use settings::Settings as _;
457
458    fn init_test(cx: &mut TestAppContext) {
459        cx.update(|cx| {
460            let settings_store = SettingsStore::test(cx);
461            cx.set_global(settings_store);
462        });
463    }
464
465    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
466        let agents: Vec<RegistryAgent> = agent_ids
467            .iter()
468            .map(|id| {
469                let id = SharedString::from(id.to_string());
470                RegistryAgent::Npx(RegistryNpxAgent {
471                    metadata: RegistryAgentMetadata {
472                        id: id.clone(),
473                        name: id.clone(),
474                        description: SharedString::from(""),
475                        version: SharedString::from("1.0.0"),
476                        repository: None,
477                        icon_path: None,
478                    },
479                    package: id,
480                    args: Vec::new(),
481                    env: HashMap::default(),
482                })
483            })
484            .collect();
485        cx.update(|cx| {
486            AgentRegistryStore::init_test_global(cx, agents);
487        });
488    }
489
490    fn set_agent_server_settings(
491        cx: &mut TestAppContext,
492        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
493    ) {
494        cx.update(|cx| {
495            AllAgentServersSettings::override_global(
496                project::agent_server_store::AllAgentServersSettings(
497                    entries
498                        .into_iter()
499                        .map(|(name, settings)| (name.to_string(), settings.into()))
500                        .collect(),
501                ),
502                cx,
503            );
504        });
505    }
506
507    #[gpui::test]
508    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
509        init_test(cx);
510        cx.update(|cx| {
511            assert!(!is_registry_agent("my-custom-agent", cx));
512        });
513    }
514
515    #[gpui::test]
516    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
517        init_test(cx);
518        init_registry_with_agents(cx, &["some-new-registry-agent"]);
519        cx.update(|cx| {
520            assert!(is_registry_agent("some-new-registry-agent", cx));
521            assert!(!is_registry_agent("not-in-registry", cx));
522        });
523    }
524
525    #[gpui::test]
526    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
527        init_test(cx);
528        set_agent_server_settings(
529            cx,
530            vec![(
531                "agent-from-settings",
532                settings::CustomAgentServerSettings::Registry {
533                    env: HashMap::default(),
534                    default_mode: None,
535                    default_model: None,
536                    favorite_models: Vec::new(),
537                    default_config_options: HashMap::default(),
538                    favorite_config_option_values: HashMap::default(),
539                },
540            )],
541        );
542        cx.update(|cx| {
543            assert!(is_registry_agent("agent-from-settings", cx));
544        });
545    }
546
547    #[gpui::test]
548    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
549        init_test(cx);
550        set_agent_server_settings(
551            cx,
552            vec![(
553                "my-extension-agent",
554                settings::CustomAgentServerSettings::Extension {
555                    env: HashMap::default(),
556                    default_mode: None,
557                    default_model: None,
558                    favorite_models: Vec::new(),
559                    default_config_options: HashMap::default(),
560                    favorite_config_option_values: HashMap::default(),
561                },
562            )],
563        );
564        cx.update(|cx| {
565            assert!(!is_registry_agent("my-extension-agent", cx));
566        });
567    }
568
569    #[gpui::test]
570    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
571        init_test(cx);
572        cx.update(|cx| {
573            assert!(matches!(
574                default_settings_for_agent("some-extension-agent", cx),
575                settings::CustomAgentServerSettings::Extension { .. }
576            ));
577        });
578    }
579
580    #[gpui::test]
581    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
582        init_test(cx);
583        init_registry_with_agents(cx, &["new-registry-agent"]);
584        cx.update(|cx| {
585            assert!(matches!(
586                default_settings_for_agent("new-registry-agent", cx),
587                settings::CustomAgentServerSettings::Registry { .. }
588            ));
589            assert!(matches!(
590                default_settings_for_agent("not-in-registry", cx),
591                settings::CustomAgentServerSettings::Extension { .. }
592            ));
593        });
594    }
595}