1use std::{any::Any, rc::Rc, sync::Arc};
2
3use agent_client_protocol as acp;
4use agent_servers::{AgentServer, AgentServerDelegate};
5use agent_settings::AgentSettings;
6use anyhow::Result;
7use collections::HashSet;
8use fs::Fs;
9use gpui::{App, Entity, Task};
10use project::AgentId;
11use prompt_store::PromptStore;
12use settings::{LanguageModelSelection, Settings as _, update_settings_file};
13
14use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
15
16#[derive(Clone)]
17pub struct NativeAgentServer {
18 fs: Arc<dyn Fs>,
19 thread_store: Entity<ThreadStore>,
20}
21
22impl NativeAgentServer {
23 pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
24 Self { fs, thread_store }
25 }
26}
27
28impl AgentServer for NativeAgentServer {
29 fn agent_id(&self) -> AgentId {
30 crate::ZED_AGENT_ID.clone()
31 }
32
33 fn logo(&self) -> ui::IconName {
34 ui::IconName::ZedAgent
35 }
36
37 fn connect(
38 &self,
39 _delegate: AgentServerDelegate,
40 cx: &mut App,
41 ) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
42 log::debug!("NativeAgentServer::connect");
43 let fs = self.fs.clone();
44 let thread_store = self.thread_store.clone();
45 let prompt_store = PromptStore::global(cx);
46 cx.spawn(async move |cx| {
47 log::debug!("Creating templates for native agent");
48 let templates = Templates::new();
49 let prompt_store = prompt_store.await?;
50
51 log::debug!("Creating native agent entity");
52 let agent = cx
53 .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
54
55 // Create the connection wrapper
56 let connection = NativeAgentConnection(agent);
57 log::debug!("NativeAgentServer connection established successfully");
58
59 Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
60 })
61 }
62
63 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
64 self
65 }
66
67 fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
68 AgentSettings::get_global(cx).favorite_model_ids()
69 }
70
71 fn toggle_favorite_model(
72 &self,
73 model_id: acp::ModelId,
74 should_be_favorite: bool,
75 fs: Arc<dyn Fs>,
76 cx: &App,
77 ) {
78 let selection = model_id_to_selection(&model_id);
79 update_settings_file(fs, cx, move |settings, _| {
80 let agent = settings.agent.get_or_insert_default();
81 if should_be_favorite {
82 agent.add_favorite_model(selection.clone());
83 } else {
84 agent.remove_favorite_model(&selection);
85 }
86 });
87 }
88}
89
90/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
91fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
92 let id = model_id.0.as_ref();
93 let (provider, model) = id.split_once('/').unwrap_or(("", id));
94 LanguageModelSelection {
95 provider: provider.to_owned().into(),
96 model: model.to_owned(),
97 enable_thinking: false,
98 effort: None,
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 use gpui::AppContext;
107
108 agent_servers::e2e_tests::common_e2e_tests!(
109 async |fs, cx| {
110 let auth = cx.update(|cx| {
111 prompt_store::init(cx);
112 let registry = language_model::LanguageModelRegistry::read_global(cx);
113 let auth = registry
114 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
115 .unwrap()
116 .authenticate(cx);
117
118 cx.spawn(async move |_| auth.await)
119 });
120
121 auth.await.unwrap();
122
123 cx.update(|cx| {
124 let registry = language_model::LanguageModelRegistry::global(cx);
125
126 registry.update(cx, |registry, cx| {
127 registry.select_default_model(
128 Some(&language_model::SelectedModel {
129 provider: language_model::ANTHROPIC_PROVIDER_ID,
130 model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
131 }),
132 cx,
133 );
134 });
135 });
136
137 let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
138
139 NativeAgentServer::new(fs.clone(), thread_store)
140 },
141 allow_option_id = "allow"
142 );
143}