agent.rs

  1use acp_thread::ModelSelector;
  2use agent_client_protocol as acp;
  3use anyhow::{anyhow, Result};
  4use futures::StreamExt;
  5use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
  6use language_model::{LanguageModel, LanguageModelRegistry};
  7use project::Project;
  8use std::collections::HashMap;
  9use std::path::Path;
 10use std::rc::Rc;
 11use std::sync::Arc;
 12
 13use crate::{templates::Templates, AgentResponseEvent, Thread};
 14
 15/// Holds both the internal Thread and the AcpThread for a session
 16struct Session {
 17    /// The internal thread that processes messages
 18    thread: Entity<Thread>,
 19    /// The ACP thread that handles protocol communication
 20    acp_thread: WeakEntity<acp_thread::AcpThread>,
 21    _subscription: Subscription,
 22}
 23
 24pub struct NativeAgent {
 25    /// Session ID -> Session mapping
 26    sessions: HashMap<acp::SessionId, Session>,
 27    /// Shared templates for all threads
 28    templates: Arc<Templates>,
 29}
 30
 31impl NativeAgent {
 32    pub fn new(templates: Arc<Templates>) -> Self {
 33        log::info!("Creating new NativeAgent");
 34        Self {
 35            sessions: HashMap::new(),
 36            templates,
 37        }
 38    }
 39}
 40
 41/// Wrapper struct that implements the AgentConnection trait
 42#[derive(Clone)]
 43pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 44
 45impl ModelSelector for NativeAgentConnection {
 46    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
 47        log::debug!("NativeAgentConnection::list_models called");
 48        cx.spawn(async move |cx| {
 49            cx.update(|cx| {
 50                let registry = LanguageModelRegistry::read_global(cx);
 51                let models = registry.available_models(cx).collect::<Vec<_>>();
 52                log::info!("Found {} available models", models.len());
 53                if models.is_empty() {
 54                    Err(anyhow::anyhow!("No models available"))
 55                } else {
 56                    Ok(models)
 57                }
 58            })?
 59        })
 60    }
 61
 62    fn select_model(
 63        &self,
 64        session_id: acp::SessionId,
 65        model: Arc<dyn LanguageModel>,
 66        cx: &mut AsyncApp,
 67    ) -> Task<Result<()>> {
 68        log::info!(
 69            "Setting model for session {}: {:?}",
 70            session_id,
 71            model.name()
 72        );
 73        let agent = self.0.clone();
 74
 75        cx.spawn(async move |cx| {
 76            agent.update(cx, |agent, cx| {
 77                if let Some(session) = agent.sessions.get(&session_id) {
 78                    session.thread.update(cx, |thread, _cx| {
 79                        thread.selected_model = model;
 80                    });
 81                    Ok(())
 82                } else {
 83                    Err(anyhow!("Session not found"))
 84                }
 85            })?
 86        })
 87    }
 88
 89    fn selected_model(
 90        &self,
 91        session_id: &acp::SessionId,
 92        cx: &mut AsyncApp,
 93    ) -> Task<Result<Arc<dyn LanguageModel>>> {
 94        let agent = self.0.clone();
 95        let session_id = session_id.clone();
 96        cx.spawn(async move |cx| {
 97            let thread = agent
 98                .read_with(cx, |agent, _| {
 99                    agent
100                        .sessions
101                        .get(&session_id)
102                        .map(|session| session.thread.clone())
103                })?
104                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
105            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
106            Ok(selected)
107        })
108    }
109}
110
111impl acp_thread::AgentConnection for NativeAgentConnection {
112    fn new_thread(
113        self: Rc<Self>,
114        project: Entity<Project>,
115        cwd: &Path,
116        cx: &mut AsyncApp,
117    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
118        let agent = self.0.clone();
119        log::info!("Creating new thread for project at: {:?}", cwd);
120
121        cx.spawn(async move |cx| {
122            log::debug!("Starting thread creation in async context");
123            // Create Thread
124            let (session_id, thread) = agent.update(
125                cx,
126                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
127                    // Fetch default model from registry settings
128                    let registry = LanguageModelRegistry::read_global(cx);
129
130                    // Log available models for debugging
131                    let available_count = registry.available_models(cx).count();
132                    log::debug!("Total available models: {}", available_count);
133
134                    let default_model = registry
135                        .default_model()
136                        .map(|configured| {
137                            log::info!(
138                                "Using configured default model: {:?} from provider: {:?}",
139                                configured.model.name(),
140                                configured.provider.name()
141                            );
142                            configured.model
143                        })
144                        .ok_or_else(|| {
145                            log::warn!("No default model configured in settings");
146                            anyhow!("No default model configured. Please configure a default model in settings.")
147                        })?;
148
149                    let thread = cx.new(|_| Thread::new(project.clone(), agent.templates.clone(), default_model));
150
151                    // Generate session ID
152                    let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
153                    log::info!("Created session with ID: {}", session_id);
154                    Ok((session_id, thread))
155                },
156            )??;
157
158            // Create AcpThread
159            let acp_thread = cx.update(|cx| {
160                cx.new(|cx| {
161                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
162                })
163            })?;
164
165            // Store the session
166            agent.update(cx, |agent, cx| {
167                agent.sessions.insert(
168                    session_id,
169                    Session {
170                        thread,
171                        acp_thread: acp_thread.downgrade(),
172                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
173                            this.sessions.remove(acp_thread.session_id());
174                        })
175                    },
176                );
177            })?;
178
179            Ok(acp_thread)
180        })
181    }
182
183    fn auth_methods(&self) -> &[acp::AuthMethod] {
184        &[] // No auth for in-process
185    }
186
187    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
188        Task::ready(Ok(()))
189    }
190
191    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
192        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
193    }
194
195    fn prompt(
196        &self,
197        params: acp::PromptRequest,
198        cx: &mut App,
199    ) -> Task<Result<acp::PromptResponse>> {
200        let session_id = params.session_id.clone();
201        let agent = self.0.clone();
202        log::info!("Received prompt request for session: {}", session_id);
203        log::debug!("Prompt blocks count: {}", params.prompt.len());
204
205        cx.spawn(async move |cx| {
206            // Get session
207            let (thread, acp_thread) = agent
208                .update(cx, |agent, _| {
209                    agent
210                        .sessions
211                        .get_mut(&session_id)
212                        .map(|s| (s.thread.clone(), s.acp_thread.clone()))
213                })?
214                .ok_or_else(|| {
215                    log::error!("Session not found: {}", session_id);
216                    anyhow::anyhow!("Session not found")
217                })?;
218            log::debug!("Found session for: {}", session_id);
219
220            // Convert prompt to message
221            let message = convert_prompt_to_message(params.prompt);
222            log::info!("Converted prompt to message: {} chars", message.len());
223            log::debug!("Message content: {}", message);
224
225            // Get model using the ModelSelector capability (always available for agent2)
226            // Get the selected model from the thread directly
227            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
228
229            // Send to thread
230            log::info!("Sending message to thread with model: {:?}", model.name());
231            let mut response_stream =
232                thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
233
234            // Handle response stream and forward to session.acp_thread
235            while let Some(result) = response_stream.next().await {
236                match result {
237                    Ok(event) => {
238                        log::trace!("Received completion event: {:?}", event);
239
240                        match event {
241                            AgentResponseEvent::Text(text) => {
242                                acp_thread.update(cx, |thread, cx| {
243                                    thread.handle_session_update(
244                                        acp::SessionUpdate::AgentMessageChunk {
245                                            content: acp::ContentBlock::Text(acp::TextContent {
246                                                text,
247                                                annotations: None,
248                                            }),
249                                        },
250                                        cx,
251                                    )
252                                })??;
253                            }
254                            AgentResponseEvent::Thinking(text) => {
255                                acp_thread.update(cx, |thread, cx| {
256                                    thread.handle_session_update(
257                                        acp::SessionUpdate::AgentThoughtChunk {
258                                            content: acp::ContentBlock::Text(acp::TextContent {
259                                                text,
260                                                annotations: None,
261                                            }),
262                                        },
263                                        cx,
264                                    )
265                                })??;
266                            }
267                            AgentResponseEvent::ToolCall(tool_call) => {
268                                acp_thread.update(cx, |thread, cx| {
269                                    thread.handle_session_update(
270                                        acp::SessionUpdate::ToolCall(tool_call),
271                                        cx,
272                                    )
273                                })??;
274                            }
275                            AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
276                                acp_thread.update(cx, |thread, cx| {
277                                    thread.handle_session_update(
278                                        acp::SessionUpdate::ToolCallUpdate(tool_call_update),
279                                        cx,
280                                    )
281                                })??;
282                            }
283                            AgentResponseEvent::Stop(stop_reason) => {
284                                log::debug!("Assistant message complete: {:?}", stop_reason);
285                                return Ok(acp::PromptResponse { stop_reason });
286                            }
287                        }
288                    }
289                    Err(e) => {
290                        log::error!("Error in model response stream: {:?}", e);
291                        // TODO: Consider sending an error message to the UI
292                        break;
293                    }
294                }
295            }
296
297            log::info!("Response stream completed");
298            anyhow::Ok(acp::PromptResponse {
299                stop_reason: acp::StopReason::EndTurn,
300            })
301        })
302    }
303
304    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
305        log::info!("Cancelling on session: {}", session_id);
306        self.0.update(cx, |agent, cx| {
307            if let Some(agent) = agent.sessions.get(session_id) {
308                agent.thread.update(cx, |thread, _cx| thread.cancel());
309            }
310        });
311    }
312}
313
314/// Convert ACP content blocks to a message string
315fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
316    log::debug!("Converting {} content blocks to message", blocks.len());
317    let mut message = String::new();
318
319    for block in blocks {
320        match block {
321            acp::ContentBlock::Text(text) => {
322                log::trace!("Processing text block: {} chars", text.text.len());
323                message.push_str(&text.text);
324            }
325            acp::ContentBlock::ResourceLink(link) => {
326                log::trace!("Processing resource link: {}", link.uri);
327                message.push_str(&format!(" @{} ", link.uri));
328            }
329            acp::ContentBlock::Image(_) => {
330                log::trace!("Processing image block");
331                message.push_str(" [image] ");
332            }
333            acp::ContentBlock::Audio(_) => {
334                log::trace!("Processing audio block");
335                message.push_str(" [audio] ");
336            }
337            acp::ContentBlock::Resource(resource) => {
338                log::trace!("Processing resource block: {:?}", resource.resource);
339                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
340            }
341        }
342    }
343
344    message
345}