1use std::{any::Any, path::Path, rc::Rc, sync::Arc};
2
3use agent_client_protocol as acp;
4use agent_servers::{AgentServer, AgentServerDelegate};
5use agent_settings::AgentSettings;
6use anyhow::Result;
7use collections::HashSet;
8use fs::Fs;
9use gpui::{App, Entity, SharedString, Task};
10use prompt_store::PromptStore;
11use settings::{LanguageModelSelection, Settings as _, update_settings_file};
12
13use crate::{HistoryStore, NativeAgent, NativeAgentConnection, templates::Templates};
14
15#[derive(Clone)]
16pub struct NativeAgentServer {
17 fs: Arc<dyn Fs>,
18 history: Entity<HistoryStore>,
19}
20
21impl NativeAgentServer {
22 pub fn new(fs: Arc<dyn Fs>, history: Entity<HistoryStore>) -> Self {
23 Self { fs, history }
24 }
25}
26
27impl AgentServer for NativeAgentServer {
28 fn name(&self) -> SharedString {
29 "Zed Agent".into()
30 }
31
32 fn logo(&self) -> ui::IconName {
33 ui::IconName::ZedAgent
34 }
35
36 fn connect(
37 &self,
38 _root_dir: Option<&Path>,
39 delegate: AgentServerDelegate,
40 cx: &mut App,
41 ) -> Task<
42 Result<(
43 Rc<dyn acp_thread::AgentConnection>,
44 Option<task::SpawnInTerminal>,
45 )>,
46 > {
47 log::debug!(
48 "NativeAgentServer::connect called for path: {:?}",
49 _root_dir
50 );
51 let project = delegate.project().clone();
52 let fs = self.fs.clone();
53 let history = self.history.clone();
54 let prompt_store = PromptStore::global(cx);
55 cx.spawn(async move |cx| {
56 log::debug!("Creating templates for native agent");
57 let templates = Templates::new();
58 let prompt_store = prompt_store.await?;
59
60 log::debug!("Creating native agent entity");
61 let agent =
62 NativeAgent::new(project, history, templates, Some(prompt_store), fs, cx).await?;
63
64 // Create the connection wrapper
65 let connection = NativeAgentConnection(agent);
66 log::debug!("NativeAgentServer connection established successfully");
67
68 Ok((
69 Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>,
70 None,
71 ))
72 })
73 }
74
75 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
76 self
77 }
78
79 fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
80 AgentSettings::get_global(cx).favorite_model_ids()
81 }
82
83 fn toggle_favorite_model(
84 &self,
85 model_id: acp::ModelId,
86 should_be_favorite: bool,
87 fs: Arc<dyn Fs>,
88 cx: &App,
89 ) {
90 let selection = model_id_to_selection(&model_id);
91 update_settings_file(fs, cx, move |settings, _| {
92 let agent = settings.agent.get_or_insert_default();
93 if should_be_favorite {
94 agent.add_favorite_model(selection.clone());
95 } else {
96 agent.remove_favorite_model(&selection);
97 }
98 });
99 }
100}
101
102/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
103fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
104 let id = model_id.0.as_ref();
105 let (provider, model) = id.split_once('/').unwrap_or(("", id));
106 LanguageModelSelection {
107 provider: provider.to_owned().into(),
108 model: model.to_owned(),
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 use assistant_text_thread::TextThreadStore;
117 use gpui::AppContext;
118
119 agent_servers::e2e_tests::common_e2e_tests!(
120 async |fs, project, cx| {
121 let auth = cx.update(|cx| {
122 prompt_store::init(cx);
123 let registry = language_model::LanguageModelRegistry::read_global(cx);
124 let auth = registry
125 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
126 .unwrap()
127 .authenticate(cx);
128
129 cx.spawn(async move |_| auth.await)
130 });
131
132 auth.await.unwrap();
133
134 cx.update(|cx| {
135 let registry = language_model::LanguageModelRegistry::global(cx);
136
137 registry.update(cx, |registry, cx| {
138 registry.select_default_model(
139 Some(&language_model::SelectedModel {
140 provider: language_model::ANTHROPIC_PROVIDER_ID,
141 model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
142 }),
143 cx,
144 );
145 });
146 });
147
148 let history = cx.update(|cx| {
149 let text_thread_store =
150 cx.new(move |cx| TextThreadStore::fake(project.clone(), cx));
151 cx.new(move |cx| HistoryStore::new(text_thread_store, cx))
152 });
153
154 NativeAgentServer::new(fs.clone(), history)
155 },
156 allow_option_id = "allow"
157 );
158}