agent.rs

  1use acp_thread::ModelSelector;
  2use agent_client_protocol as acp;
  3use anyhow::Result;
  4use gpui::{App, AppContext, AsyncApp, Entity, Task};
  5use language_model::{LanguageModel, LanguageModelRegistry};
  6use project::Project;
  7use std::collections::HashMap;
  8use std::path::Path;
  9use std::rc::Rc;
 10use std::sync::Arc;
 11
 12use crate::{templates::Templates, Thread};
 13
 14pub struct Agent {
 15    /// Session ID -> Thread entity mapping
 16    sessions: HashMap<acp::SessionId, Entity<Thread>>,
 17    /// Shared templates for all threads
 18    templates: Arc<Templates>,
 19}
 20
 21impl Agent {
 22    pub fn new(templates: Arc<Templates>) -> Self {
 23        Self {
 24            sessions: HashMap::new(),
 25            templates,
 26        }
 27    }
 28}
 29
 30/// Wrapper struct that implements the AgentConnection trait
 31#[derive(Clone)]
 32pub struct AgentConnection(pub Entity<Agent>);
 33
 34impl ModelSelector for AgentConnection {
 35    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
 36        let result = cx.update(|cx| {
 37            let registry = LanguageModelRegistry::read_global(cx);
 38            let models = registry.available_models(cx).collect::<Vec<_>>();
 39            if models.is_empty() {
 40                Err(anyhow::anyhow!("No models available"))
 41            } else {
 42                Ok(models)
 43            }
 44        });
 45        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
 46    }
 47
 48    fn select_model(
 49        &self,
 50        session_id: &acp::SessionId,
 51        model: Arc<dyn LanguageModel>,
 52        cx: &mut AsyncApp,
 53    ) -> Task<Result<()>> {
 54        let agent = self.0.clone();
 55        let result = agent.update(cx, |agent, cx| {
 56            if let Some(thread) = agent.sessions.get(session_id) {
 57                thread.update(cx, |thread, _| {
 58                    thread.selected_model = model;
 59                });
 60                Ok(())
 61            } else {
 62                Err(anyhow::anyhow!("Session not found"))
 63            }
 64        });
 65        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
 66    }
 67
 68    fn selected_model(
 69        &self,
 70        session_id: &acp::SessionId,
 71        cx: &mut AsyncApp,
 72    ) -> Task<Result<Arc<dyn LanguageModel>>> {
 73        let agent = self.0.clone();
 74        let thread_result = agent
 75            .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
 76            .ok()
 77            .flatten()
 78            .ok_or_else(|| anyhow::anyhow!("Session not found"));
 79
 80        match thread_result {
 81            Ok(thread) => {
 82                let selected = thread
 83                    .read_with(cx, |thread, _| thread.selected_model.clone())
 84                    .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
 85                Task::ready(Ok(selected))
 86            }
 87            Err(e) => Task::ready(Err(e)),
 88        }
 89    }
 90}
 91
 92impl acp_thread::AgentConnection for AgentConnection {
 93    fn new_thread(
 94        self: Rc<Self>,
 95        project: Entity<Project>,
 96        cwd: &Path,
 97        cx: &mut AsyncApp,
 98    ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
 99        let _cwd = cwd.to_owned();
100        let agent = self.0.clone();
101
102        cx.spawn(async move |cx| {
103            // Create Thread and store in Agent
104            let (session_id, _thread) =
105                agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
106                    // Fetch default model
107                    let default_model = LanguageModelRegistry::read_global(cx)
108                        .available_models(cx)
109                        .next()
110                        .unwrap_or_else(|| panic!("No default model available"));
111
112                    let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
113                    let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
114                    agent.sessions.insert(session_id.clone(), thread.clone());
115                    (session_id, thread)
116                })?;
117
118            // Create AcpThread
119            let acp_thread = cx.update(|cx| {
120                cx.new(|cx| {
121                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
122                })
123            })?;
124
125            Ok(acp_thread)
126        })
127    }
128
129    fn auth_methods(&self) -> &[acp::AuthMethod] {
130        &[] // No auth for in-process
131    }
132
133    fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
134        Task::ready(Ok(()))
135    }
136
137    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
138        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
139    }
140
141    fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
142        let session_id = params.session_id.clone();
143        let agent = self.0.clone();
144
145        cx.spawn(async move |cx| {
146            // Get thread
147            let thread: Entity<Thread> = agent
148                .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
149                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
150
151            // Convert prompt to message
152            let message = convert_prompt_to_message(params.prompt);
153
154            // Get model using the ModelSelector capability (always available for agent2)
155            // Get the selected model from the thread directly
156            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
157
158            // Send to thread
159            thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
160
161            Ok(())
162        })
163    }
164
165    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
166        self.0.update(cx, |agent, _cx| {
167            agent.sessions.remove(session_id);
168        });
169    }
170}
171
172/// Convert ACP content blocks to a message string
173fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
174    let mut message = String::new();
175
176    for block in blocks {
177        match block {
178            acp::ContentBlock::Text(text) => {
179                message.push_str(&text.text);
180            }
181            acp::ContentBlock::ResourceLink(link) => {
182                message.push_str(&format!(" @{} ", link.uri));
183            }
184            acp::ContentBlock::Image(_) => {
185                message.push_str(" [image] ");
186            }
187            acp::ContentBlock::Audio(_) => {
188                message.push_str(" [audio] ");
189            }
190            acp::ContentBlock::Resource(resource) => {
191                message.push_str(&format!(" [resource: {:?}] ", resource.resource));
192            }
193        }
194    }
195
196    message
197}