1use acp_thread::ModelSelector;
2use agent_client_protocol as acp;
3use anyhow::Result;
4use gpui::{App, AppContext, AsyncApp, Entity, Task};
5use language_model::{LanguageModel, LanguageModelRegistry};
6use project::Project;
7use std::collections::HashMap;
8use std::path::Path;
9use std::rc::Rc;
10use std::sync::Arc;
11
12use crate::{templates::Templates, Thread};
13
14pub struct NativeAgent {
15 /// Session ID -> Thread entity mapping
16 sessions: HashMap<acp::SessionId, Entity<Thread>>,
17 /// Shared templates for all threads
18 templates: Arc<Templates>,
19}
20
21impl NativeAgent {
22 pub fn new(templates: Arc<Templates>) -> Self {
23 Self {
24 sessions: HashMap::new(),
25 templates,
26 }
27 }
28}
29
30/// Wrapper struct that implements the AgentConnection trait
31#[derive(Clone)]
32pub struct NativeAgentConnection(pub Entity<NativeAgent>);
33
34impl ModelSelector for NativeAgentConnection {
35 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
36 cx.spawn(async move |cx| {
37 cx.update(|cx| {
38 let registry = LanguageModelRegistry::read_global(cx);
39 let models = registry.available_models(cx).collect::<Vec<_>>();
40 if models.is_empty() {
41 Err(anyhow::anyhow!("No models available"))
42 } else {
43 Ok(models)
44 }
45 })?
46 })
47 }
48
49 fn select_model(
50 &self,
51 session_id: &acp::SessionId,
52 model: Arc<dyn LanguageModel>,
53 cx: &mut AsyncApp,
54 ) -> Task<Result<()>> {
55 let agent = self.0.clone();
56 let session_id = session_id.clone();
57 cx.spawn(async move |cx| {
58 agent.update(cx, |agent, cx| {
59 if let Some(thread) = agent.sessions.get(&session_id) {
60 thread.update(cx, |thread, _| {
61 thread.selected_model = model;
62 });
63 Ok(())
64 } else {
65 Err(anyhow::anyhow!("Session not found"))
66 }
67 })?
68 })
69 }
70
71 fn selected_model(
72 &self,
73 session_id: &acp::SessionId,
74 cx: &mut AsyncApp,
75 ) -> Task<Result<Arc<dyn LanguageModel>>> {
76 let agent = self.0.clone();
77 let session_id = session_id.clone();
78 cx.spawn(async move |cx| {
79 let thread = agent
80 .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
81 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
82 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
83 Ok(selected)
84 })
85 }
86}
87
88impl acp_thread::AgentConnection for NativeAgentConnection {
89 fn new_thread(
90 self: Rc<Self>,
91 project: Entity<Project>,
92 cwd: &Path,
93 cx: &mut AsyncApp,
94 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
95 let _cwd = cwd.to_owned();
96 let agent = self.0.clone();
97
98 cx.spawn(async move |cx| {
99 // Create Thread and store in Agent
100 let (session_id, _thread) =
101 agent.update(cx, |agent, cx: &mut gpui::Context<NativeAgent>| {
102 // Fetch default model
103 let default_model = LanguageModelRegistry::read_global(cx)
104 .available_models(cx)
105 .next()
106 .unwrap_or_else(|| panic!("No default model available"));
107
108 let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
109 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
110 agent.sessions.insert(session_id.clone(), thread.clone());
111 (session_id, thread)
112 })?;
113
114 // Create AcpThread
115 let acp_thread = cx.update(|cx| {
116 cx.new(|cx| {
117 acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
118 })
119 })?;
120
121 Ok(acp_thread)
122 })
123 }
124
125 fn auth_methods(&self) -> &[acp::AuthMethod] {
126 &[] // No auth for in-process
127 }
128
129 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
130 Task::ready(Ok(()))
131 }
132
133 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
134 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
135 }
136
137 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
138 let session_id = params.session_id.clone();
139 let agent = self.0.clone();
140
141 cx.spawn(async move |cx| {
142 // Get thread
143 let thread: Entity<Thread> = agent
144 .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
145 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
146
147 // Convert prompt to message
148 let message = convert_prompt_to_message(params.prompt);
149
150 // Get model using the ModelSelector capability (always available for agent2)
151 // Get the selected model from the thread directly
152 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
153
154 // Send to thread
155 thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
156
157 Ok(())
158 })
159 }
160
161 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
162 self.0.update(cx, |agent, _cx| {
163 agent.sessions.remove(session_id);
164 });
165 }
166}
167
168/// Convert ACP content blocks to a message string
169fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
170 let mut message = String::new();
171
172 for block in blocks {
173 match block {
174 acp::ContentBlock::Text(text) => {
175 message.push_str(&text.text);
176 }
177 acp::ContentBlock::ResourceLink(link) => {
178 message.push_str(&format!(" @{} ", link.uri));
179 }
180 acp::ContentBlock::Image(_) => {
181 message.push_str(" [image] ");
182 }
183 acp::ContentBlock::Audio(_) => {
184 message.push_str(" [audio] ");
185 }
186 acp::ContentBlock::Resource(resource) => {
187 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
188 }
189 }
190 }
191
192 message
193}