native_agent_server.rs

  1use std::{any::Any, rc::Rc, sync::Arc};
  2
  3use agent_client_protocol as acp;
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use agent_settings::AgentSettings;
  6use anyhow::Result;
  7use collections::HashSet;
  8use fs::Fs;
  9use gpui::{App, Entity, Task};
 10use project::AgentId;
 11use prompt_store::PromptStore;
 12use settings::{LanguageModelSelection, Settings as _, update_settings_file};
 13
 14use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
 15
 16#[derive(Clone)]
 17pub struct NativeAgentServer {
 18    fs: Arc<dyn Fs>,
 19    thread_store: Entity<ThreadStore>,
 20}
 21
 22impl NativeAgentServer {
 23    pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
 24        Self { fs, thread_store }
 25    }
 26}
 27
 28impl AgentServer for NativeAgentServer {
 29    fn agent_id(&self) -> AgentId {
 30        crate::ZED_AGENT_ID.clone()
 31    }
 32
 33    fn logo(&self) -> ui::IconName {
 34        ui::IconName::ZedAgent
 35    }
 36
 37    fn connect(
 38        &self,
 39        _delegate: AgentServerDelegate,
 40        cx: &mut App,
 41    ) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
 42        log::debug!("NativeAgentServer::connect");
 43        let fs = self.fs.clone();
 44        let thread_store = self.thread_store.clone();
 45        let prompt_store = PromptStore::global(cx);
 46        cx.spawn(async move |cx| {
 47            log::debug!("Creating templates for native agent");
 48            let templates = Templates::new();
 49            let prompt_store = prompt_store.await?;
 50
 51            log::debug!("Creating native agent entity");
 52            let agent = cx
 53                .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
 54
 55            // Create the connection wrapper
 56            let connection = NativeAgentConnection(agent);
 57            log::debug!("NativeAgentServer connection established successfully");
 58
 59            Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
 60        })
 61    }
 62
 63    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 64        self
 65    }
 66
 67    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
 68        AgentSettings::get_global(cx).favorite_model_ids()
 69    }
 70
 71    fn toggle_favorite_model(
 72        &self,
 73        model_id: acp::ModelId,
 74        should_be_favorite: bool,
 75        fs: Arc<dyn Fs>,
 76        cx: &App,
 77    ) {
 78        let selection = model_id_to_selection(&model_id);
 79        update_settings_file(fs, cx, move |settings, _| {
 80            let agent = settings.agent.get_or_insert_default();
 81            if should_be_favorite {
 82                agent.add_favorite_model(selection.clone());
 83            } else {
 84                agent.remove_favorite_model(&selection);
 85            }
 86        });
 87    }
 88}
 89
 90/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
 91fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
 92    let id = model_id.0.as_ref();
 93    let (provider, model) = id.split_once('/').unwrap_or(("", id));
 94    LanguageModelSelection {
 95        provider: provider.to_owned().into(),
 96        model: model.to_owned(),
 97        enable_thinking: false,
 98        effort: None,
 99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    use gpui::AppContext;
107
108    agent_servers::e2e_tests::common_e2e_tests!(
109        async |fs, cx| {
110            let auth = cx.update(|cx| {
111                prompt_store::init(cx);
112                let registry = language_model::LanguageModelRegistry::read_global(cx);
113                let auth = registry
114                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
115                    .unwrap()
116                    .authenticate(cx);
117
118                cx.spawn(async move |_| auth.await)
119            });
120
121            auth.await.unwrap();
122
123            cx.update(|cx| {
124                let registry = language_model::LanguageModelRegistry::global(cx);
125
126                registry.update(cx, |registry, cx| {
127                    registry.select_default_model(
128                        Some(&language_model::SelectedModel {
129                            provider: language_model::ANTHROPIC_PROVIDER_ID,
130                            model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
131                        }),
132                        cx,
133                    );
134                });
135            });
136
137            let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
138
139            NativeAgentServer::new(fs.clone(), thread_store)
140        },
141        allow_option_id = "allow"
142    );
143}