1use acp_thread::ModelSelector;
2use agent_client_protocol as acp;
3use anyhow::Result;
4use gpui::{App, AppContext, AsyncApp, Entity, Task};
5use language_model::{LanguageModel, LanguageModelRegistry};
6use project::Project;
7use std::collections::HashMap;
8use std::path::Path;
9use std::rc::Rc;
10use std::sync::Arc;
11
12use crate::{templates::Templates, Thread};
13
14pub struct Agent {
15 /// Session ID -> Thread entity mapping
16 sessions: HashMap<acp::SessionId, Entity<Thread>>,
17 /// Shared templates for all threads
18 templates: Arc<Templates>,
19}
20
21impl Agent {
22 pub fn new(templates: Arc<Templates>) -> Self {
23 Self {
24 sessions: HashMap::new(),
25 templates,
26 }
27 }
28}
29
30/// Wrapper struct that implements the AgentConnection trait
31#[derive(Clone)]
32pub struct AgentConnection(pub Entity<Agent>);
33
34impl ModelSelector for AgentConnection {
35 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
36 let result = cx.update(|cx| {
37 let registry = LanguageModelRegistry::read_global(cx);
38 let models = registry.available_models(cx).collect::<Vec<_>>();
39 if models.is_empty() {
40 Err(anyhow::anyhow!("No models available"))
41 } else {
42 Ok(models)
43 }
44 });
45 Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
46 }
47
48 fn select_model(
49 &self,
50 session_id: &acp::SessionId,
51 model: Arc<dyn LanguageModel>,
52 cx: &mut AsyncApp,
53 ) -> Task<Result<()>> {
54 let agent = self.0.clone();
55 let result = agent.update(cx, |agent, cx| {
56 if let Some(thread) = agent.sessions.get(session_id) {
57 thread.update(cx, |thread, _| {
58 thread.selected_model = model;
59 });
60 Ok(())
61 } else {
62 Err(anyhow::anyhow!("Session not found"))
63 }
64 });
65 Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
66 }
67
68 fn selected_model(
69 &self,
70 session_id: &acp::SessionId,
71 cx: &mut AsyncApp,
72 ) -> Task<Result<Arc<dyn LanguageModel>>> {
73 let agent = self.0.clone();
74 let thread_result = agent
75 .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
76 .ok()
77 .flatten()
78 .ok_or_else(|| anyhow::anyhow!("Session not found"));
79
80 match thread_result {
81 Ok(thread) => {
82 let selected = thread
83 .read_with(cx, |thread, _| thread.selected_model.clone())
84 .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
85 Task::ready(Ok(selected))
86 }
87 Err(e) => Task::ready(Err(e)),
88 }
89 }
90}
91
92impl acp_thread::AgentConnection for AgentConnection {
93 fn new_thread(
94 self: Rc<Self>,
95 project: Entity<Project>,
96 cwd: &Path,
97 cx: &mut AsyncApp,
98 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
99 let _cwd = cwd.to_owned();
100 let agent = self.0.clone();
101
102 cx.spawn(async move |cx| {
103 // Create Thread and store in Agent
104 let (session_id, _thread) =
105 agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
106 // Fetch default model
107 let default_model = LanguageModelRegistry::read_global(cx)
108 .available_models(cx)
109 .next()
110 .unwrap_or_else(|| panic!("No default model available"));
111
112 let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
113 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
114 agent.sessions.insert(session_id.clone(), thread.clone());
115 (session_id, thread)
116 })?;
117
118 // Create AcpThread
119 let acp_thread = cx.update(|cx| {
120 cx.new(|cx| {
121 acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
122 })
123 })?;
124
125 Ok(acp_thread)
126 })
127 }
128
129 fn auth_methods(&self) -> &[acp::AuthMethod] {
130 &[] // No auth for in-process
131 }
132
133 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
134 Task::ready(Ok(()))
135 }
136
137 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
138 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
139 }
140
141 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
142 let session_id = params.session_id.clone();
143 let agent = self.0.clone();
144
145 cx.spawn(async move |cx| {
146 // Get thread
147 let thread: Entity<Thread> = agent
148 .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
149 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
150
151 // Convert prompt to message
152 let message = convert_prompt_to_message(params.prompt);
153
154 // Get model using the ModelSelector capability (always available for agent2)
155 // Get the selected model from the thread directly
156 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
157
158 // Send to thread
159 thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
160
161 Ok(())
162 })
163 }
164
165 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
166 self.0.update(cx, |agent, _cx| {
167 agent.sessions.remove(session_id);
168 });
169 }
170}
171
172/// Convert ACP content blocks to a message string
173fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
174 let mut message = String::new();
175
176 for block in blocks {
177 match block {
178 acp::ContentBlock::Text(text) => {
179 message.push_str(&text.text);
180 }
181 acp::ContentBlock::ResourceLink(link) => {
182 message.push_str(&format!(" @{} ", link.uri));
183 }
184 acp::ContentBlock::Image(_) => {
185 message.push_str(" [image] ");
186 }
187 acp::ContentBlock::Audio(_) => {
188 message.push_str(" [audio] ");
189 }
190 acp::ContentBlock::Resource(resource) => {
191 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
192 }
193 }
194 }
195
196 message
197}