agent.rs

  1use acp_thread::ModelSelector;
  2use agent_client_protocol as acp;
  3use anyhow::Result;
  4use gpui::{App, AppContext, AsyncApp, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRegistry};
  6use project::Project;
  7use std::collections::HashMap;
  8use std::path::Path;
  9use std::rc::Rc;
 10use std::sync::Arc;
 11
 12use crate::{templates::Templates, Thread};
 13
 14pub struct NativeAgent {
 15    /// Session ID -> Thread entity mapping
 16    sessions: HashMap<acp::SessionId, Entity<Thread>>,
 17    /// Shared templates for all threads
 18    templates: Arc<Templates>,
 19}
 20
 21impl NativeAgent {
 22    pub fn new(templates: Arc<Templates>) -> Self {
 23        Self {
 24            sessions: HashMap::new(),
 25            templates,
 26        }
 27    }
 28}
 29
 30/// Wrapper struct that implements the AgentConnection trait
 31#[derive(Clone)]
 32pub struct NativeAgentConnection(pub Entity<NativeAgent>);
 33
 34impl ModelSelector for NativeAgentConnection {
 35    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
 36        cx.spawn(async move |cx| {
 37            cx.update(|cx| {
 38                let registry = LanguageModelRegistry::read_global(cx);
 39                let models = registry.available_models(cx).collect::<Vec<_>>();
 40                if models.is_empty() {
 41                    Err(anyhow::anyhow!("No models available"))
 42                } else {
 43                    Ok(models)
 44                }
 45            })?
 46        })
 47    }
 48
 49    fn select_model(
 50        &self,
 51        session_id: &acp::SessionId,
 52        model: Arc<dyn LanguageModel>,
 53        cx: &mut AsyncApp,
 54    ) -> Task<Result<()>> {
 55        let agent = self.0.clone();
 56        let session_id = session_id.clone();
 57        cx.spawn(async move |cx| {
 58            agent.update(cx, |agent, cx| {
 59                if let Some(thread) = agent.sessions.get(&session_id) {
 60                    thread.update(cx, |thread, _| {
 61                        thread.selected_model = model;
 62                    });
 63                    Ok(())
 64                } else {
 65                    Err(anyhow::anyhow!("Session not found"))
 66                }
 67            })?
 68        })
 69    }
 70
 71    fn selected_model(
 72        &self,
 73        session_id: &acp::SessionId,
 74        cx: &mut AsyncApp,
 75    ) -> Task<Result<Arc<dyn LanguageModel>>> {
 76        let agent = self.0.clone();
 77        let session_id = session_id.clone();
 78        cx.spawn(async move |cx| {
 79            let thread = agent
 80                .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
 81                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
 82            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
 83            Ok(selected)
 84        })
 85    }
 86}
 87
 88impl acp_thread::AgentConnection for NativeAgentConnection {
 89    fn new_thread(
 90        self: Rc<Self>,
 91        project: Entity<Project>,
 92        cwd: &Path,
 93        cx: &mut AsyncApp,
 94    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 95        let _cwd = cwd.to_owned();
 96        let agent = self.0.clone();
 97
 98        cx.spawn(async move |cx| {
 99            // Create Thread and store in Agent
100            let (session_id, _thread) =
101                agent.update(cx, |agent, cx: &mut gpui::Context<NativeAgent>| {
102                    // Fetch default model
103                    let default_model = LanguageModelRegistry::read_global(cx)
104                        .available_models(cx)
105                        .next()
106                        .unwrap_or_else(|| panic!("No default model available"));
107
108                    let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
109                    let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
110                    agent.sessions.insert(session_id.clone(), thread.clone());
111                    (session_id, thread)
112                })?;
113
114            // Create AcpThread
115            let acp_thread = cx.update(|cx| {
116                cx.new(|cx| {
117                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
118                })
119            })?;
120
121            Ok(acp_thread)
122        })
123    }
124
125    fn auth_methods(&self) -> &[acp::AuthMethod] {
126        &[] // No auth for in-process
127    }
128
129    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
130        Task::ready(Ok(()))
131    }
132
133    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
134        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
135    }
136
137    fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
138        let session_id = params.session_id.clone();
139        let agent = self.0.clone();
140
141        cx.spawn(async move |cx| {
142            // Get thread
143            let thread: Entity<Thread> = agent
144                .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
145                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
146
147            // Convert prompt to message
148            let message = convert_prompt_to_message(params.prompt);
149
150            // Get model using the ModelSelector capability (always available for agent2)
151            // Get the selected model from the thread directly
152            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
153
154            // Send to thread
155            thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
156
157            Ok(())
158        })
159    }
160
161    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
162        self.0.update(cx, |agent, _cx| {
163            agent.sessions.remove(session_id);
164        });
165    }
166}
167
168/// Convert ACP content blocks to a message string
169fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
170    let mut message = String::new();
171
172    for block in blocks {
173        match block {
174            acp::ContentBlock::Text(text) => {
175                message.push_str(&text.text);
176            }
177            acp::ContentBlock::ResourceLink(link) => {
178                message.push_str(&format!(" @{} ", link.uri));
179            }
180            acp::ContentBlock::Image(_) => {
181                message.push_str(" [image] ");
182            }
183            acp::ContentBlock::Audio(_) => {
184                message.push_str(" [audio] ");
185            }
186            acp::ContentBlock::Resource(resource) => {
187                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
188            }
189        }
190    }
191
192    message
193}