custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::agent_server_store::{AgentId, AllAgentServersSettings};
 11use settings::{SettingsStore, update_settings_file};
 12use std::{rc::Rc, sync::Arc};
 13use ui::IconName;
 14
 15pub const GEMINI_ID: &str = "gemini";
 16pub const CLAUDE_AGENT_ID: &str = "claude-acp";
 17pub const CODEX_ID: &str = "codex-acp";
 18
 19/// A generic agent server implementation for custom user-defined agents
 20pub struct CustomAgentServer {
 21    agent_id: AgentId,
 22}
 23
 24impl CustomAgentServer {
 25    pub fn new(agent_id: AgentId) -> Self {
 26        Self { agent_id }
 27    }
 28}
 29
 30impl AgentServer for CustomAgentServer {
 31    fn agent_id(&self) -> AgentId {
 32        self.agent_id.clone()
 33    }
 34
 35    fn logo(&self) -> IconName {
 36        IconName::Terminal
 37    }
 38
 39    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 40        let settings = cx.read_global(|settings: &SettingsStore, _| {
 41            settings
 42                .get::<AllAgentServersSettings>(None)
 43                .get(self.agent_id().0.as_ref())
 44                .cloned()
 45        });
 46
 47        settings
 48            .as_ref()
 49            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 50    }
 51
 52    fn favorite_config_option_value_ids(
 53        &self,
 54        config_id: &acp::SessionConfigId,
 55        cx: &mut App,
 56    ) -> HashSet<acp::SessionConfigValueId> {
 57        let settings = cx.read_global(|settings: &SettingsStore, _| {
 58            settings
 59                .get::<AllAgentServersSettings>(None)
 60                .get(self.agent_id().0.as_ref())
 61                .cloned()
 62        });
 63
 64        settings
 65            .as_ref()
 66            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 67            .map(|values| {
 68                values
 69                    .iter()
 70                    .cloned()
 71                    .map(acp::SessionConfigValueId::new)
 72                    .collect()
 73            })
 74            .unwrap_or_default()
 75    }
 76
 77    fn toggle_favorite_config_option_value(
 78        &self,
 79        config_id: acp::SessionConfigId,
 80        value_id: acp::SessionConfigValueId,
 81        should_be_favorite: bool,
 82        fs: Arc<dyn Fs>,
 83        cx: &App,
 84    ) {
 85        let agent_id = self.agent_id();
 86        let config_id = config_id.to_string();
 87        let value_id = value_id.to_string();
 88
 89        update_settings_file(fs, cx, move |settings, cx| {
 90            let settings = settings
 91                .agent_servers
 92                .get_or_insert_default()
 93                .entry(agent_id.0.to_string())
 94                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
 95
 96            match settings {
 97                settings::CustomAgentServerSettings::Custom {
 98                    favorite_config_option_values,
 99                    ..
100                }
101                | settings::CustomAgentServerSettings::Extension {
102                    favorite_config_option_values,
103                    ..
104                }
105                | settings::CustomAgentServerSettings::Registry {
106                    favorite_config_option_values,
107                    ..
108                } => {
109                    let entry = favorite_config_option_values
110                        .entry(config_id.clone())
111                        .or_insert_with(Vec::new);
112
113                    if should_be_favorite {
114                        if !entry.iter().any(|v| v == &value_id) {
115                            entry.push(value_id.clone());
116                        }
117                    } else {
118                        entry.retain(|v| v != &value_id);
119                        if entry.is_empty() {
120                            favorite_config_option_values.remove(&config_id);
121                        }
122                    }
123                }
124            }
125        });
126    }
127
128    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
129        let agent_id = self.agent_id();
130        update_settings_file(fs, cx, move |settings, cx| {
131            let settings = settings
132                .agent_servers
133                .get_or_insert_default()
134                .entry(agent_id.0.to_string())
135                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
136
137            match settings {
138                settings::CustomAgentServerSettings::Custom { default_mode, .. }
139                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
140                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
141                    *default_mode = mode_id.map(|m| m.to_string());
142                }
143            }
144        });
145    }
146
147    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
148        let settings = cx.read_global(|settings: &SettingsStore, _| {
149            settings
150                .get::<AllAgentServersSettings>(None)
151                .get(self.agent_id().as_ref())
152                .cloned()
153        });
154
155        settings
156            .as_ref()
157            .and_then(|s| s.default_model().map(acp::ModelId::new))
158    }
159
160    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
161        let agent_id = self.agent_id();
162        update_settings_file(fs, cx, move |settings, cx| {
163            let settings = settings
164                .agent_servers
165                .get_or_insert_default()
166                .entry(agent_id.0.to_string())
167                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
168
169            match settings {
170                settings::CustomAgentServerSettings::Custom { default_model, .. }
171                | settings::CustomAgentServerSettings::Extension { default_model, .. }
172                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
173                    *default_model = model_id.map(|m| m.to_string());
174                }
175            }
176        });
177    }
178
179    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
180        let settings = cx.read_global(|settings: &SettingsStore, _| {
181            settings
182                .get::<AllAgentServersSettings>(None)
183                .get(self.agent_id().as_ref())
184                .cloned()
185        });
186
187        settings
188            .as_ref()
189            .map(|s| {
190                s.favorite_models()
191                    .iter()
192                    .map(|id| acp::ModelId::new(id.clone()))
193                    .collect()
194            })
195            .unwrap_or_default()
196    }
197
198    fn toggle_favorite_model(
199        &self,
200        model_id: acp::ModelId,
201        should_be_favorite: bool,
202        fs: Arc<dyn Fs>,
203        cx: &App,
204    ) {
205        let agent_id = self.agent_id();
206        update_settings_file(fs, cx, move |settings, cx| {
207            let settings = settings
208                .agent_servers
209                .get_or_insert_default()
210                .entry(agent_id.0.to_string())
211                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
212
213            let favorite_models = match settings {
214                settings::CustomAgentServerSettings::Custom {
215                    favorite_models, ..
216                }
217                | settings::CustomAgentServerSettings::Extension {
218                    favorite_models, ..
219                }
220                | settings::CustomAgentServerSettings::Registry {
221                    favorite_models, ..
222                } => favorite_models,
223            };
224
225            let model_id_str = model_id.to_string();
226            if should_be_favorite {
227                if !favorite_models.contains(&model_id_str) {
228                    favorite_models.push(model_id_str);
229                }
230            } else {
231                favorite_models.retain(|id| id != &model_id_str);
232            }
233        });
234    }
235
236    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
237        let settings = cx.read_global(|settings: &SettingsStore, _| {
238            settings
239                .get::<AllAgentServersSettings>(None)
240                .get(self.agent_id().as_ref())
241                .cloned()
242        });
243
244        settings
245            .as_ref()
246            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
247    }
248
249    fn set_default_config_option(
250        &self,
251        config_id: &str,
252        value_id: Option<&str>,
253        fs: Arc<dyn Fs>,
254        cx: &mut App,
255    ) {
256        let agent_id = self.agent_id();
257        let config_id = config_id.to_string();
258        let value_id = value_id.map(|s| s.to_string());
259        update_settings_file(fs, cx, move |settings, cx| {
260            let settings = settings
261                .agent_servers
262                .get_or_insert_default()
263                .entry(agent_id.0.to_string())
264                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
265
266            match settings {
267                settings::CustomAgentServerSettings::Custom {
268                    default_config_options,
269                    ..
270                }
271                | settings::CustomAgentServerSettings::Extension {
272                    default_config_options,
273                    ..
274                }
275                | settings::CustomAgentServerSettings::Registry {
276                    default_config_options,
277                    ..
278                } => {
279                    if let Some(value) = value_id.clone() {
280                        default_config_options.insert(config_id.clone(), value);
281                    } else {
282                        default_config_options.remove(&config_id);
283                    }
284                }
285            }
286        });
287    }
288
289    fn connect(
290        &self,
291        delegate: AgentServerDelegate,
292        cx: &mut App,
293    ) -> Task<Result<Rc<dyn AgentConnection>>> {
294        let agent_id = self.agent_id();
295        let display_name = delegate
296            .store
297            .read(cx)
298            .agent_display_name(&agent_id)
299            .unwrap_or_else(|| agent_id.0.clone());
300        let default_mode = self.default_mode(cx);
301        let default_model = self.default_model(cx);
302        let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
303        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
304            settings
305                .get::<AllAgentServersSettings>(None)
306                .get(self.agent_id().as_ref())
307                .map(|s| match s {
308                    project::agent_server_store::CustomAgentServerSettings::Custom {
309                        default_config_options,
310                        ..
311                    }
312                    | project::agent_server_store::CustomAgentServerSettings::Extension {
313                        default_config_options,
314                        ..
315                    }
316                    | project::agent_server_store::CustomAgentServerSettings::Registry {
317                        default_config_options,
318                        ..
319                    } => default_config_options.clone(),
320                })
321                .unwrap_or_default()
322        });
323
324        if is_registry_agent {
325            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
326                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
327            }
328        }
329
330        let mut extra_env = load_proxy_env(cx);
331        if delegate.store.read(cx).no_browser() {
332            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
333        }
334        if is_registry_agent {
335            match agent_id.as_ref() {
336                CLAUDE_AGENT_ID => {
337                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
338                }
339                CODEX_ID => {
340                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
341                        extra_env.insert("CODEX_API_KEY".into(), api_key);
342                    }
343                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
344                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
345                    }
346                }
347                GEMINI_ID => {
348                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
349                }
350                _ => {}
351            }
352        }
353        let store = delegate.store.downgrade();
354        cx.spawn(async move |cx| {
355            if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
356                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
357                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
358                }
359            }
360            let command = store
361                .update(cx, |store, cx| {
362                    let agent = store.get_external_agent(&agent_id).with_context(|| {
363                        format!("Custom agent server `{}` is not registered", agent_id)
364                    })?;
365                    anyhow::Ok(agent.get_command(
366                        extra_env,
367                        delegate.new_version_available,
368                        &mut cx.to_async(),
369                    ))
370                })??
371                .await?;
372            let connection = crate::acp::connect(
373                agent_id,
374                display_name,
375                command,
376                default_mode,
377                default_model,
378                default_config_options,
379                cx,
380            )
381            .await?;
382            Ok(connection)
383        })
384    }
385
386    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
387        self
388    }
389}
390
391fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
392    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
393    if let Some(key) = env_var.value {
394        return Task::ready(Ok(key));
395    }
396    let credentials_provider = <dyn CredentialsProvider>::global(cx);
397    let api_url = google_ai::API_URL.to_string();
398    cx.spawn(async move |cx| {
399        Ok(
400            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
401                .await?
402                .key()
403                .to_string(),
404        )
405    })
406}
407
408fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
409    let agent_id = agent_id.into();
410    let is_in_registry = project::AgentRegistryStore::try_global(cx)
411        .map(|store| store.read(cx).agent(&agent_id).is_some())
412        .unwrap_or(false);
413    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
414        settings
415            .get::<AllAgentServersSettings>(None)
416            .get(agent_id.as_ref())
417            .is_some_and(|s| {
418                matches!(
419                    s,
420                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
421                )
422            })
423    });
424    is_in_registry || is_settings_registry
425}
426
427fn default_settings_for_agent(
428    agent_id: impl Into<AgentId>,
429    cx: &App,
430) -> settings::CustomAgentServerSettings {
431    if is_registry_agent(agent_id, cx) {
432        settings::CustomAgentServerSettings::Registry {
433            default_model: None,
434            default_mode: None,
435            env: Default::default(),
436            favorite_models: Vec::new(),
437            default_config_options: Default::default(),
438            favorite_config_option_values: Default::default(),
439        }
440    } else {
441        settings::CustomAgentServerSettings::Extension {
442            default_model: None,
443            default_mode: None,
444            env: Default::default(),
445            favorite_models: Vec::new(),
446            default_config_options: Default::default(),
447            favorite_config_option_values: Default::default(),
448        }
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use collections::HashMap;
456    use gpui::TestAppContext;
457    use project::agent_registry_store::{
458        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
459    };
460    use settings::Settings as _;
461    use ui::SharedString;
462
463    fn init_test(cx: &mut TestAppContext) {
464        cx.update(|cx| {
465            let settings_store = SettingsStore::test(cx);
466            cx.set_global(settings_store);
467        });
468    }
469
470    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
471        let agents: Vec<RegistryAgent> = agent_ids
472            .iter()
473            .map(|id| {
474                let id = SharedString::from(id.to_string());
475                RegistryAgent::Npx(RegistryNpxAgent {
476                    metadata: RegistryAgentMetadata {
477                        id: AgentId::new(id.clone()),
478                        name: id.clone(),
479                        description: SharedString::from(""),
480                        version: SharedString::from("1.0.0"),
481                        repository: None,
482                        icon_path: None,
483                    },
484                    package: id,
485                    args: Vec::new(),
486                    env: HashMap::default(),
487                })
488            })
489            .collect();
490        cx.update(|cx| {
491            AgentRegistryStore::init_test_global(cx, agents);
492        });
493    }
494
495    fn set_agent_server_settings(
496        cx: &mut TestAppContext,
497        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
498    ) {
499        cx.update(|cx| {
500            AllAgentServersSettings::override_global(
501                project::agent_server_store::AllAgentServersSettings(
502                    entries
503                        .into_iter()
504                        .map(|(name, settings)| (name.to_string(), settings.into()))
505                        .collect(),
506                ),
507                cx,
508            );
509        });
510    }
511
512    #[gpui::test]
513    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
514        init_test(cx);
515        cx.update(|cx| {
516            assert!(!is_registry_agent("my-custom-agent", cx));
517        });
518    }
519
520    #[gpui::test]
521    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
522        init_test(cx);
523        init_registry_with_agents(cx, &["some-new-registry-agent"]);
524        cx.update(|cx| {
525            assert!(is_registry_agent("some-new-registry-agent", cx));
526            assert!(!is_registry_agent("not-in-registry", cx));
527        });
528    }
529
530    #[gpui::test]
531    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
532        init_test(cx);
533        set_agent_server_settings(
534            cx,
535            vec![(
536                "agent-from-settings",
537                settings::CustomAgentServerSettings::Registry {
538                    env: HashMap::default(),
539                    default_mode: None,
540                    default_model: None,
541                    favorite_models: Vec::new(),
542                    default_config_options: HashMap::default(),
543                    favorite_config_option_values: HashMap::default(),
544                },
545            )],
546        );
547        cx.update(|cx| {
548            assert!(is_registry_agent("agent-from-settings", cx));
549        });
550    }
551
552    #[gpui::test]
553    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
554        init_test(cx);
555        set_agent_server_settings(
556            cx,
557            vec![(
558                "my-extension-agent",
559                settings::CustomAgentServerSettings::Extension {
560                    env: HashMap::default(),
561                    default_mode: None,
562                    default_model: None,
563                    favorite_models: Vec::new(),
564                    default_config_options: HashMap::default(),
565                    favorite_config_option_values: HashMap::default(),
566                },
567            )],
568        );
569        cx.update(|cx| {
570            assert!(!is_registry_agent("my-extension-agent", cx));
571        });
572    }
573
574    #[gpui::test]
575    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
576        init_test(cx);
577        cx.update(|cx| {
578            assert!(matches!(
579                default_settings_for_agent("some-extension-agent", cx),
580                settings::CustomAgentServerSettings::Extension { .. }
581            ));
582        });
583    }
584
585    #[gpui::test]
586    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
587        init_test(cx);
588        init_registry_with_agents(cx, &["new-registry-agent"]);
589        cx.update(|cx| {
590            assert!(matches!(
591                default_settings_for_agent("new-registry-agent", cx),
592                settings::CustomAgentServerSettings::Registry { .. }
593            ));
594            assert!(matches!(
595                default_settings_for_agent("not-in-registry", cx),
596                settings::CustomAgentServerSettings::Extension { .. }
597            ));
598        });
599    }
600}