agent.rs

  1use acp_thread::ModelSelector;
  2use agent_client_protocol as acp;
  3use anyhow::{anyhow, Result};
  4use futures::StreamExt;
  5use gpui::{App, AppContext, AsyncApp, Entity, Task};
  6use language_model::{LanguageModel, LanguageModelRegistry};
  7use project::Project;
  8use std::collections::HashMap;
  9use std::path::Path;
 10use std::rc::Rc;
 11use std::sync::Arc;
 12
 13use crate::{templates::Templates, AgentResponseEvent, Thread};
 14
 15/// Holds both the internal Thread and the AcpThread for a session
 16struct Session {
 17    /// The internal thread that processes messages
 18    thread: Entity<Thread>,
 19    /// The ACP thread that handles protocol communication
 20    acp_thread: Entity<acp_thread::AcpThread>,
 21}
 22
 23pub struct NativeAgent {
 24    /// Session ID -> Session mapping
 25    sessions: HashMap<acp::SessionId, Session>,
 26    /// Shared templates for all threads
 27    templates: Arc<Templates>,
 28}
 29
 30impl NativeAgent {
 31    pub fn new(templates: Arc<Templates>) -> Self {
 32        log::info!("Creating new NativeAgent");
 33        Self {
 34            sessions: HashMap::new(),
 35            templates,
 36        }
 37    }
 38}
 39
 40/// Wrapper struct that implements the AgentConnection trait
 41#[derive(Clone)]
 42pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 43
 44impl ModelSelector for NativeAgentConnection {
 45    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
 46        log::debug!("NativeAgentConnection::list_models called");
 47        cx.spawn(async move |cx| {
 48            cx.update(|cx| {
 49                let registry = LanguageModelRegistry::read_global(cx);
 50                let models = registry.available_models(cx).collect::<Vec<_>>();
 51                log::info!("Found {} available models", models.len());
 52                if models.is_empty() {
 53                    Err(anyhow::anyhow!("No models available"))
 54                } else {
 55                    Ok(models)
 56                }
 57            })?
 58        })
 59    }
 60
 61    fn select_model(
 62        &self,
 63        session_id: acp::SessionId,
 64        model: Arc<dyn LanguageModel>,
 65        cx: &mut AsyncApp,
 66    ) -> Task<Result<()>> {
 67        log::info!(
 68            "Setting model for session {}: {:?}",
 69            session_id,
 70            model.name()
 71        );
 72        let agent = self.0.clone();
 73
 74        cx.spawn(async move |cx| {
 75            agent.update(cx, |agent, cx| {
 76                if let Some(session) = agent.sessions.get(&session_id) {
 77                    session.thread.update(cx, |thread, _cx| {
 78                        thread.selected_model = model;
 79                    });
 80                    Ok(())
 81                } else {
 82                    Err(anyhow!("Session not found"))
 83                }
 84            })?
 85        })
 86    }
 87
 88    fn selected_model(
 89        &self,
 90        session_id: &acp::SessionId,
 91        cx: &mut AsyncApp,
 92    ) -> Task<Result<Arc<dyn LanguageModel>>> {
 93        let agent = self.0.clone();
 94        let session_id = session_id.clone();
 95        cx.spawn(async move |cx| {
 96            let thread = agent
 97                .read_with(cx, |agent, _| {
 98                    agent
 99                        .sessions
100                        .get(&session_id)
101                        .map(|session| session.thread.clone())
102                })?
103                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
104            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
105            Ok(selected)
106        })
107    }
108}
109
110impl acp_thread::AgentConnection for NativeAgentConnection {
111    fn new_thread(
112        self: Rc<Self>,
113        project: Entity<Project>,
114        cwd: &Path,
115        cx: &mut AsyncApp,
116    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
117        let agent = self.0.clone();
118        log::info!("Creating new thread for project at: {:?}", cwd);
119
120        cx.spawn(async move |cx| {
121            log::debug!("Starting thread creation in async context");
122            // Create Thread
123            let (session_id, thread) = agent.update(
124                cx,
125                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
126                    // Fetch default model from registry settings
127                    let registry = LanguageModelRegistry::read_global(cx);
128
129                    // Log available models for debugging
130                    let available_count = registry.available_models(cx).count();
131                    log::debug!("Total available models: {}", available_count);
132
133                    let default_model = registry
134                        .default_model()
135                        .map(|configured| {
136                            log::info!(
137                                "Using configured default model: {:?} from provider: {:?}",
138                                configured.model.name(),
139                                configured.provider.name()
140                            );
141                            configured.model
142                        })
143                        .ok_or_else(|| {
144                            log::warn!("No default model configured in settings");
145                            anyhow!("No default model configured. Please configure a default model in settings.")
146                        })?;
147
148                    let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
149
150                    // Generate session ID
151                    let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
152                    log::info!("Created session with ID: {}", session_id);
153                    Ok((session_id, thread))
154                },
155            )??;
156
157            // Create AcpThread
158            let acp_thread = cx.update(|cx| {
159                cx.new(|cx| {
160                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
161                })
162            })?;
163
164            // Store the session
165            agent.update(cx, |agent, _cx| {
166                agent.sessions.insert(
167                    session_id,
168                    Session {
169                        thread,
170                        acp_thread: acp_thread.clone(),
171                    },
172                );
173            })?;
174
175            Ok(acp_thread)
176        })
177    }
178
179    fn auth_methods(&self) -> &[acp::AuthMethod] {
180        &[] // No auth for in-process
181    }
182
183    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
184        Task::ready(Ok(()))
185    }
186
187    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
188        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
189    }
190
191    fn prompt(
192        &self,
193        params: acp::PromptRequest,
194        cx: &mut App,
195    ) -> Task<Result<acp::PromptResponse>> {
196        let session_id = params.session_id.clone();
197        let agent = self.0.clone();
198        log::info!("Received prompt request for session: {}", session_id);
199        log::debug!("Prompt blocks count: {}", params.prompt.len());
200
201        cx.spawn(async move |cx| {
202            // Get session
203            let (thread, acp_thread) = agent
204                .update(cx, |agent, _| {
205                    agent
206                        .sessions
207                        .get_mut(&session_id)
208                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
209                })?
210                .ok_or_else(|| {
211                    log::error!("Session not found: {}", session_id);
212                    anyhow::anyhow!("Session not found")
213                })?;
214            log::debug!("Found session for: {}", session_id);
215
216            // Convert prompt to message
217            let message = convert_prompt_to_message(params.prompt);
218            log::info!("Converted prompt to message: {} chars", message.len());
219            log::debug!("Message content: {}", message);
220
221            // Get model using the ModelSelector capability (always available for agent2)
222            // Get the selected model from the thread directly
223            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
224
225            // Send to thread
226            log::info!("Sending message to thread with model: {:?}", model.name());
227            let mut response_stream =
228                thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
229
230            // Handle response stream and forward to session.acp_thread
231            while let Some(result) = response_stream.next().await {
232                match result {
233                    Ok(event) => {
234                        log::trace!("Received completion event: {:?}", event);
235
236                        match event {
237                            AgentResponseEvent::Text(text) => {
238                                acp_thread.update(cx, |thread, cx| {
239                                    thread.handle_session_update(
240                                        acp::SessionUpdate::AgentMessageChunk {
241                                            content: acp::ContentBlock::Text(acp::TextContent {
242                                                text,
243                                                annotations: None,
244                                            }),
245                                        },
246                                        cx,
247                                    )
248                                })??;
249                            }
250                            AgentResponseEvent::Thinking(text) => {
251                                acp_thread.update(cx, |thread, cx| {
252                                    thread.handle_session_update(
253                                        acp::SessionUpdate::AgentThoughtChunk {
254                                            content: acp::ContentBlock::Text(acp::TextContent {
255                                                text,
256                                                annotations: None,
257                                            }),
258                                        },
259                                        cx,
260                                    )
261                                })??;
262                            }
263                            AgentResponseEvent::ToolCall(tool_call) => {
264                                acp_thread.update(cx, |thread, cx| {
265                                    thread.handle_session_update(
266                                        acp::SessionUpdate::ToolCall(tool_call),
267                                        cx,
268                                    )
269                                })??;
270                            }
271                            AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
272                                acp_thread.update(cx, |thread, cx| {
273                                    thread.handle_session_update(
274                                        acp::SessionUpdate::ToolCallUpdate(tool_call_update),
275                                        cx,
276                                    )
277                                })??;
278                            }
279                            AgentResponseEvent::Stop(stop_reason) => {
280                                log::debug!("Assistant message complete: {:?}", stop_reason);
281                                return Ok(acp::PromptResponse { stop_reason });
282                            }
283                        }
284                    }
285                    Err(e) => {
286                        log::error!("Error in model response stream: {:?}", e);
287                        // TODO: Consider sending an error message to the UI
288                        break;
289                    }
290                }
291            }
292
293            log::info!("Response stream completed");
294            anyhow::Ok(acp::PromptResponse {
295                stop_reason: acp::StopReason::EndTurn,
296            })
297        })
298    }
299
300    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
301        log::info!("Cancelling on session: {}", session_id);
302        self.0.update(cx, |agent, cx| {
303            if let Some(agent) = agent.sessions.get(session_id) {
304                agent.thread.update(cx, |thread, _cx| thread.cancel());
305            }
306        });
307    }
308}
309
310/// Convert ACP content blocks to a message string
311fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
312    log::debug!("Converting {} content blocks to message", blocks.len());
313    let mut message = String::new();
314
315    for block in blocks {
316        match block {
317            acp::ContentBlock::Text(text) => {
318                log::trace!("Processing text block: {} chars", text.text.len());
319                message.push_str(&text.text);
320            }
321            acp::ContentBlock::ResourceLink(link) => {
322                log::trace!("Processing resource link: {}", link.uri);
323                message.push_str(&format!(" @{} ", link.uri));
324            }
325            acp::ContentBlock::Image(_) => {
326                log::trace!("Processing image block");
327                message.push_str(" [image] ");
328            }
329            acp::ContentBlock::Audio(_) => {
330                log::trace!("Processing audio block");
331                message.push_str(" [audio] ");
332            }
333            acp::ContentBlock::Resource(resource) => {
334                log::trace!("Processing resource block: {:?}", resource.resource);
335                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
336            }
337        }
338    }
339
340    message
341}