native_agent_server.rs

  1use std::{any::Any, rc::Rc, sync::Arc};
  2
  3use agent_client_protocol as acp;
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use agent_settings::AgentSettings;
  6use anyhow::Result;
  7use collections::HashSet;
  8use fs::Fs;
  9use gpui::{App, Entity, SharedString, Task};
 10use prompt_store::PromptStore;
 11use settings::{LanguageModelSelection, Settings as _, update_settings_file};
 12
 13use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
 14
 15#[derive(Clone)]
 16pub struct NativeAgentServer {
 17    fs: Arc<dyn Fs>,
 18    thread_store: Entity<ThreadStore>,
 19}
 20
 21impl NativeAgentServer {
 22    pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
 23        Self { fs, thread_store }
 24    }
 25}
 26
 27impl AgentServer for NativeAgentServer {
 28    fn name(&self) -> SharedString {
 29        "Zed Agent".into()
 30    }
 31
 32    fn logo(&self) -> ui::IconName {
 33        ui::IconName::ZedAgent
 34    }
 35
 36    fn connect(
 37        &self,
 38        delegate: AgentServerDelegate,
 39        cx: &mut App,
 40    ) -> Task<
 41        Result<(
 42            Rc<dyn acp_thread::AgentConnection>,
 43            Option<task::SpawnInTerminal>,
 44        )>,
 45    > {
 46        log::debug!("NativeAgentServer::connect");
 47        let project = delegate.project().clone();
 48        let fs = self.fs.clone();
 49        let thread_store = self.thread_store.clone();
 50        let prompt_store = PromptStore::global(cx);
 51        cx.spawn(async move |cx| {
 52            log::debug!("Creating templates for native agent");
 53            let templates = Templates::new();
 54            let prompt_store = prompt_store.await?;
 55
 56            log::debug!("Creating native agent entity");
 57            let agent =
 58                NativeAgent::new(project, thread_store, templates, Some(prompt_store), fs, cx)
 59                    .await?;
 60
 61            // Create the connection wrapper
 62            let connection = NativeAgentConnection(agent);
 63            log::debug!("NativeAgentServer connection established successfully");
 64
 65            Ok((
 66                Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
 67                None,
 68            ))
 69        })
 70    }
 71
 72    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 73        self
 74    }
 75
 76    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
 77        AgentSettings::get_global(cx).favorite_model_ids()
 78    }
 79
 80    fn toggle_favorite_model(
 81        &self,
 82        model_id: acp::ModelId,
 83        should_be_favorite: bool,
 84        fs: Arc<dyn Fs>,
 85        cx: &App,
 86    ) {
 87        let selection = model_id_to_selection(&model_id);
 88        update_settings_file(fs, cx, move |settings, _| {
 89            let agent = settings.agent.get_or_insert_default();
 90            if should_be_favorite {
 91                agent.add_favorite_model(selection.clone());
 92            } else {
 93                agent.remove_favorite_model(&selection);
 94            }
 95        });
 96    }
 97}
 98
 99/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
100fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
101    let id = model_id.0.as_ref();
102    let (provider, model) = id.split_once('/').unwrap_or(("", id));
103    LanguageModelSelection {
104        provider: provider.to_owned().into(),
105        model: model.to_owned(),
106        enable_thinking: false,
107        effort: None,
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    use gpui::AppContext;
116
117    agent_servers::e2e_tests::common_e2e_tests!(
118        async |fs, cx| {
119            let auth = cx.update(|cx| {
120                prompt_store::init(cx);
121                let registry = language_model::LanguageModelRegistry::read_global(cx);
122                let auth = registry
123                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
124                    .unwrap()
125                    .authenticate(cx);
126
127                cx.spawn(async move |_| auth.await)
128            });
129
130            auth.await.unwrap();
131
132            cx.update(|cx| {
133                let registry = language_model::LanguageModelRegistry::global(cx);
134
135                registry.update(cx, |registry, cx| {
136                    registry.select_default_model(
137                        Some(&language_model::SelectedModel {
138                            provider: language_model::ANTHROPIC_PROVIDER_ID,
139                            model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
140                        }),
141                        cx,
142                    );
143                });
144            });
145
146            let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
147
148            NativeAgentServer::new(fs.clone(), thread_store)
149        },
150        allow_option_id = "allow"
151    );
152}