native_agent_server.rs

  1use std::{any::Any, rc::Rc, sync::Arc};
  2
  3use agent_client_protocol as acp;
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use agent_settings::{AgentSettings, language_model_to_selection};
  6use anyhow::Result;
  7use collections::HashSet;
  8use fs::Fs;
  9use gpui::{App, Entity, Task};
 10use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
 11use project::{AgentId, Project};
 12use prompt_store::PromptStore;
 13use settings::{LanguageModelSelection, Settings as _, update_settings_file};
 14
 15use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
 16
 17#[derive(Clone)]
 18pub struct NativeAgentServer {
 19    fs: Arc<dyn Fs>,
 20    thread_store: Entity<ThreadStore>,
 21}
 22
 23impl NativeAgentServer {
 24    pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
 25        Self { fs, thread_store }
 26    }
 27}
 28
 29impl AgentServer for NativeAgentServer {
 30    fn agent_id(&self) -> AgentId {
 31        crate::ZED_AGENT_ID.clone()
 32    }
 33
 34    fn logo(&self) -> ui::IconName {
 35        ui::IconName::ZedAgent
 36    }
 37
 38    fn connect(
 39        &self,
 40        _delegate: AgentServerDelegate,
 41        _project: Entity<Project>,
 42        cx: &mut App,
 43    ) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
 44        log::debug!("NativeAgentServer::connect");
 45        let fs = self.fs.clone();
 46        let thread_store = self.thread_store.clone();
 47        let prompt_store = PromptStore::global(cx);
 48        cx.spawn(async move |cx| {
 49            log::debug!("Creating templates for native agent");
 50            let templates = Templates::new();
 51            let prompt_store = prompt_store.await?;
 52
 53            log::debug!("Creating native agent entity");
 54            let agent = cx
 55                .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
 56
 57            // Create the connection wrapper
 58            let connection = NativeAgentConnection(agent);
 59            log::debug!("NativeAgentServer connection established successfully");
 60
 61            Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
 62        })
 63    }
 64
 65    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 66        self
 67    }
 68
 69    fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
 70        AgentSettings::get_global(cx).favorite_model_ids()
 71    }
 72
 73    fn toggle_favorite_model(
 74        &self,
 75        model_id: acp::ModelId,
 76        should_be_favorite: bool,
 77        fs: Arc<dyn Fs>,
 78        cx: &App,
 79    ) {
 80        let selection = model_id_to_selection(&model_id, cx);
 81        update_settings_file(fs, cx, move |settings, _| {
 82            let agent = settings.agent.get_or_insert_default();
 83            if should_be_favorite {
 84                agent.add_favorite_model(selection.clone());
 85            } else {
 86                agent.remove_favorite_model(&selection);
 87            }
 88        });
 89    }
 90}
 91
 92/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
 93fn model_id_to_selection(model_id: &acp::ModelId, cx: &App) -> LanguageModelSelection {
 94    let id = model_id.0.as_ref();
 95    let (provider, model) = id.split_once('/').unwrap_or(("", id));
 96
 97    let provider_id = LanguageModelProviderId(provider.to_string().into());
 98    let model_id_typed = LanguageModelId(model.to_string().into());
 99    let resolved = LanguageModelRegistry::global(cx)
100        .read(cx)
101        .provider(&provider_id)
102        .and_then(|p| {
103            p.provided_models(cx)
104                .into_iter()
105                .find(|m| m.id() == model_id_typed)
106        });
107
108    let Some(resolved) = resolved else {
109        return LanguageModelSelection {
110            provider: provider.to_owned().into(),
111            model: model.to_owned(),
112            enable_thinking: false,
113            effort: None,
114            speed: None,
115        };
116    };
117
118    let current_user_selection = AgentSettings::get_global(cx)
119        .default_model
120        .as_ref()
121        .filter(|selection| {
122            selection.provider.0 == resolved.provider_id().0.as_ref()
123                && selection.model == resolved.id().0.as_ref()
124        })
125        .cloned();
126
127    language_model_to_selection(&resolved, current_user_selection.as_ref())
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    use gpui::AppContext;
135
136    agent_servers::e2e_tests::common_e2e_tests!(
137        async |fs, cx| {
138            let auth = cx.update(|cx| {
139                prompt_store::init(cx);
140                let registry = language_model::LanguageModelRegistry::read_global(cx);
141                let auth = registry
142                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
143                    .unwrap()
144                    .authenticate(cx);
145
146                cx.spawn(async move |_| auth.await)
147            });
148
149            auth.await.unwrap();
150
151            cx.update(|cx| {
152                let registry = language_model::LanguageModelRegistry::global(cx);
153
154                registry.update(cx, |registry, cx| {
155                    registry.select_default_model(
156                        Some(&language_model::SelectedModel {
157                            provider: language_model::ANTHROPIC_PROVIDER_ID,
158                            model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
159                        }),
160                        cx,
161                    );
162                });
163            });
164
165            let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
166
167            NativeAgentServer::new(fs.clone(), thread_store)
168        },
169        allow_option_id = "allow"
170    );
171}