native_agent_server.rs

  1use std::{any::Any, path::Path, rc::Rc, sync::Arc};
  2
  3use agent_servers::{AgentServer, AgentServerDelegate};
  4use anyhow::Result;
  5use fs::Fs;
  6use gpui::{App, Entity, SharedString, Task};
  7use prompt_store::PromptStore;
  8
  9use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
 10
 11#[derive(Clone)]
 12pub struct NativeAgentServer {
 13    fs: Arc<dyn Fs>,
 14    history: Entity<HistoryStore>,
 15}
 16
 17impl NativeAgentServer {
 18    pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
 19        Self { fs, history }
 20    }
 21}
 22
 23impl AgentServer for NativeAgentServer {
 24    fn name(&self) -> SharedString {
 25        "Zed Agent".into()
 26    }
 27
 28    fn logo(&self) -> ui::IconName {
 29        ui::IconName::ZedAgent
 30    }
 31
 32    fn connect(
 33        &self,
 34        _root_dir: Option<&Path>,
 35        delegate: AgentServerDelegate,
 36        cx: &mut App,
 37    ) -> Task<
 38        Result<(
 39            Rc<dyn acp_thread::AgentConnection>,
 40            Option<task::SpawnInTerminal>,
 41        )>,
 42    > {
 43        log::debug!(
 44            "NativeAgentServer::connect called for path: {:?}",
 45            _root_dir
 46        );
 47        let project = delegate.project().clone();
 48        let fs = self.fs.clone();
 49        let history = self.history.clone();
 50        let prompt_store = PromptStore::global(cx);
 51        cx.spawn(async move |cx| {
 52            log::debug!("Creating templates for native agent");
 53            let templates = Templates::new();
 54            let prompt_store = prompt_store.await?;
 55
 56            log::debug!("Creating native agent entity");
 57            let agent =
 58                NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
 59
 60            // Create the connection wrapper
 61            let connection = NativeAgentConnection(agent);
 62            log::debug!("NativeAgentServer connection established successfully");
 63
 64            Ok((
 65                Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
 66                None,
 67            ))
 68        })
 69    }
 70
 71    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 72        self
 73    }
 74}
 75
 76#[cfg(test)]
 77mod tests {
 78    use super::*;
 79
 80    use assistant_text_thread::TextThreadStore;
 81    use gpui::AppContext;
 82
 83    agent_servers::e2e_tests::common_e2e_tests!(
 84        async |fs, project, cx| {
 85            let auth = cx.update(|cx| {
 86                prompt_store::init(cx);
 87                let registry = language_model::LanguageModelRegistry::read_global(cx);
 88                let auth = registry
 89                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
 90                    .unwrap()
 91                    .authenticate(cx);
 92
 93                cx.spawn(async move |_| auth.await)
 94            });
 95
 96            auth.await.unwrap();
 97
 98            cx.update(|cx| {
 99                let registry = language_model::LanguageModelRegistry::global(cx);
100
101                registry.update(cx, |registry, cx| {
102                    registry.select_default_model(
103                        Some(&language_model::SelectedModel {
104                            provider: language_model::ANTHROPIC_PROVIDER_ID,
105                            model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
106                        }),
107                        cx,
108                    );
109                });
110            });
111
112            let history = cx.update(|cx| {
113                let text_thread_store =
114                    cx.new(move |cx| TextThreadStore::fake(project.clone(), cx));
115                cx.new(move |cx| HistoryStore::new(text_thread_store, cx))
116            });
117
118            NativeAgentServer::new(fs.clone(), history)
119        },
120        allow_option_id = "allow"
121    );
122}