custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::{
 11    Project,
 12    agent_server_store::{AgentId, AllAgentServersSettings},
 13};
 14use settings::{SettingsStore, update_settings_file};
 15use std::{rc::Rc, sync::Arc};
 16use ui::IconName;
 17
 18pub const GEMINI_ID: &str = "gemini";
 19pub const CLAUDE_AGENT_ID: &str = "claude-acp";
 20pub const CODEX_ID: &str = "codex-acp";
 21
 22/// A generic agent server implementation for custom user-defined agents
 23pub struct CustomAgentServer {
 24    agent_id: AgentId,
 25}
 26
 27impl CustomAgentServer {
 28    pub fn new(agent_id: AgentId) -> Self {
 29        Self { agent_id }
 30    }
 31}
 32
 33impl AgentServer for CustomAgentServer {
 34    fn agent_id(&self) -> AgentId {
 35        self.agent_id.clone()
 36    }
 37
 38    fn logo(&self) -> IconName {
 39        IconName::Terminal
 40    }
 41
 42    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 43        let settings = cx.read_global(|settings: &SettingsStore, _| {
 44            settings
 45                .get::<AllAgentServersSettings>(None)
 46                .get(self.agent_id().0.as_ref())
 47                .cloned()
 48        });
 49
 50        settings
 51            .as_ref()
 52            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 53    }
 54
 55    fn favorite_config_option_value_ids(
 56        &self,
 57        config_id: &acp::SessionConfigId,
 58        cx: &mut App,
 59    ) -> HashSet<acp::SessionConfigValueId> {
 60        let settings = cx.read_global(|settings: &SettingsStore, _| {
 61            settings
 62                .get::<AllAgentServersSettings>(None)
 63                .get(self.agent_id().0.as_ref())
 64                .cloned()
 65        });
 66
 67        settings
 68            .as_ref()
 69            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 70            .map(|values| {
 71                values
 72                    .iter()
 73                    .cloned()
 74                    .map(acp::SessionConfigValueId::new)
 75                    .collect()
 76            })
 77            .unwrap_or_default()
 78    }
 79
 80    fn toggle_favorite_config_option_value(
 81        &self,
 82        config_id: acp::SessionConfigId,
 83        value_id: acp::SessionConfigValueId,
 84        should_be_favorite: bool,
 85        fs: Arc<dyn Fs>,
 86        cx: &App,
 87    ) {
 88        let agent_id = self.agent_id();
 89        let config_id = config_id.to_string();
 90        let value_id = value_id.to_string();
 91
 92        update_settings_file(fs, cx, move |settings, cx| {
 93            let settings = settings
 94                .agent_servers
 95                .get_or_insert_default()
 96                .entry(agent_id.0.to_string())
 97                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
 98
 99            match settings {
100                settings::CustomAgentServerSettings::Custom {
101                    favorite_config_option_values,
102                    ..
103                }
104                | settings::CustomAgentServerSettings::Extension {
105                    favorite_config_option_values,
106                    ..
107                }
108                | settings::CustomAgentServerSettings::Registry {
109                    favorite_config_option_values,
110                    ..
111                } => {
112                    let entry = favorite_config_option_values
113                        .entry(config_id.clone())
114                        .or_insert_with(Vec::new);
115
116                    if should_be_favorite {
117                        if !entry.iter().any(|v| v == &value_id) {
118                            entry.push(value_id.clone());
119                        }
120                    } else {
121                        entry.retain(|v| v != &value_id);
122                        if entry.is_empty() {
123                            favorite_config_option_values.remove(&config_id);
124                        }
125                    }
126                }
127            }
128        });
129    }
130
131    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
132        let agent_id = self.agent_id();
133        update_settings_file(fs, cx, move |settings, cx| {
134            let settings = settings
135                .agent_servers
136                .get_or_insert_default()
137                .entry(agent_id.0.to_string())
138                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
139
140            match settings {
141                settings::CustomAgentServerSettings::Custom { default_mode, .. }
142                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
143                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
144                    *default_mode = mode_id.map(|m| m.to_string());
145                }
146            }
147        });
148    }
149
150    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
151        let settings = cx.read_global(|settings: &SettingsStore, _| {
152            settings
153                .get::<AllAgentServersSettings>(None)
154                .get(self.agent_id().as_ref())
155                .cloned()
156        });
157
158        settings
159            .as_ref()
160            .and_then(|s| s.default_model().map(acp::ModelId::new))
161    }
162
163    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
164        let agent_id = self.agent_id();
165        update_settings_file(fs, cx, move |settings, cx| {
166            let settings = settings
167                .agent_servers
168                .get_or_insert_default()
169                .entry(agent_id.0.to_string())
170                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
171
172            match settings {
173                settings::CustomAgentServerSettings::Custom { default_model, .. }
174                | settings::CustomAgentServerSettings::Extension { default_model, .. }
175                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
176                    *default_model = model_id.map(|m| m.to_string());
177                }
178            }
179        });
180    }
181
182    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
183        let settings = cx.read_global(|settings: &SettingsStore, _| {
184            settings
185                .get::<AllAgentServersSettings>(None)
186                .get(self.agent_id().as_ref())
187                .cloned()
188        });
189
190        settings
191            .as_ref()
192            .map(|s| {
193                s.favorite_models()
194                    .iter()
195                    .map(|id| acp::ModelId::new(id.clone()))
196                    .collect()
197            })
198            .unwrap_or_default()
199    }
200
201    fn toggle_favorite_model(
202        &self,
203        model_id: acp::ModelId,
204        should_be_favorite: bool,
205        fs: Arc<dyn Fs>,
206        cx: &App,
207    ) {
208        let agent_id = self.agent_id();
209        update_settings_file(fs, cx, move |settings, cx| {
210            let settings = settings
211                .agent_servers
212                .get_or_insert_default()
213                .entry(agent_id.0.to_string())
214                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
215
216            let favorite_models = match settings {
217                settings::CustomAgentServerSettings::Custom {
218                    favorite_models, ..
219                }
220                | settings::CustomAgentServerSettings::Extension {
221                    favorite_models, ..
222                }
223                | settings::CustomAgentServerSettings::Registry {
224                    favorite_models, ..
225                } => favorite_models,
226            };
227
228            let model_id_str = model_id.to_string();
229            if should_be_favorite {
230                if !favorite_models.contains(&model_id_str) {
231                    favorite_models.push(model_id_str);
232                }
233            } else {
234                favorite_models.retain(|id| id != &model_id_str);
235            }
236        });
237    }
238
239    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
240        let settings = cx.read_global(|settings: &SettingsStore, _| {
241            settings
242                .get::<AllAgentServersSettings>(None)
243                .get(self.agent_id().as_ref())
244                .cloned()
245        });
246
247        settings
248            .as_ref()
249            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
250    }
251
252    fn set_default_config_option(
253        &self,
254        config_id: &str,
255        value_id: Option<&str>,
256        fs: Arc<dyn Fs>,
257        cx: &mut App,
258    ) {
259        let agent_id = self.agent_id();
260        let config_id = config_id.to_string();
261        let value_id = value_id.map(|s| s.to_string());
262        update_settings_file(fs, cx, move |settings, cx| {
263            let settings = settings
264                .agent_servers
265                .get_or_insert_default()
266                .entry(agent_id.0.to_string())
267                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
268
269            match settings {
270                settings::CustomAgentServerSettings::Custom {
271                    default_config_options,
272                    ..
273                }
274                | settings::CustomAgentServerSettings::Extension {
275                    default_config_options,
276                    ..
277                }
278                | settings::CustomAgentServerSettings::Registry {
279                    default_config_options,
280                    ..
281                } => {
282                    if let Some(value) = value_id.clone() {
283                        default_config_options.insert(config_id.clone(), value);
284                    } else {
285                        default_config_options.remove(&config_id);
286                    }
287                }
288            }
289        });
290    }
291
292    fn connect(
293        &self,
294        delegate: AgentServerDelegate,
295        project: Entity<Project>,
296        cx: &mut App,
297    ) -> Task<Result<Rc<dyn AgentConnection>>> {
298        let agent_id = self.agent_id();
299        let default_mode = self.default_mode(cx);
300        let default_model = self.default_model(cx);
301        let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
302        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
303            settings
304                .get::<AllAgentServersSettings>(None)
305                .get(self.agent_id().as_ref())
306                .map(|s| match s {
307                    project::agent_server_store::CustomAgentServerSettings::Custom {
308                        default_config_options,
309                        ..
310                    }
311                    | project::agent_server_store::CustomAgentServerSettings::Extension {
312                        default_config_options,
313                        ..
314                    }
315                    | project::agent_server_store::CustomAgentServerSettings::Registry {
316                        default_config_options,
317                        ..
318                    } => default_config_options.clone(),
319                })
320                .unwrap_or_default()
321        });
322
323        if is_registry_agent {
324            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
325                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
326            }
327        }
328
329        let mut extra_env = load_proxy_env(cx);
330        if delegate.store.read(cx).no_browser() {
331            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
332        }
333        if is_registry_agent {
334            match agent_id.as_ref() {
335                CLAUDE_AGENT_ID => {
336                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
337                }
338                CODEX_ID => {
339                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
340                        extra_env.insert("CODEX_API_KEY".into(), api_key);
341                    }
342                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
343                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
344                    }
345                }
346                GEMINI_ID => {
347                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
348                }
349                _ => {}
350            }
351        }
352        let store = delegate.store.downgrade();
353        cx.spawn(async move |cx| {
354            if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
355                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
356                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
357                }
358            }
359            let command = store
360                .update(cx, |store, cx| {
361                    let agent = store.get_external_agent(&agent_id).with_context(|| {
362                        format!("Custom agent server `{}` is not registered", agent_id)
363                    })?;
364                    anyhow::Ok(agent.get_command(
365                        extra_env,
366                        delegate.new_version_available,
367                        &mut cx.to_async(),
368                    ))
369                })??
370                .await?;
371            let connection = crate::acp::connect(
372                agent_id,
373                project,
374                command,
375                default_mode,
376                default_model,
377                default_config_options,
378                cx,
379            )
380            .await?;
381            Ok(connection)
382        })
383    }
384
385    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
386        self
387    }
388}
389
390fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
391    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
392    if let Some(key) = env_var.value {
393        return Task::ready(Ok(key));
394    }
395    let credentials_provider = <dyn CredentialsProvider>::global(cx);
396    let api_url = google_ai::API_URL.to_string();
397    cx.spawn(async move |cx| {
398        Ok(
399            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
400                .await?
401                .key()
402                .to_string(),
403        )
404    })
405}
406
407fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
408    let agent_id = agent_id.into();
409    let is_in_registry = project::AgentRegistryStore::try_global(cx)
410        .map(|store| store.read(cx).agent(&agent_id).is_some())
411        .unwrap_or(false);
412    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
413        settings
414            .get::<AllAgentServersSettings>(None)
415            .get(agent_id.as_ref())
416            .is_some_and(|s| {
417                matches!(
418                    s,
419                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
420                )
421            })
422    });
423    is_in_registry || is_settings_registry
424}
425
426fn default_settings_for_agent(
427    agent_id: impl Into<AgentId>,
428    cx: &App,
429) -> settings::CustomAgentServerSettings {
430    if is_registry_agent(agent_id, cx) {
431        settings::CustomAgentServerSettings::Registry {
432            default_model: None,
433            default_mode: None,
434            env: Default::default(),
435            favorite_models: Vec::new(),
436            default_config_options: Default::default(),
437            favorite_config_option_values: Default::default(),
438        }
439    } else {
440        settings::CustomAgentServerSettings::Extension {
441            default_model: None,
442            default_mode: None,
443            env: Default::default(),
444            favorite_models: Vec::new(),
445            default_config_options: Default::default(),
446            favorite_config_option_values: Default::default(),
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454    use collections::HashMap;
455    use gpui::TestAppContext;
456    use project::agent_registry_store::{
457        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
458    };
459    use settings::Settings as _;
460    use ui::SharedString;
461
462    fn init_test(cx: &mut TestAppContext) {
463        cx.update(|cx| {
464            let settings_store = SettingsStore::test(cx);
465            cx.set_global(settings_store);
466        });
467    }
468
469    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
470        let agents: Vec<RegistryAgent> = agent_ids
471            .iter()
472            .map(|id| {
473                let id = SharedString::from(id.to_string());
474                RegistryAgent::Npx(RegistryNpxAgent {
475                    metadata: RegistryAgentMetadata {
476                        id: AgentId::new(id.clone()),
477                        name: id.clone(),
478                        description: SharedString::from(""),
479                        version: SharedString::from("1.0.0"),
480                        repository: None,
481                        website: None,
482                        icon_path: None,
483                    },
484                    package: id,
485                    args: Vec::new(),
486                    env: HashMap::default(),
487                })
488            })
489            .collect();
490        cx.update(|cx| {
491            AgentRegistryStore::init_test_global(cx, agents);
492        });
493    }
494
495    fn set_agent_server_settings(
496        cx: &mut TestAppContext,
497        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
498    ) {
499        cx.update(|cx| {
500            AllAgentServersSettings::override_global(
501                project::agent_server_store::AllAgentServersSettings(
502                    entries
503                        .into_iter()
504                        .map(|(name, settings)| (name.to_string(), settings.into()))
505                        .collect(),
506                ),
507                cx,
508            );
509        });
510    }
511
512    #[gpui::test]
513    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
514        init_test(cx);
515        cx.update(|cx| {
516            assert!(!is_registry_agent("my-custom-agent", cx));
517        });
518    }
519
520    #[gpui::test]
521    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
522        init_test(cx);
523        init_registry_with_agents(cx, &["some-new-registry-agent"]);
524        cx.update(|cx| {
525            assert!(is_registry_agent("some-new-registry-agent", cx));
526            assert!(!is_registry_agent("not-in-registry", cx));
527        });
528    }
529
530    #[gpui::test]
531    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
532        init_test(cx);
533        set_agent_server_settings(
534            cx,
535            vec![(
536                "agent-from-settings",
537                settings::CustomAgentServerSettings::Registry {
538                    env: HashMap::default(),
539                    default_mode: None,
540                    default_model: None,
541                    favorite_models: Vec::new(),
542                    default_config_options: HashMap::default(),
543                    favorite_config_option_values: HashMap::default(),
544                },
545            )],
546        );
547        cx.update(|cx| {
548            assert!(is_registry_agent("agent-from-settings", cx));
549        });
550    }
551
552    #[gpui::test]
553    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
554        init_test(cx);
555        set_agent_server_settings(
556            cx,
557            vec![(
558                "my-extension-agent",
559                settings::CustomAgentServerSettings::Extension {
560                    env: HashMap::default(),
561                    default_mode: None,
562                    default_model: None,
563                    favorite_models: Vec::new(),
564                    default_config_options: HashMap::default(),
565                    favorite_config_option_values: HashMap::default(),
566                },
567            )],
568        );
569        cx.update(|cx| {
570            assert!(!is_registry_agent("my-extension-agent", cx));
571        });
572    }
573
574    #[gpui::test]
575    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
576        init_test(cx);
577        cx.update(|cx| {
578            assert!(matches!(
579                default_settings_for_agent("some-extension-agent", cx),
580                settings::CustomAgentServerSettings::Extension { .. }
581            ));
582        });
583    }
584
585    #[gpui::test]
586    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
587        init_test(cx);
588        init_registry_with_agents(cx, &["new-registry-agent"]);
589        cx.update(|cx| {
590            assert!(matches!(
591                default_settings_for_agent("new-registry-agent", cx),
592                settings::CustomAgentServerSettings::Registry { .. }
593            ));
594            assert!(matches!(
595                default_settings_for_agent("not-in-registry", cx),
596                settings::CustomAgentServerSettings::Extension { .. }
597            ));
598        });
599    }
600}