1use acp_thread::ModelSelector;
2use agent_client_protocol as acp;
3use anyhow::{anyhow, Result};
4use futures::StreamExt;
5use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
6use language_model::{LanguageModel, LanguageModelRegistry};
7use project::Project;
8use std::collections::HashMap;
9use std::path::Path;
10use std::rc::Rc;
11use std::sync::Arc;
12
13use crate::{templates::Templates, AgentResponseEvent, Thread};
14
15/// Holds both the internal Thread and the AcpThread for a session
16struct Session {
17 /// The internal thread that processes messages
18 thread: Entity<Thread>,
19 /// The ACP thread that handles protocol communication
20 acp_thread: WeakEntity<acp_thread::AcpThread>,
21 _subscription: Subscription,
22}
23
24pub struct NativeAgent {
25 /// Session ID -> Session mapping
26 sessions: HashMap<acp::SessionId, Session>,
27 /// Shared templates for all threads
28 templates: Arc<Templates>,
29}
30
31impl NativeAgent {
32 pub fn new(templates: Arc<Templates>) -> Self {
33 log::info!("Creating new NativeAgent");
34 Self {
35 sessions: HashMap::new(),
36 templates,
37 }
38 }
39}
40
41/// Wrapper struct that implements the AgentConnection trait
42#[derive(Clone)]
43pub struct NativeAgentConnection(pub Entity<NativeAgent>);
44
45impl ModelSelector for NativeAgentConnection {
46 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
47 log::debug!("NativeAgentConnection::list_models called");
48 cx.spawn(async move |cx| {
49 cx.update(|cx| {
50 let registry = LanguageModelRegistry::read_global(cx);
51 let models = registry.available_models(cx).collect::<Vec<_>>();
52 log::info!("Found {} available models", models.len());
53 if models.is_empty() {
54 Err(anyhow::anyhow!("No models available"))
55 } else {
56 Ok(models)
57 }
58 })?
59 })
60 }
61
62 fn select_model(
63 &self,
64 session_id: acp::SessionId,
65 model: Arc<dyn LanguageModel>,
66 cx: &mut AsyncApp,
67 ) -> Task<Result<()>> {
68 log::info!(
69 "Setting model for session {}: {:?}",
70 session_id,
71 model.name()
72 );
73 let agent = self.0.clone();
74
75 cx.spawn(async move |cx| {
76 agent.update(cx, |agent, cx| {
77 if let Some(session) = agent.sessions.get(&session_id) {
78 session.thread.update(cx, |thread, _cx| {
79 thread.selected_model = model;
80 });
81 Ok(())
82 } else {
83 Err(anyhow!("Session not found"))
84 }
85 })?
86 })
87 }
88
89 fn selected_model(
90 &self,
91 session_id: &acp::SessionId,
92 cx: &mut AsyncApp,
93 ) -> Task<Result<Arc<dyn LanguageModel>>> {
94 let agent = self.0.clone();
95 let session_id = session_id.clone();
96 cx.spawn(async move |cx| {
97 let thread = agent
98 .read_with(cx, |agent, _| {
99 agent
100 .sessions
101 .get(&session_id)
102 .map(|session| session.thread.clone())
103 })?
104 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
105 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
106 Ok(selected)
107 })
108 }
109}
110
111impl acp_thread::AgentConnection for NativeAgentConnection {
112 fn new_thread(
113 self: Rc<Self>,
114 project: Entity<Project>,
115 cwd: &Path,
116 cx: &mut AsyncApp,
117 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
118 let agent = self.0.clone();
119 log::info!("Creating new thread for project at: {:?}", cwd);
120
121 cx.spawn(async move |cx| {
122 log::debug!("Starting thread creation in async context");
123 // Create Thread
124 let (session_id, thread) = agent.update(
125 cx,
126 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
127 // Fetch default model from registry settings
128 let registry = LanguageModelRegistry::read_global(cx);
129
130 // Log available models for debugging
131 let available_count = registry.available_models(cx).count();
132 log::debug!("Total available models: {}", available_count);
133
134 let default_model = registry
135 .default_model()
136 .map(|configured| {
137 log::info!(
138 "Using configured default model: {:?} from provider: {:?}",
139 configured.model.name(),
140 configured.provider.name()
141 );
142 configured.model
143 })
144 .ok_or_else(|| {
145 log::warn!("No default model configured in settings");
146 anyhow!("No default model configured. Please configure a default model in settings.")
147 })?;
148
149 let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
150
151 // Generate session ID
152 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
153 log::info!("Created session with ID: {}", session_id);
154 Ok((session_id, thread))
155 },
156 )??;
157
158 // Create AcpThread
159 let acp_thread = cx.update(|cx| {
160 cx.new(|cx| {
161 acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
162 })
163 })?;
164
165 // Store the session
166 agent.update(cx, |agent, cx| {
167 agent.sessions.insert(
168 session_id,
169 Session {
170 thread,
171 acp_thread: acp_thread.downgrade(),
172 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
173 this.sessions.remove(acp_thread.session_id());
174 })
175 },
176 );
177 })?;
178
179 Ok(acp_thread)
180 })
181 }
182
183 fn auth_methods(&self) -> &[acp::AuthMethod] {
184 &[] // No auth for in-process
185 }
186
187 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
188 Task::ready(Ok(()))
189 }
190
191 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
192 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
193 }
194
195 fn prompt(
196 &self,
197 params: acp::PromptRequest,
198 cx: &mut App,
199 ) -> Task<Result<acp::PromptResponse>> {
200 let session_id = params.session_id.clone();
201 let agent = self.0.clone();
202 log::info!("Received prompt request for session: {}", session_id);
203 log::debug!("Prompt blocks count: {}", params.prompt.len());
204
205 cx.spawn(async move |cx| {
206 // Get session
207 let (thread, acp_thread) = agent
208 .update(cx, |agent, _| {
209 agent
210 .sessions
211 .get_mut(&session_id)
212 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
213 })?
214 .ok_or_else(|| {
215 log::error!("Session not found: {}", session_id);
216 anyhow::anyhow!("Session not found")
217 })?;
218 log::debug!("Found session for: {}", session_id);
219
220 // Convert prompt to message
221 let message = convert_prompt_to_message(params.prompt);
222 log::info!("Converted prompt to message: {} chars", message.len());
223 log::debug!("Message content: {}", message);
224
225 // Get model using the ModelSelector capability (always available for agent2)
226 // Get the selected model from the thread directly
227 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
228
229 // Send to thread
230 log::info!("Sending message to thread with model: {:?}", model.name());
231 let mut response_stream =
232 thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
233
234 // Handle response stream and forward to session.acp_thread
235 while let Some(result) = response_stream.next().await {
236 match result {
237 Ok(event) => {
238 log::trace!("Received completion event: {:?}", event);
239
240 match event {
241 AgentResponseEvent::Text(text) => {
242 acp_thread.update(cx, |thread, cx| {
243 thread.handle_session_update(
244 acp::SessionUpdate::AgentMessageChunk {
245 content: acp::ContentBlock::Text(acp::TextContent {
246 text,
247 annotations: None,
248 }),
249 },
250 cx,
251 )
252 })??;
253 }
254 AgentResponseEvent::Thinking(text) => {
255 acp_thread.update(cx, |thread, cx| {
256 thread.handle_session_update(
257 acp::SessionUpdate::AgentThoughtChunk {
258 content: acp::ContentBlock::Text(acp::TextContent {
259 text,
260 annotations: None,
261 }),
262 },
263 cx,
264 )
265 })??;
266 }
267 AgentResponseEvent::ToolCall(tool_call) => {
268 acp_thread.update(cx, |thread, cx| {
269 thread.handle_session_update(
270 acp::SessionUpdate::ToolCall(tool_call),
271 cx,
272 )
273 })??;
274 }
275 AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
276 acp_thread.update(cx, |thread, cx| {
277 thread.handle_session_update(
278 acp::SessionUpdate::ToolCallUpdate(tool_call_update),
279 cx,
280 )
281 })??;
282 }
283 AgentResponseEvent::Stop(stop_reason) => {
284 log::debug!("Assistant message complete: {:?}", stop_reason);
285 return Ok(acp::PromptResponse { stop_reason });
286 }
287 }
288 }
289 Err(e) => {
290 log::error!("Error in model response stream: {:?}", e);
291 // TODO: Consider sending an error message to the UI
292 break;
293 }
294 }
295 }
296
297 log::info!("Response stream completed");
298 anyhow::Ok(acp::PromptResponse {
299 stop_reason: acp::StopReason::EndTurn,
300 })
301 })
302 }
303
304 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
305 log::info!("Cancelling on session: {}", session_id);
306 self.0.update(cx, |agent, cx| {
307 if let Some(agent) = agent.sessions.get(session_id) {
308 agent.thread.update(cx, |thread, _cx| thread.cancel());
309 }
310 });
311 }
312}
313
314/// Convert ACP content blocks to a message string
315fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
316 log::debug!("Converting {} content blocks to message", blocks.len());
317 let mut message = String::new();
318
319 for block in blocks {
320 match block {
321 acp::ContentBlock::Text(text) => {
322 log::trace!("Processing text block: {} chars", text.text.len());
323 message.push_str(&text.text);
324 }
325 acp::ContentBlock::ResourceLink(link) => {
326 log::trace!("Processing resource link: {}", link.uri);
327 message.push_str(&format!(" @{} ", link.uri));
328 }
329 acp::ContentBlock::Image(_) => {
330 log::trace!("Processing image block");
331 message.push_str(" [image] ");
332 }
333 acp::ContentBlock::Audio(_) => {
334 log::trace!("Processing audio block");
335 message.push_str(" [audio] ");
336 }
337 acp::ContentBlock::Resource(resource) => {
338 log::trace!("Processing resource block: {:?}", resource.resource);
339 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
340 }
341 }
342 }
343
344 message
345}