1use std::{any::Any, rc::Rc, sync::Arc};
2
3use agent_client_protocol as acp;
4use agent_servers::{AgentServer, AgentServerDelegate};
5use agent_settings::AgentSettings;
6use anyhow::Result;
7use collections::HashSet;
8use fs::Fs;
9use gpui::{App, Entity, SharedString, Task};
10use prompt_store::PromptStore;
11use settings::{LanguageModelSelection, Settings as _, update_settings_file};
12
13use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
14
15#[derive(Clone)]
16pub struct NativeAgentServer {
17 fs: Arc<dyn Fs>,
18 thread_store: Entity<ThreadStore>,
19}
20
21impl NativeAgentServer {
22 pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
23 Self { fs, thread_store }
24 }
25}
26
27impl AgentServer for NativeAgentServer {
28 fn name(&self) -> SharedString {
29 "Zed Agent".into()
30 }
31
32 fn logo(&self) -> ui::IconName {
33 ui::IconName::ZedAgent
34 }
35
36 fn connect(
37 &self,
38 delegate: AgentServerDelegate,
39 cx: &mut App,
40 ) -> Task<
41 Result<(
42 Rc<dyn acp_thread::AgentConnection>,
43 Option<task::SpawnInTerminal>,
44 )>,
45 > {
46 log::debug!("NativeAgentServer::connect");
47 let project = delegate.project().clone();
48 let fs = self.fs.clone();
49 let thread_store = self.thread_store.clone();
50 let prompt_store = PromptStore::global(cx);
51 cx.spawn(async move |cx| {
52 log::debug!("Creating templates for native agent");
53 let templates = Templates::new();
54 let prompt_store = prompt_store.await?;
55
56 log::debug!("Creating native agent entity");
57 let agent =
58 NativeAgent::new(project, thread_store, templates, Some(prompt_store), fs, cx)
59 .await?;
60
61 // Create the connection wrapper
62 let connection = NativeAgentConnection(agent);
63 log::debug!("NativeAgentServer connection established successfully");
64
65 Ok((
66 Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
67 None,
68 ))
69 })
70 }
71
72 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
73 self
74 }
75
76 fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
77 AgentSettings::get_global(cx).favorite_model_ids()
78 }
79
80 fn toggle_favorite_model(
81 &self,
82 model_id: acp::ModelId,
83 should_be_favorite: bool,
84 fs: Arc<dyn Fs>,
85 cx: &App,
86 ) {
87 let selection = model_id_to_selection(&model_id);
88 update_settings_file(fs, cx, move |settings, _| {
89 let agent = settings.agent.get_or_insert_default();
90 if should_be_favorite {
91 agent.add_favorite_model(selection.clone());
92 } else {
93 agent.remove_favorite_model(&selection);
94 }
95 });
96 }
97}
98
99/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
100fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
101 let id = model_id.0.as_ref();
102 let (provider, model) = id.split_once('/').unwrap_or(("", id));
103 LanguageModelSelection {
104 provider: provider.to_owned().into(),
105 model: model.to_owned(),
106 enable_thinking: false,
107 effort: None,
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 use gpui::AppContext;
116
117 agent_servers::e2e_tests::common_e2e_tests!(
118 async |fs, cx| {
119 let auth = cx.update(|cx| {
120 prompt_store::init(cx);
121 let registry = language_model::LanguageModelRegistry::read_global(cx);
122 let auth = registry
123 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
124 .unwrap()
125 .authenticate(cx);
126
127 cx.spawn(async move |_| auth.await)
128 });
129
130 auth.await.unwrap();
131
132 cx.update(|cx| {
133 let registry = language_model::LanguageModelRegistry::global(cx);
134
135 registry.update(cx, |registry, cx| {
136 registry.select_default_model(
137 Some(&language_model::SelectedModel {
138 provider: language_model::ANTHROPIC_PROVIDER_ID,
139 model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
140 }),
141 cx,
142 );
143 });
144 });
145
146 let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
147
148 NativeAgentServer::new(fs.clone(), thread_store)
149 },
150 allow_option_id = "allow"
151 );
152}