1use std::{any::Any, rc::Rc, sync::Arc};
2use util::ResultExt as _;
3
4use agent_client_protocol as acp;
5use agent_servers::{AgentServer, AgentServerDelegate};
6use agent_settings::AgentSettings;
7use anyhow::Result;
8use collections::HashSet;
9use fs::Fs;
10use gpui::{App, Entity, Task};
11use project::{AgentId, Project};
12use prompt_store::PromptStore;
13use settings::{LanguageModelSelection, Settings as _, update_settings_file};
14
15use crate::{NativeAgent, NativeAgentConnection, ThreadStore, templates::Templates};
16
17#[derive(Clone)]
18pub struct NativeAgentServer {
19 fs: Arc<dyn Fs>,
20 thread_store: Entity<ThreadStore>,
21}
22
23impl NativeAgentServer {
24 pub fn new(fs: Arc<dyn Fs>, thread_store: Entity<ThreadStore>) -> Self {
25 Self { fs, thread_store }
26 }
27}
28
29impl AgentServer for NativeAgentServer {
30 fn agent_id(&self) -> AgentId {
31 crate::ZED_AGENT_ID.clone()
32 }
33
34 fn logo(&self) -> ui::IconName {
35 ui::IconName::ZedAgent
36 }
37
38 fn connect(
39 &self,
40 _delegate: AgentServerDelegate,
41 _project: Entity<Project>,
42 cx: &mut App,
43 ) -> Task<Result<Rc<dyn acp_thread::AgentConnection>>> {
44 log::debug!("NativeAgentServer::connect");
45 let fs = self.fs.clone();
46 let thread_store = self.thread_store.clone();
47 let prompt_store = PromptStore::global(cx);
48 cx.spawn(async move |cx| {
49 let templates = Templates::new();
50 let prompt_store = prompt_store.await.log_err();
51
52 let agent =
53 cx.update(|cx| NativeAgent::new(thread_store, templates, prompt_store, fs, cx));
54
55 let connection = NativeAgentConnection(agent);
56
57 Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
58 })
59 }
60
61 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
62 self
63 }
64
65 fn favorite_model_ids(&self, cx: &mut App) -> HashSet<acp::ModelId> {
66 AgentSettings::get_global(cx).favorite_model_ids()
67 }
68
69 fn toggle_favorite_model(
70 &self,
71 model_id: acp::ModelId,
72 should_be_favorite: bool,
73 fs: Arc<dyn Fs>,
74 cx: &App,
75 ) {
76 let selection = model_id_to_selection(&model_id);
77 update_settings_file(fs, cx, move |settings, _| {
78 let agent = settings.agent.get_or_insert_default();
79 if should_be_favorite {
80 agent.add_favorite_model(selection.clone());
81 } else {
82 agent.remove_favorite_model(&selection);
83 }
84 });
85 }
86}
87
88/// Convert a ModelId (e.g. "anthropic/claude-3-5-sonnet") to a LanguageModelSelection.
89fn model_id_to_selection(model_id: &acp::ModelId) -> LanguageModelSelection {
90 let id = model_id.0.as_ref();
91 let (provider, model) = id.split_once('/').unwrap_or(("", id));
92 LanguageModelSelection {
93 provider: provider.to_owned().into(),
94 model: model.to_owned(),
95 enable_thinking: false,
96 effort: None,
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 use gpui::AppContext;
105
106 agent_servers::e2e_tests::common_e2e_tests!(
107 async |fs, cx| {
108 let auth = cx.update(|cx| {
109 prompt_store::init(cx);
110 let registry = language_model::LanguageModelRegistry::read_global(cx);
111 let auth = registry
112 .provider(&language_model::ANTHROPIC_PROVIDER_ID)
113 .unwrap()
114 .authenticate(cx);
115
116 cx.spawn(async move |_| auth.await)
117 });
118
119 auth.await.unwrap();
120
121 cx.update(|cx| {
122 let registry = language_model::LanguageModelRegistry::global(cx);
123
124 registry.update(cx, |registry, cx| {
125 registry.select_default_model(
126 Some(&language_model::SelectedModel {
127 provider: language_model::ANTHROPIC_PROVIDER_ID,
128 model: language_model::LanguageModelId("claude-sonnet-4-latest".into()),
129 }),
130 cx,
131 );
132 });
133 });
134
135 let thread_store = cx.update(|cx| cx.new(|cx| ThreadStore::new(cx)));
136
137 NativeAgentServer::new(fs.clone(), thread_store)
138 },
139 allow_option_id = "allow"
140 );
141}