custom.rs

  1use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
  2use acp_thread::AgentConnection;
  3use agent_client_protocol as acp;
  4use anyhow::{Context as _, Result};
  5use collections::HashSet;
  6use credentials_provider::CredentialsProvider;
  7use fs::Fs;
  8use gpui::{App, AppContext as _, Entity, Task};
  9use language_model::{ApiKey, EnvVar};
 10use project::{
 11    Project,
 12    agent_server_store::{AgentId, AllAgentServersSettings},
 13};
 14use settings::{SettingsStore, update_settings_file};
 15use std::{rc::Rc, sync::Arc};
 16use ui::IconName;
 17
 18pub const GEMINI_ID: &str = "gemini";
 19pub const CLAUDE_AGENT_ID: &str = "claude-acp";
 20pub const CODEX_ID: &str = "codex-acp";
 21
 22/// A generic agent server implementation for custom user-defined agents
 23pub struct CustomAgentServer {
 24    agent_id: AgentId,
 25}
 26
 27impl CustomAgentServer {
 28    pub fn new(agent_id: AgentId) -> Self {
 29        Self { agent_id }
 30    }
 31}
 32
 33impl AgentServer for CustomAgentServer {
 34    fn agent_id(&self) -> AgentId {
 35        self.agent_id.clone()
 36    }
 37
 38    fn logo(&self) -> IconName {
 39        IconName::Terminal
 40    }
 41
 42    fn default_mode(&self, cx: &App) -> Option<acp::SessionModeId> {
 43        let settings = cx.read_global(|settings: &SettingsStore, _| {
 44            settings
 45                .get::<AllAgentServersSettings>(None)
 46                .get(self.agent_id().0.as_ref())
 47                .cloned()
 48        });
 49
 50        settings
 51            .as_ref()
 52            .and_then(|s| s.default_mode().map(acp::SessionModeId::new))
 53    }
 54
 55    fn favorite_config_option_value_ids(
 56        &self,
 57        config_id: &acp::SessionConfigId,
 58        cx: &mut App,
 59    ) -> HashSet<acp::SessionConfigValueId> {
 60        let settings = cx.read_global(|settings: &SettingsStore, _| {
 61            settings
 62                .get::<AllAgentServersSettings>(None)
 63                .get(self.agent_id().0.as_ref())
 64                .cloned()
 65        });
 66
 67        settings
 68            .as_ref()
 69            .and_then(|s| s.favorite_config_option_values(config_id.0.as_ref()))
 70            .map(|values| {
 71                values
 72                    .iter()
 73                    .cloned()
 74                    .map(acp::SessionConfigValueId::new)
 75                    .collect()
 76            })
 77            .unwrap_or_default()
 78    }
 79
 80    fn toggle_favorite_config_option_value(
 81        &self,
 82        config_id: acp::SessionConfigId,
 83        value_id: acp::SessionConfigValueId,
 84        should_be_favorite: bool,
 85        fs: Arc<dyn Fs>,
 86        cx: &App,
 87    ) {
 88        let agent_id = self.agent_id();
 89        let config_id = config_id.to_string();
 90        let value_id = value_id.to_string();
 91
 92        update_settings_file(fs, cx, move |settings, cx| {
 93            let settings = settings
 94                .agent_servers
 95                .get_or_insert_default()
 96                .entry(agent_id.0.to_string())
 97                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
 98
 99            match settings {
100                settings::CustomAgentServerSettings::Custom {
101                    favorite_config_option_values,
102                    ..
103                }
104                | settings::CustomAgentServerSettings::Extension {
105                    favorite_config_option_values,
106                    ..
107                }
108                | settings::CustomAgentServerSettings::Registry {
109                    favorite_config_option_values,
110                    ..
111                } => {
112                    let entry = favorite_config_option_values
113                        .entry(config_id.clone())
114                        .or_insert_with(Vec::new);
115
116                    if should_be_favorite {
117                        if !entry.iter().any(|v| v == &value_id) {
118                            entry.push(value_id.clone());
119                        }
120                    } else {
121                        entry.retain(|v| v != &value_id);
122                        if entry.is_empty() {
123                            favorite_config_option_values.remove(&config_id);
124                        }
125                    }
126                }
127            }
128        });
129    }
130
131    fn set_default_mode(&self, mode_id: Option<acp::SessionModeId>, fs: Arc<dyn Fs>, cx: &mut App) {
132        let agent_id = self.agent_id();
133        update_settings_file(fs, cx, move |settings, cx| {
134            let settings = settings
135                .agent_servers
136                .get_or_insert_default()
137                .entry(agent_id.0.to_string())
138                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
139
140            match settings {
141                settings::CustomAgentServerSettings::Custom { default_mode, .. }
142                | settings::CustomAgentServerSettings::Extension { default_mode, .. }
143                | settings::CustomAgentServerSettings::Registry { default_mode, .. } => {
144                    *default_mode = mode_id.map(|m| m.to_string());
145                }
146            }
147        });
148    }
149
150    fn default_model(&self, cx: &App) -> Option<acp::ModelId> {
151        let settings = cx.read_global(|settings: &SettingsStore, _| {
152            settings
153                .get::<AllAgentServersSettings>(None)
154                .get(self.agent_id().as_ref())
155                .cloned()
156        });
157
158        settings
159            .as_ref()
160            .and_then(|s| s.default_model().map(acp::ModelId::new))
161    }
162
163    fn set_default_model(&self, model_id: Option<acp::ModelId>, fs: Arc<dyn Fs>, cx: &mut App) {
164        let agent_id = self.agent_id();
165        update_settings_file(fs, cx, move |settings, cx| {
166            let settings = settings
167                .agent_servers
168                .get_or_insert_default()
169                .entry(agent_id.0.to_string())
170                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
171
172            match settings {
173                settings::CustomAgentServerSettings::Custom { default_model, .. }
174                | settings::CustomAgentServerSettings::Extension { default_model, .. }
175                | settings::CustomAgentServerSettings::Registry { default_model, .. } => {
176                    *default_model = model_id.map(|m| m.to_string());
177                }
178            }
179        });
180    }
181
182    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
183        let settings = cx.read_global(|settings: &SettingsStore, _| {
184            settings
185                .get::<AllAgentServersSettings>(None)
186                .get(self.agent_id().as_ref())
187                .cloned()
188        });
189
190        settings
191            .as_ref()
192            .map(|s| {
193                s.favorite_models()
194                    .iter()
195                    .map(|id| acp::ModelId::new(id.clone()))
196                    .collect()
197            })
198            .unwrap_or_default()
199    }
200
201    fn toggle_favorite_model(
202        &self,
203        model_id: acp::ModelId,
204        should_be_favorite: bool,
205        fs: Arc<dyn Fs>,
206        cx: &App,
207    ) {
208        let agent_id = self.agent_id();
209        update_settings_file(fs, cx, move |settings, cx| {
210            let settings = settings
211                .agent_servers
212                .get_or_insert_default()
213                .entry(agent_id.0.to_string())
214                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
215
216            let favorite_models = match settings {
217                settings::CustomAgentServerSettings::Custom {
218                    favorite_models, ..
219                }
220                | settings::CustomAgentServerSettings::Extension {
221                    favorite_models, ..
222                }
223                | settings::CustomAgentServerSettings::Registry {
224                    favorite_models, ..
225                } => favorite_models,
226            };
227
228            let model_id_str = model_id.to_string();
229            if should_be_favorite {
230                if !favorite_models.contains(&model_id_str) {
231                    favorite_models.push(model_id_str);
232                }
233            } else {
234                favorite_models.retain(|id| id != &model_id_str);
235            }
236        });
237    }
238
239    fn default_config_option(&self, config_id: &str, cx: &App) -> Option<String> {
240        let settings = cx.read_global(|settings: &SettingsStore, _| {
241            settings
242                .get::<AllAgentServersSettings>(None)
243                .get(self.agent_id().as_ref())
244                .cloned()
245        });
246
247        settings
248            .as_ref()
249            .and_then(|s| s.default_config_option(config_id).map(|s| s.to_string()))
250    }
251
252    fn set_default_config_option(
253        &self,
254        config_id: &str,
255        value_id: Option<&str>,
256        fs: Arc<dyn Fs>,
257        cx: &mut App,
258    ) {
259        let agent_id = self.agent_id();
260        let config_id = config_id.to_string();
261        let value_id = value_id.map(|s| s.to_string());
262        update_settings_file(fs, cx, move |settings, cx| {
263            let settings = settings
264                .agent_servers
265                .get_or_insert_default()
266                .entry(agent_id.0.to_string())
267                .or_insert_with(|| default_settings_for_agent(agent_id, cx));
268
269            match settings {
270                settings::CustomAgentServerSettings::Custom {
271                    default_config_options,
272                    ..
273                }
274                | settings::CustomAgentServerSettings::Extension {
275                    default_config_options,
276                    ..
277                }
278                | settings::CustomAgentServerSettings::Registry {
279                    default_config_options,
280                    ..
281                } => {
282                    if let Some(value) = value_id.clone() {
283                        default_config_options.insert(config_id.clone(), value);
284                    } else {
285                        default_config_options.remove(&config_id);
286                    }
287                }
288            }
289        });
290    }
291
292    fn connect(
293        &self,
294        delegate: AgentServerDelegate,
295        project: Entity<Project>,
296        cx: &mut App,
297    ) -> Task<Result<Rc<dyn AgentConnection>>> {
298        let agent_id = self.agent_id();
299        let display_name = delegate
300            .store
301            .read(cx)
302            .agent_display_name(&agent_id)
303            .unwrap_or_else(|| agent_id.0.clone());
304        let default_mode = self.default_mode(cx);
305        let default_model = self.default_model(cx);
306        let is_registry_agent = is_registry_agent(agent_id.clone(), cx);
307        let default_config_options = cx.read_global(|settings: &SettingsStore, _| {
308            settings
309                .get::<AllAgentServersSettings>(None)
310                .get(self.agent_id().as_ref())
311                .map(|s| match s {
312                    project::agent_server_store::CustomAgentServerSettings::Custom {
313                        default_config_options,
314                        ..
315                    }
316                    | project::agent_server_store::CustomAgentServerSettings::Extension {
317                        default_config_options,
318                        ..
319                    }
320                    | project::agent_server_store::CustomAgentServerSettings::Registry {
321                        default_config_options,
322                        ..
323                    } => default_config_options.clone(),
324                })
325                .unwrap_or_default()
326        });
327
328        if is_registry_agent {
329            if let Some(registry_store) = project::AgentRegistryStore::try_global(cx) {
330                registry_store.update(cx, |store, cx| store.refresh_if_stale(cx));
331            }
332        }
333
334        let mut extra_env = load_proxy_env(cx);
335        if delegate.store.read(cx).no_browser() {
336            extra_env.insert("NO_BROWSER".to_owned(), "1".to_owned());
337        }
338        if is_registry_agent {
339            match agent_id.as_ref() {
340                CLAUDE_AGENT_ID => {
341                    extra_env.insert("ANTHROPIC_API_KEY".into(), "".into());
342                }
343                CODEX_ID => {
344                    if let Ok(api_key) = std::env::var("CODEX_API_KEY") {
345                        extra_env.insert("CODEX_API_KEY".into(), api_key);
346                    }
347                    if let Ok(api_key) = std::env::var("OPEN_AI_API_KEY") {
348                        extra_env.insert("OPEN_AI_API_KEY".into(), api_key);
349                    }
350                }
351                GEMINI_ID => {
352                    extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
353                }
354                _ => {}
355            }
356        }
357        let store = delegate.store.downgrade();
358        cx.spawn(async move |cx| {
359            if is_registry_agent && agent_id.as_ref() == GEMINI_ID {
360                if let Some(api_key) = cx.update(api_key_for_gemini_cli).await.ok() {
361                    extra_env.insert("GEMINI_API_KEY".into(), api_key);
362                }
363            }
364            let command = store
365                .update(cx, |store, cx| {
366                    let agent = store.get_external_agent(&agent_id).with_context(|| {
367                        format!("Custom agent server `{}` is not registered", agent_id)
368                    })?;
369                    anyhow::Ok(agent.get_command(
370                        extra_env,
371                        delegate.new_version_available,
372                        &mut cx.to_async(),
373                    ))
374                })??
375                .await?;
376            let connection = crate::acp::connect(
377                agent_id,
378                project,
379                display_name,
380                command,
381                default_mode,
382                default_model,
383                default_config_options,
384                cx,
385            )
386            .await?;
387            Ok(connection)
388        })
389    }
390
391    fn into_any(self: Rc<Self>) -> Rc<dyn std::any::Any> {
392        self
393    }
394}
395
396fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
397    let env_var = EnvVar::new("GEMINI_API_KEY".into()).or(EnvVar::new("GOOGLE_AI_API_KEY".into()));
398    if let Some(key) = env_var.value {
399        return Task::ready(Ok(key));
400    }
401    let credentials_provider = <dyn CredentialsProvider>::global(cx);
402    let api_url = google_ai::API_URL.to_string();
403    cx.spawn(async move |cx| {
404        Ok(
405            ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
406                .await?
407                .key()
408                .to_string(),
409        )
410    })
411}
412
413fn is_registry_agent(agent_id: impl Into<AgentId>, cx: &App) -> bool {
414    let agent_id = agent_id.into();
415    let is_in_registry = project::AgentRegistryStore::try_global(cx)
416        .map(|store| store.read(cx).agent(&agent_id).is_some())
417        .unwrap_or(false);
418    let is_settings_registry = cx.read_global(|settings: &SettingsStore, _| {
419        settings
420            .get::<AllAgentServersSettings>(None)
421            .get(agent_id.as_ref())
422            .is_some_and(|s| {
423                matches!(
424                    s,
425                    project::agent_server_store::CustomAgentServerSettings::Registry { .. }
426                )
427            })
428    });
429    is_in_registry || is_settings_registry
430}
431
432fn default_settings_for_agent(
433    agent_id: impl Into<AgentId>,
434    cx: &App,
435) -> settings::CustomAgentServerSettings {
436    if is_registry_agent(agent_id, cx) {
437        settings::CustomAgentServerSettings::Registry {
438            default_model: None,
439            default_mode: None,
440            env: Default::default(),
441            favorite_models: Vec::new(),
442            default_config_options: Default::default(),
443            favorite_config_option_values: Default::default(),
444        }
445    } else {
446        settings::CustomAgentServerSettings::Extension {
447            default_model: None,
448            default_mode: None,
449            env: Default::default(),
450            favorite_models: Vec::new(),
451            default_config_options: Default::default(),
452            favorite_config_option_values: Default::default(),
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use collections::HashMap;
461    use gpui::TestAppContext;
462    use project::agent_registry_store::{
463        AgentRegistryStore, RegistryAgent, RegistryAgentMetadata, RegistryNpxAgent,
464    };
465    use settings::Settings as _;
466    use ui::SharedString;
467
468    fn init_test(cx: &mut TestAppContext) {
469        cx.update(|cx| {
470            let settings_store = SettingsStore::test(cx);
471            cx.set_global(settings_store);
472        });
473    }
474
475    fn init_registry_with_agents(cx: &mut TestAppContext, agent_ids: &[&str]) {
476        let agents: Vec<RegistryAgent> = agent_ids
477            .iter()
478            .map(|id| {
479                let id = SharedString::from(id.to_string());
480                RegistryAgent::Npx(RegistryNpxAgent {
481                    metadata: RegistryAgentMetadata {
482                        id: AgentId::new(id.clone()),
483                        name: id.clone(),
484                        description: SharedString::from(""),
485                        version: SharedString::from("1.0.0"),
486                        repository: None,
487                        website: None,
488                        icon_path: None,
489                    },
490                    package: id,
491                    args: Vec::new(),
492                    env: HashMap::default(),
493                })
494            })
495            .collect();
496        cx.update(|cx| {
497            AgentRegistryStore::init_test_global(cx, agents);
498        });
499    }
500
501    fn set_agent_server_settings(
502        cx: &mut TestAppContext,
503        entries: Vec<(&str, settings::CustomAgentServerSettings)>,
504    ) {
505        cx.update(|cx| {
506            AllAgentServersSettings::override_global(
507                project::agent_server_store::AllAgentServersSettings(
508                    entries
509                        .into_iter()
510                        .map(|(name, settings)| (name.to_string(), settings.into()))
511                        .collect(),
512                ),
513                cx,
514            );
515        });
516    }
517
518    #[gpui::test]
519    fn test_unknown_agent_is_not_registry(cx: &mut TestAppContext) {
520        init_test(cx);
521        cx.update(|cx| {
522            assert!(!is_registry_agent("my-custom-agent", cx));
523        });
524    }
525
526    #[gpui::test]
527    fn test_agent_in_registry_store_is_registry(cx: &mut TestAppContext) {
528        init_test(cx);
529        init_registry_with_agents(cx, &["some-new-registry-agent"]);
530        cx.update(|cx| {
531            assert!(is_registry_agent("some-new-registry-agent", cx));
532            assert!(!is_registry_agent("not-in-registry", cx));
533        });
534    }
535
536    #[gpui::test]
537    fn test_agent_with_registry_settings_type_is_registry(cx: &mut TestAppContext) {
538        init_test(cx);
539        set_agent_server_settings(
540            cx,
541            vec![(
542                "agent-from-settings",
543                settings::CustomAgentServerSettings::Registry {
544                    env: HashMap::default(),
545                    default_mode: None,
546                    default_model: None,
547                    favorite_models: Vec::new(),
548                    default_config_options: HashMap::default(),
549                    favorite_config_option_values: HashMap::default(),
550                },
551            )],
552        );
553        cx.update(|cx| {
554            assert!(is_registry_agent("agent-from-settings", cx));
555        });
556    }
557
558    #[gpui::test]
559    fn test_agent_with_extension_settings_type_is_not_registry(cx: &mut TestAppContext) {
560        init_test(cx);
561        set_agent_server_settings(
562            cx,
563            vec![(
564                "my-extension-agent",
565                settings::CustomAgentServerSettings::Extension {
566                    env: HashMap::default(),
567                    default_mode: None,
568                    default_model: None,
569                    favorite_models: Vec::new(),
570                    default_config_options: HashMap::default(),
571                    favorite_config_option_values: HashMap::default(),
572                },
573            )],
574        );
575        cx.update(|cx| {
576            assert!(!is_registry_agent("my-extension-agent", cx));
577        });
578    }
579
580    #[gpui::test]
581    fn test_default_settings_for_extension_agent(cx: &mut TestAppContext) {
582        init_test(cx);
583        cx.update(|cx| {
584            assert!(matches!(
585                default_settings_for_agent("some-extension-agent", cx),
586                settings::CustomAgentServerSettings::Extension { .. }
587            ));
588        });
589    }
590
591    #[gpui::test]
592    fn test_default_settings_for_agent_in_registry(cx: &mut TestAppContext) {
593        init_test(cx);
594        init_registry_with_agents(cx, &["new-registry-agent"]);
595        cx.update(|cx| {
596            assert!(matches!(
597                default_settings_for_agent("new-registry-agent", cx),
598                settings::CustomAgentServerSettings::Registry { .. }
599            ));
600            assert!(matches!(
601                default_settings_for_agent("not-in-registry", cx),
602                settings::CustomAgentServerSettings::Extension { .. }
603            ));
604        });
605    }
606}