1use std::{any::Any, rc::Rc, sync::Arc};
2
3use agent_client_protocol as acp;
4use agent_servers::{AgentServer, AgentServerDelegate};
5use agent_settings::{AgentSettings, language_model_to_selection};
6use anyhow::Result;
7use collections::HashSet;
8use fs::Fs;
9use gpui::{App, Entity, Task};
10use language_model::{LanguageModelId, LanguageModelProviderId, LanguageModelRegistry};
11use project::{AgentId, Project};
12use prompt_store::PromptStore;
13use settings::{LanguageModelSelection, Settings as _, update_settings_file};
14
15use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
16
17#[derive(Clone)]
18pub struct NativeAgentServer {
19 fs: Arc<dyn Fs>,
20 thread_store: Entity<ThreadStore>,
21}
22
23impl NativeAgentServer {
24 pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
25 Self { fs, thread_store }
26 }
27}
28
29impl AgentServer for NativeAgentServer {
30 fn agent_id(&self) -> AgentId {
31 crate::ZED_AGENT_ID.clone()
32 }
33
34 fn logo(&self) -> ui::IconName {
35 ui::IconName::ZedAgent
36 }
37
38 fn connect(
39 &self,
40 _delegate: AgentServerDelegate,
41 _project: Entity<Project>,
42 cx: &mut App,
43 ) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
44 log::debug!("NativeAgentServer::connect");
45 let fs = self.fs.clone();
46 let thread_store = self.thread_store.clone();
47 let prompt_store = PromptStore::global(cx);
48 cx.spawn(async move |cx| {
49 log::debug!("Creating templates for native agent");
50 let templates = Templates::new();
51 let prompt_store = prompt_store.await?;
52
53 log::debug!("Creating native agent entity");
54 let agent = cx
55 .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
56
57 // Create the connection wrapper
58 let connection = NativeAgentConnection(agent);
59 log::debug!("NativeAgentServer connection established successfully");
60
61 Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
62 })
63 }
64
65 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
66 self
67 }
68
69 fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
70 AgentSettings::get_global(cx).favorite_model_ids()
71 }
72
73 fn toggle_favorite_model(
74 &self,
75 model_id: acp::ModelId,
76 should_be_favorite: bool,
77 fs: Arc<dyn Fs>,
78 cx: &App,
79 ) {
80 let selection = model_id_to_selection(&model_id, cx);
81 update_settings_file(fs, cx, move |settings, _| {
82 let agent = settings.agent.get_or_insert_default();
83 if should_be_favorite {
84 agent.add_favorite_model(selection.clone());
85 } else {
86 agent.remove_favorite_model(&selection);
87 }
88 });
89 }
90}
91
92/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
93fn model_id_to_selection(model_id: &acp::ModelId, cx: &App) -> LanguageModelSelection {
94 let id = model_id.0.as_ref();
95 let (provider, model) = id.split_once('/').unwrap_or(("", id));
96
97 let provider_id = LanguageModelProviderId(provider.to_string().into());
98 let model_id_typed = LanguageModelId(model.to_string().into());
99 let resolved = LanguageModelRegistry::global(cx)
100 .read(cx)
101 .provider(&provider_id)
102 .and_then(|p| {
103 p.provided_models(cx)
104 .into_iter()
105 .find(|m| m.id() == model_id_typed)
106 });
107
108 let Some(resolved) = resolved else {
109 return LanguageModelSelection {
110 provider: provider.to_owned().into(),
111 model: model.to_owned(),
112 enable_thinking: false,
113 effort: None,
114 speed: None,
115 };
116 };
117
118 let current_user_selection = AgentSettings::get_global(cx)
119 .default_model
120 .as_ref()
121 .filter(|selection| {
122 selection.provider.0 == resolved.provider_id().0.as_ref()
123 && selection.model == resolved.id().0.as_ref()
124 })
125 .cloned();
126
127 language_model_to_selection(&resolved, current_user_selection.as_ref())
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 use gpui::AppContext;
135
136 agent_servers::e2e_tests::common_e2e_tests!(
137 async |fs, cx| {
138 let auth = cx.update(|cx| {
139 prompt_store::init(cx);
140 let registry = language_model::LanguageModelRegistry::read_global(cx);
141 let auth = registry
142 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
143 .unwrap()
144 .authenticate(cx);
145
146 cx.spawn(async move |_| auth.await)
147 });
148
149 auth.await.unwrap();
150
151 cx.update(|cx| {
152 let registry = language_model::LanguageModelRegistry::global(cx);
153
154 registry.update(cx, |registry, cx| {
155 registry.select_default_model(
156 Some(&language_model::SelectedModel {
157 provider: language_model::ANTHROPIC_PROVIDER_ID,
158 model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
159 }),
160 cx,
161 );
162 });
163 });
164
165 let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
166
167 NativeAgentServer::new(fs.clone(), thread_store)
168 },
169 allow_option_id = "allow"
170 );
171}