agent.rs

  1use acp_thread::ModelSelector;
  2use agent_client_protocol as acp;
  3use anyhow::{anyhow, Result};
  4use gpui::{App, AppContext, AsyncApp, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRegistry};
  6use project::Project;
  7use std::collections::HashMap;
  8use std::path::Path;
  9use std::rc::Rc;
 10use std::sync::Arc;
 11
 12use crate::{templates::Templates, Thread};
 13
 14/// Holds both the internal Thread and the AcpThread for a session
 15#[derive(Clone)]
 16struct Session {
 17    /// The internal thread that processes messages
 18    thread: Entity<Thread>,
 19    /// The ACP thread that handles protocol communication
 20    acp_thread: Entity<acp_thread::AcpThread>,
 21}
 22
 23pub struct NativeAgent {
 24    /// Session ID -> Session mapping
 25    sessions: HashMap<acp::SessionId, Session>,
 26    /// Shared templates for all threads
 27    templates: Arc<Templates>,
 28}
 29
 30impl NativeAgent {
 31    pub fn new(templates: Arc<Templates>) -> Self {
 32        log::info!("Creating new NativeAgent");
 33        Self {
 34            sessions: HashMap::new(),
 35            templates,
 36        }
 37    }
 38}
 39
 40/// Wrapper struct that implements the AgentConnection trait
 41#[derive(Clone)]
 42pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 43
 44impl ModelSelector for NativeAgentConnection {
 45    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
 46        log::debug!("NativeAgentConnection::list_models called");
 47        cx.spawn(async move |cx| {
 48            cx.update(|cx| {
 49                let registry = LanguageModelRegistry::read_global(cx);
 50                let models = registry.available_models(cx).collect::<Vec<_>>();
 51                log::info!("Found {} available models", models.len());
 52                if models.is_empty() {
 53                    Err(anyhow::anyhow!("No models available"))
 54                } else {
 55                    Ok(models)
 56                }
 57            })?
 58        })
 59    }
 60
 61    fn select_model(
 62        &self,
 63        session_id: acp::SessionId,
 64        model: Arc<dyn LanguageModel>,
 65        cx: &mut AsyncApp,
 66    ) -> Task<Result<()>> {
 67        log::info!(
 68            "Setting model for session {}: {:?}",
 69            session_id,
 70            model.name()
 71        );
 72        let agent = self.0.clone();
 73
 74        cx.spawn(async move |cx| {
 75            agent.update(cx, |agent, cx| {
 76                if let Some(session) = agent.sessions.get(&session_id) {
 77                    session.thread.update(cx, |thread, _cx| {
 78                        thread.selected_model = model;
 79                    });
 80                    Ok(())
 81                } else {
 82                    Err(anyhow!("Session not found"))
 83                }
 84            })?
 85        })
 86    }
 87
 88    fn selected_model(
 89        &self,
 90        session_id: &acp::SessionId,
 91        cx: &mut AsyncApp,
 92    ) -> Task<Result<Arc<dyn LanguageModel>>> {
 93        let agent = self.0.clone();
 94        let session_id = session_id.clone();
 95        cx.spawn(async move |cx| {
 96            let session = agent
 97                .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
 98                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
 99            let selected = session
100                .thread
101                .read_with(cx, |thread, _| thread.selected_model.clone())?;
102            Ok(selected)
103        })
104    }
105}
106
107impl acp_thread::AgentConnection for NativeAgentConnection {
108    fn new_thread(
109        self: Rc<Self>,
110        project: Entity<Project>,
111        cwd: &Path,
112        cx: &mut AsyncApp,
113    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
114        let agent = self.0.clone();
115        log::info!("Creating new thread for project at: {:?}", cwd);
116
117        cx.spawn(async move |cx| {
118            log::debug!("Starting thread creation in async context");
119            // Create Thread
120            let (session_id, thread) = agent.update(
121                cx,
122                |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
123                    // Fetch default model from registry settings
124                    let registry = LanguageModelRegistry::read_global(cx);
125
126                    // Log available models for debugging
127                    let available_count = registry.available_models(cx).count();
128                    log::debug!("Total available models: {}", available_count);
129
130                    let default_model = registry
131                        .default_model()
132                        .map(|configured| {
133                            log::info!(
134                                "Using configured default model: {:?} from provider: {:?}",
135                                configured.model.name(),
136                                configured.provider.name()
137                            );
138                            configured.model
139                        })
140                        .ok_or_else(|| {
141                            log::warn!("No default model configured in settings");
142                            anyhow!("No default model configured. Please configure a default model in settings.")
143                        })?;
144
145                    let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
146
147                    // Generate session ID
148                    let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
149                    log::info!("Created session with ID: {}", session_id);
150                    Ok((session_id, thread))
151                },
152            )??;
153
154            // Create AcpThread
155            let acp_thread = cx.update(|cx| {
156                cx.new(|cx| {
157                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id.clone(), cx)
158                })
159            })?;
160
161            // Store the session
162            agent.update(cx, |agent, _cx| {
163                agent.sessions.insert(
164                    session_id,
165                    Session {
166                        thread,
167                        acp_thread: acp_thread.clone(),
168                    },
169                );
170            })?;
171
172            Ok(acp_thread)
173        })
174    }
175
176    fn auth_methods(&self) -> &[acp::AuthMethod] {
177        &[] // No auth for in-process
178    }
179
180    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
181        Task::ready(Ok(()))
182    }
183
184    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
185        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
186    }
187
188    fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
189        let session_id = params.session_id.clone();
190        let agent = self.0.clone();
191        log::info!("Received prompt request for session: {}", session_id);
192        log::debug!("Prompt blocks count: {}", params.prompt.len());
193
194        cx.spawn(async move |cx| {
195            // Get session
196            let session = agent
197                .read_with(cx, |agent, _| {
198                    agent.sessions.get(&session_id).map(|s| Session {
199                        thread: s.thread.clone(),
200                        acp_thread: s.acp_thread.clone(),
201                    })
202                })?
203                .ok_or_else(|| {
204                    log::error!("Session not found: {}", session_id);
205                    anyhow::anyhow!("Session not found")
206                })?;
207            log::debug!("Found session for: {}", session_id);
208
209            // Convert prompt to message
210            let message = convert_prompt_to_message(params.prompt);
211            log::info!("Converted prompt to message: {} chars", message.len());
212            log::debug!("Message content: {}", message);
213
214            // Get model using the ModelSelector capability (always available for agent2)
215            // Get the selected model from the thread directly
216            let model = session
217                .thread
218                .read_with(cx, |thread, _| thread.selected_model.clone())?;
219
220            // Send to thread
221            log::info!("Sending message to thread with model: {:?}", model.name());
222            let response_stream = session
223                .thread
224                .update(cx, |thread, cx| thread.send(model, message, cx))?;
225
226            // Handle response stream and forward to session.acp_thread
227            let acp_thread = session.acp_thread.clone();
228            cx.spawn(async move |cx| {
229                use futures::StreamExt;
230                use language_model::LanguageModelCompletionEvent;
231
232                let mut response_stream = response_stream;
233
234                while let Some(result) = response_stream.next().await {
235                    match result {
236                        Ok(event) => {
237                            log::trace!("Received completion event: {:?}", event);
238
239                            match event {
240                                LanguageModelCompletionEvent::Text(text) => {
241                                    // Send text chunk as agent message
242                                    acp_thread.update(cx, |thread, cx| {
243                                        thread.handle_session_update(
244                                            acp::SessionUpdate::AgentMessageChunk {
245                                                content: acp::ContentBlock::Text(
246                                                    acp::TextContent {
247                                                        text: text.into(),
248                                                        annotations: None,
249                                                    },
250                                                ),
251                                            },
252                                            cx,
253                                        )
254                                    })??;
255                                }
256                                LanguageModelCompletionEvent::ToolUse(tool_use) => {
257                                    // Convert LanguageModelToolUse to ACP ToolCall
258                                    acp_thread.update(cx, |thread, cx| {
259                                        thread.handle_session_update(
260                                            acp::SessionUpdate::ToolCall(acp::ToolCall {
261                                                id: acp::ToolCallId(tool_use.id.to_string().into()),
262                                                label: tool_use.name.to_string(),
263                                                kind: acp::ToolKind::Other,
264                                                status: acp::ToolCallStatus::Pending,
265                                                content: vec![],
266                                                locations: vec![],
267                                                raw_input: Some(tool_use.input),
268                                            }),
269                                            cx,
270                                        )
271                                    })??;
272                                }
273                                LanguageModelCompletionEvent::StartMessage { .. } => {
274                                    log::debug!("Started new assistant message");
275                                }
276                                LanguageModelCompletionEvent::UsageUpdate(usage) => {
277                                    log::debug!("Token usage update: {:?}", usage);
278                                }
279                                LanguageModelCompletionEvent::Thinking { text, .. } => {
280                                    // Send thinking text as agent thought chunk
281                                    acp_thread.update(cx, |thread, cx| {
282                                        thread.handle_session_update(
283                                            acp::SessionUpdate::AgentThoughtChunk {
284                                                content: acp::ContentBlock::Text(
285                                                    acp::TextContent {
286                                                        text: text.into(),
287                                                        annotations: None,
288                                                    },
289                                                ),
290                                            },
291                                            cx,
292                                        )
293                                    })??;
294                                }
295                                LanguageModelCompletionEvent::StatusUpdate(status) => {
296                                    log::trace!("Status update: {:?}", status);
297                                }
298                                LanguageModelCompletionEvent::Stop(stop_reason) => {
299                                    log::debug!("Assistant message complete: {:?}", stop_reason);
300                                }
301                                LanguageModelCompletionEvent::RedactedThinking { .. } => {
302                                    log::trace!("Redacted thinking event");
303                                }
304                                LanguageModelCompletionEvent::ToolUseJsonParseError {
305                                    id,
306                                    tool_name,
307                                    raw_input,
308                                    json_parse_error,
309                                } => {
310                                    log::error!(
311                                        "Tool use JSON parse error for tool '{}' (id: {}): {} - input: {}",
312                                        tool_name,
313                                        id,
314                                        json_parse_error,
315                                        raw_input
316                                    );
317                                }
318                            }
319                        }
320                        Err(e) => {
321                            log::error!("Error in model response stream: {:?}", e);
322                            // TODO: Consider sending an error message to the UI
323                            break;
324                        }
325                    }
326                }
327
328                log::info!("Response stream completed");
329                anyhow::Ok(())
330            })
331            .detach();
332
333            log::info!("Successfully sent prompt to thread and started response handler");
334            Ok(())
335        })
336    }
337
338    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
339        log::info!("Cancelling session: {}", session_id);
340        self.0.update(cx, |agent, _cx| {
341            agent.sessions.remove(session_id);
342        });
343    }
344}
345
346/// Convert ACP content blocks to a message string
347fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
348    log::debug!("Converting {} content blocks to message", blocks.len());
349    let mut message = String::new();
350
351    for block in blocks {
352        match block {
353            acp::ContentBlock::Text(text) => {
354                log::trace!("Processing text block: {} chars", text.text.len());
355                message.push_str(&text.text);
356            }
357            acp::ContentBlock::ResourceLink(link) => {
358                log::trace!("Processing resource link: {}", link.uri);
359                message.push_str(&format!(" @{} ", link.uri));
360            }
361            acp::ContentBlock::Image(_) => {
362                log::trace!("Processing image block");
363                message.push_str(" [image] ");
364            }
365            acp::ContentBlock::Audio(_) => {
366                log::trace!("Processing audio block");
367                message.push_str(" [audio] ");
368            }
369            acp::ContentBlock::Resource(resource) => {
370                log::trace!("Processing resource block: {:?}", resource.resource);
371                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
372            }
373        }
374    }
375
376    message
377}