native_agent_server.rs

  1use std::{any::Any, path::Path, rc::Rc, sync::Arc};
  2
  3use agent_client_protocol as acp;
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use agent_settings::AgentSettings;
  6use anyhow::Result;
  7use collections::HashSet;
  8use fs::Fs;
  9use gpui::{App, Entity, SharedString, Task};
 10use prompt_store::PromptStore;
 11use settings::{LanguageModelSelection, Settings as _, update_settings_file};
 12
 13use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
 14
 15#[derive(Clone)]
 16pub struct NativeAgentServer {
 17    fs: Arc<dyn Fs>,
 18    history: Entity<HistoryStore>,
 19}
 20
 21impl NativeAgentServer {
 22    pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
 23        Self { fs, history }
 24    }
 25}
 26
 27impl AgentServer for NativeAgentServer {
 28    fn name(&self) -> SharedString {
 29        "Zed Agent".into()
 30    }
 31
 32    fn logo(&self) -> ui::IconName {
 33        ui::IconName::ZedAgent
 34    }
 35
 36    fn connect(
 37        &self,
 38        _root_dir: Option<&Path>,
 39        delegate: AgentServerDelegate,
 40        cx: &mut App,
 41    ) -> Task<
 42        Result<(
 43            Rc<dyn acp_thread::AgentConnection>,
 44            Option<task::SpawnInTerminal>,
 45        )>,
 46    > {
 47        log::debug!(
 48            "NativeAgentServer::connect called for path: {:?}",
 49            _root_dir
 50        );
 51        let project = delegate.project().clone();
 52        let fs = self.fs.clone();
 53        let history = self.history.clone();
 54        let prompt_store = PromptStore::global(cx);
 55        cx.spawn(async move |cx| {
 56            log::debug!("Creating templates for native agent");
 57            let templates = Templates::new();
 58            let prompt_store = prompt_store.await?;
 59
 60            log::debug!("Creating native agent entity");
 61            let agent =
 62                NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
 63
 64            // Create the connection wrapper
 65            let connection = NativeAgentConnection(agent);
 66            log::debug!("NativeAgentServer connection established successfully");
 67
 68            Ok((
 69                Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
 70                None,
 71            ))
 72        })
 73    }
 74
 75    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 76        self
 77    }
 78
 79    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
 80        AgentSettings::get_global(cx).favorite_model_ids()
 81    }
 82
 83    fn toggle_favorite_model(
 84        &self,
 85        model_id: acp::ModelId,
 86        should_be_favorite: bool,
 87        fs: Arc<dyn Fs>,
 88        cx: &App,
 89    ) {
 90        let selection = model_id_to_selection(&model_id);
 91        update_settings_file(fs, cx, move |settings, _| {
 92            let agent = settings.agent.get_or_insert_default();
 93            if should_be_favorite {
 94                agent.add_favorite_model(selection.clone());
 95            } else {
 96                agent.remove_favorite_model(&selection);
 97            }
 98        });
 99    }
100}
101
102/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
103fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
104    let id = model_id.0.as_ref();
105    let (provider, model) = id.split_once('/').unwrap_or(("", id));
106    LanguageModelSelection {
107        provider: provider.to_owned().into(),
108        model: model.to_owned(),
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    use assistant_text_thread::TextThreadStore;
117    use gpui::AppContext;
118
119    agent_servers::e2e_tests::common_e2e_tests!(
120        async |fs, project, cx| {
121            let auth = cx.update(|cx| {
122                prompt_store::init(cx);
123                let registry = language_model::LanguageModelRegistry::read_global(cx);
124                let auth = registry
125                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
126                    .unwrap()
127                    .authenticate(cx);
128
129                cx.spawn(async move |_| auth.await)
130            });
131
132            auth.await.unwrap();
133
134            cx.update(|cx| {
135                let registry = language_model::LanguageModelRegistry::global(cx);
136
137                registry.update(cx, |registry, cx| {
138                    registry.select_default_model(
139                        Some(&language_model::SelectedModel {
140                            provider: language_model::ANTHROPIC_PROVIDER_ID,
141                            model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
142                        }),
143                        cx,
144                    );
145                });
146            });
147
148            let history = cx.update(|cx| {
149                let text_thread_store =
150                    cx.new(move |cx| TextThreadStore::fake(project.clone(), cx));
151                cx.new(move |cx| HistoryStore::new(text_thread_store, cx))
152            });
153
154            NativeAgentServer::new(fs.clone(), history)
155        },
156        allow_option_id = "allow"
157    );
158}