1use crate::{templates::Templates, AgentResponseEvent, Thread};
2use crate::{FindPathTool, ThinkingTool, ToolCallAuthorization};
3use acp_thread::ModelSelector;
4use agent_client_protocol as acp;
5use anyhow::{anyhow, Context as _, Result};
6use futures::{future, StreamExt};
7use gpui::{
8 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
9};
10use language_model::{LanguageModel, LanguageModelRegistry};
11use project::{Project, ProjectItem, ProjectPath, Worktree};
12use prompt_store::{
13 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
14};
15use std::cell::RefCell;
16use std::collections::HashMap;
17use std::path::Path;
18use std::rc::Rc;
19use std::sync::Arc;
20use util::ResultExt;
21
22const RULES_FILE_NAMES: [&'static str; 9] = [
23 ".rules",
24 ".cursorrules",
25 ".windsurfrules",
26 ".clinerules",
27 ".github/copilot-instructions.md",
28 "CLAUDE.md",
29 "AGENT.md",
30 "AGENTS.md",
31 "GEMINI.md",
32];
33
34pub struct RulesLoadingError {
35 pub message: SharedString,
36}
37
38/// Holds both the internal Thread and the AcpThread for a session
39struct Session {
40 /// The internal thread that processes messages
41 thread: Entity<Thread>,
42 /// The ACP thread that handles protocol communication
43 acp_thread: WeakEntity<acp_thread::AcpThread>,
44 _subscription: Subscription,
45}
46
47pub struct NativeAgent {
48 /// Session ID -> Session mapping
49 sessions: HashMap<acp::SessionId, Session>,
50 /// Shared project context for all threads
51 project_context: Rc<RefCell<ProjectContext>>,
52 project_context_needs_refresh: watch::Sender<()>,
53 _maintain_project_context: Task<Result<()>>,
54 /// Shared templates for all threads
55 templates: Arc<Templates>,
56 project: Entity<Project>,
57 prompt_store: Option<Entity<PromptStore>>,
58 _subscriptions: Vec<Subscription>,
59}
60
61impl NativeAgent {
62 pub async fn new(
63 project: Entity<Project>,
64 templates: Arc<Templates>,
65 prompt_store: Option<Entity<PromptStore>>,
66 cx: &mut AsyncApp,
67 ) -> Result<Entity<NativeAgent>> {
68 log::info!("Creating new NativeAgent");
69
70 let project_context = cx
71 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
72 .await;
73
74 cx.new(|cx| {
75 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
76 if let Some(prompt_store) = prompt_store.as_ref() {
77 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
78 }
79
80 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
81 watch::channel(());
82 Self {
83 sessions: HashMap::new(),
84 project_context: Rc::new(RefCell::new(project_context)),
85 project_context_needs_refresh: project_context_needs_refresh_tx,
86 _maintain_project_context: cx.spawn(async move |this, cx| {
87 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
88 }),
89 templates,
90 project,
91 prompt_store,
92 _subscriptions: subscriptions,
93 }
94 })
95 }
96
97 async fn maintain_project_context(
98 this: WeakEntity<Self>,
99 mut needs_refresh: watch::Receiver<()>,
100 cx: &mut AsyncApp,
101 ) -> Result<()> {
102 while needs_refresh.changed().await.is_ok() {
103 let project_context = this
104 .update(cx, |this, cx| {
105 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
106 })?
107 .await;
108 this.update(cx, |this, _| this.project_context.replace(project_context))?;
109 }
110
111 Ok(())
112 }
113
114 fn build_project_context(
115 project: &Entity<Project>,
116 prompt_store: Option<&Entity<PromptStore>>,
117 cx: &mut App,
118 ) -> Task<ProjectContext> {
119 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
120 let worktree_tasks = worktrees
121 .into_iter()
122 .map(|worktree| {
123 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
124 })
125 .collect::<Vec<_>>();
126 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
127 prompt_store.read_with(cx, |prompt_store, cx| {
128 let prompts = prompt_store.default_prompt_metadata();
129 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
130 let contents = prompt_store.load(prompt_metadata.id, cx);
131 async move { (contents.await, prompt_metadata) }
132 });
133 cx.background_spawn(future::join_all(load_tasks))
134 })
135 } else {
136 Task::ready(vec![])
137 };
138
139 cx.spawn(async move |_cx| {
140 let (worktrees, default_user_rules) =
141 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
142
143 let worktrees = worktrees
144 .into_iter()
145 .map(|(worktree, _rules_error)| {
146 // TODO: show error message
147 // if let Some(rules_error) = rules_error {
148 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
149 // }
150 worktree
151 })
152 .collect::<Vec<_>>();
153
154 let default_user_rules = default_user_rules
155 .into_iter()
156 .flat_map(|(contents, prompt_metadata)| match contents {
157 Ok(contents) => Some(UserRulesContext {
158 uuid: match prompt_metadata.id {
159 PromptId::User { uuid } => uuid,
160 PromptId::EditWorkflow => return None,
161 },
162 title: prompt_metadata.title.map(|title| title.to_string()),
163 contents,
164 }),
165 Err(_err) => {
166 // TODO: show error message
167 // this.update(cx, |_, cx| {
168 // cx.emit(RulesLoadingError {
169 // message: format!("{err:?}").into(),
170 // });
171 // })
172 // .ok();
173 None
174 }
175 })
176 .collect::<Vec<_>>();
177
178 ProjectContext::new(worktrees, default_user_rules)
179 })
180 }
181
182 fn load_worktree_info_for_system_prompt(
183 worktree: Entity<Worktree>,
184 project: Entity<Project>,
185 cx: &mut App,
186 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
187 let tree = worktree.read(cx);
188 let root_name = tree.root_name().into();
189 let abs_path = tree.abs_path();
190
191 let mut context = WorktreeContext {
192 root_name,
193 abs_path,
194 rules_file: None,
195 };
196
197 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
198 let Some(rules_task) = rules_task else {
199 return Task::ready((context, None));
200 };
201
202 cx.spawn(async move |_| {
203 let (rules_file, rules_file_error) = match rules_task.await {
204 Ok(rules_file) => (Some(rules_file), None),
205 Err(err) => (
206 None,
207 Some(RulesLoadingError {
208 message: format!("{err}").into(),
209 }),
210 ),
211 };
212 context.rules_file = rules_file;
213 (context, rules_file_error)
214 })
215 }
216
217 fn load_worktree_rules_file(
218 worktree: Entity<Worktree>,
219 project: Entity<Project>,
220 cx: &mut App,
221 ) -> Option<Task<Result<RulesFileContext>>> {
222 let worktree = worktree.read(cx);
223 let worktree_id = worktree.id();
224 let selected_rules_file = RULES_FILE_NAMES
225 .into_iter()
226 .filter_map(|name| {
227 worktree
228 .entry_for_path(name)
229 .filter(|entry| entry.is_file())
230 .map(|entry| entry.path.clone())
231 })
232 .next();
233
234 // Note that Cline supports `.clinerules` being a directory, but that is not currently
235 // supported. This doesn't seem to occur often in GitHub repositories.
236 selected_rules_file.map(|path_in_worktree| {
237 let project_path = ProjectPath {
238 worktree_id,
239 path: path_in_worktree.clone(),
240 };
241 let buffer_task =
242 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
243 let rope_task = cx.spawn(async move |cx| {
244 buffer_task.await?.read_with(cx, |buffer, cx| {
245 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
246 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
247 })?
248 });
249 // Build a string from the rope on a background thread.
250 cx.background_spawn(async move {
251 let (project_entry_id, rope) = rope_task.await?;
252 anyhow::Ok(RulesFileContext {
253 path_in_worktree,
254 text: rope.to_string().trim().to_string(),
255 project_entry_id: project_entry_id.to_usize(),
256 })
257 })
258 })
259 }
260
261 fn handle_project_event(
262 &mut self,
263 _project: Entity<Project>,
264 event: &project::Event,
265 _cx: &mut Context<Self>,
266 ) {
267 match event {
268 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
269 self.project_context_needs_refresh.send(()).ok();
270 }
271 project::Event::WorktreeUpdatedEntries(_, items) => {
272 if items.iter().any(|(path, _, _)| {
273 RULES_FILE_NAMES
274 .iter()
275 .any(|name| path.as_ref() == Path::new(name))
276 }) {
277 self.project_context_needs_refresh.send(()).ok();
278 }
279 }
280 _ => {}
281 }
282 }
283
284 fn handle_prompts_updated_event(
285 &mut self,
286 _prompt_store: Entity<PromptStore>,
287 _event: &prompt_store::PromptsUpdatedEvent,
288 _cx: &mut Context<Self>,
289 ) {
290 self.project_context_needs_refresh.send(()).ok();
291 }
292}
293
294/// Wrapper struct that implements the AgentConnection trait
295#[derive(Clone)]
296pub struct NativeAgentConnection(pub Entity<NativeAgent>);
297
298impl ModelSelector for NativeAgentConnection {
299 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
300 log::debug!("NativeAgentConnection::list_models called");
301 cx.spawn(async move |cx| {
302 cx.update(|cx| {
303 let registry = LanguageModelRegistry::read_global(cx);
304 let models = registry.available_models(cx).collect::<Vec<_>>();
305 log::info!("Found {} available models", models.len());
306 if models.is_empty() {
307 Err(anyhow::anyhow!("No models available"))
308 } else {
309 Ok(models)
310 }
311 })?
312 })
313 }
314
315 fn select_model(
316 &self,
317 session_id: acp::SessionId,
318 model: Arc<dyn LanguageModel>,
319 cx: &mut AsyncApp,
320 ) -> Task<Result<()>> {
321 log::info!(
322 "Setting model for session {}: {:?}",
323 session_id,
324 model.name()
325 );
326 let agent = self.0.clone();
327
328 cx.spawn(async move |cx| {
329 agent.update(cx, |agent, cx| {
330 if let Some(session) = agent.sessions.get(&session_id) {
331 session.thread.update(cx, |thread, _cx| {
332 thread.selected_model = model;
333 });
334 Ok(())
335 } else {
336 Err(anyhow!("Session not found"))
337 }
338 })?
339 })
340 }
341
342 fn selected_model(
343 &self,
344 session_id: &acp::SessionId,
345 cx: &mut AsyncApp,
346 ) -> Task<Result<Arc<dyn LanguageModel>>> {
347 let agent = self.0.clone();
348 let session_id = session_id.clone();
349 cx.spawn(async move |cx| {
350 let thread = agent
351 .read_with(cx, |agent, _| {
352 agent
353 .sessions
354 .get(&session_id)
355 .map(|session| session.thread.clone())
356 })?
357 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
358 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
359 Ok(selected)
360 })
361 }
362}
363
364impl acp_thread::AgentConnection for NativeAgentConnection {
365 fn new_thread(
366 self: Rc<Self>,
367 project: Entity<Project>,
368 cwd: &Path,
369 cx: &mut AsyncApp,
370 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
371 let agent = self.0.clone();
372 log::info!("Creating new thread for project at: {:?}", cwd);
373
374 cx.spawn(async move |cx| {
375 log::debug!("Starting thread creation in async context");
376
377 // Generate session ID
378 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
379 log::info!("Created session with ID: {}", session_id);
380
381 // Create AcpThread
382 let acp_thread = cx.update(|cx| {
383 cx.new(|cx| {
384 acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
385 })
386 })?;
387 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
388
389 // Create Thread
390 let thread = agent.update(
391 cx,
392 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
393 // Fetch default model from registry settings
394 let registry = LanguageModelRegistry::read_global(cx);
395
396 // Log available models for debugging
397 let available_count = registry.available_models(cx).count();
398 log::debug!("Total available models: {}", available_count);
399
400 let default_model = registry
401 .default_model()
402 .map(|configured| {
403 log::info!(
404 "Using configured default model: {:?} from provider: {:?}",
405 configured.model.name(),
406 configured.provider.name()
407 );
408 configured.model
409 })
410 .ok_or_else(|| {
411 log::warn!("No default model configured in settings");
412 anyhow!("No default model configured. Please configure a default model in settings.")
413 })?;
414
415 let thread = cx.new(|_| {
416 let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log, agent.templates.clone(), default_model);
417 thread.add_tool(ThinkingTool);
418 thread.add_tool(FindPathTool::new(project.clone()));
419 thread
420 });
421
422 Ok(thread)
423 },
424 )??;
425
426 // Store the session
427 agent.update(cx, |agent, cx| {
428 agent.sessions.insert(
429 session_id,
430 Session {
431 thread,
432 acp_thread: acp_thread.downgrade(),
433 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
434 this.sessions.remove(acp_thread.session_id());
435 })
436 },
437 );
438 })?;
439
440 Ok(acp_thread)
441 })
442 }
443
444 fn auth_methods(&self) -> &[acp::AuthMethod] {
445 &[] // No auth for in-process
446 }
447
448 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
449 Task::ready(Ok(()))
450 }
451
452 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
453 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
454 }
455
456 fn prompt(
457 &self,
458 params: acp::PromptRequest,
459 cx: &mut App,
460 ) -> Task<Result<acp::PromptResponse>> {
461 let session_id = params.session_id.clone();
462 let agent = self.0.clone();
463 log::info!("Received prompt request for session: {}", session_id);
464 log::debug!("Prompt blocks count: {}", params.prompt.len());
465
466 cx.spawn(async move |cx| {
467 // Get session
468 let (thread, acp_thread) = agent
469 .update(cx, |agent, _| {
470 agent
471 .sessions
472 .get_mut(&session_id)
473 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
474 })?
475 .ok_or_else(|| {
476 log::error!("Session not found: {}", session_id);
477 anyhow::anyhow!("Session not found")
478 })?;
479 log::debug!("Found session for: {}", session_id);
480
481 // Convert prompt to message
482 let message = convert_prompt_to_message(params.prompt);
483 log::info!("Converted prompt to message: {} chars", message.len());
484 log::debug!("Message content: {}", message);
485
486 // Get model using the ModelSelector capability (always available for agent2)
487 // Get the selected model from the thread directly
488 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
489
490 // Send to thread
491 log::info!("Sending message to thread with model: {:?}", model.name());
492 let mut response_stream =
493 thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
494
495 // Handle response stream and forward to session.acp_thread
496 while let Some(result) = response_stream.next().await {
497 match result {
498 Ok(event) => {
499 log::trace!("Received completion event: {:?}", event);
500
501 match event {
502 AgentResponseEvent::Text(text) => {
503 acp_thread.update(cx, |thread, cx| {
504 thread.handle_session_update(
505 acp::SessionUpdate::AgentMessageChunk {
506 content: acp::ContentBlock::Text(acp::TextContent {
507 text,
508 annotations: None,
509 }),
510 },
511 cx,
512 )
513 })??;
514 }
515 AgentResponseEvent::Thinking(text) => {
516 acp_thread.update(cx, |thread, cx| {
517 thread.handle_session_update(
518 acp::SessionUpdate::AgentThoughtChunk {
519 content: acp::ContentBlock::Text(acp::TextContent {
520 text,
521 annotations: None,
522 }),
523 },
524 cx,
525 )
526 })??;
527 }
528 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
529 tool_call,
530 options,
531 response,
532 }) => {
533 let recv = acp_thread.update(cx, |thread, cx| {
534 thread.request_tool_call_authorization(tool_call, options, cx)
535 })?;
536 cx.background_spawn(async move {
537 if let Some(option) = recv
538 .await
539 .context("authorization sender was dropped")
540 .log_err()
541 {
542 response
543 .send(option)
544 .map(|_| anyhow!("authorization receiver was dropped"))
545 .log_err();
546 }
547 })
548 .detach();
549 }
550 AgentResponseEvent::ToolCall(tool_call) => {
551 acp_thread.update(cx, |thread, cx| {
552 thread.handle_session_update(
553 acp::SessionUpdate::ToolCall(tool_call),
554 cx,
555 )
556 })??;
557 }
558 AgentResponseEvent::ToolCallUpdate(tool_call_update) => {
559 acp_thread.update(cx, |thread, cx| {
560 thread.handle_session_update(
561 acp::SessionUpdate::ToolCallUpdate(tool_call_update),
562 cx,
563 )
564 })??;
565 }
566 AgentResponseEvent::Stop(stop_reason) => {
567 log::debug!("Assistant message complete: {:?}", stop_reason);
568 return Ok(acp::PromptResponse { stop_reason });
569 }
570 }
571 }
572 Err(e) => {
573 log::error!("Error in model response stream: {:?}", e);
574 // TODO: Consider sending an error message to the UI
575 break;
576 }
577 }
578 }
579
580 log::info!("Response stream completed");
581 anyhow::Ok(acp::PromptResponse {
582 stop_reason: acp::StopReason::EndTurn,
583 })
584 })
585 }
586
587 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
588 log::info!("Cancelling on session: {}", session_id);
589 self.0.update(cx, |agent, cx| {
590 if let Some(agent) = agent.sessions.get(session_id) {
591 agent.thread.update(cx, |thread, _cx| thread.cancel());
592 }
593 });
594 }
595}
596
597/// Convert ACP content blocks to a message string
598fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
599 log::debug!("Converting {} content blocks to message", blocks.len());
600 let mut message = String::new();
601
602 for block in blocks {
603 match block {
604 acp::ContentBlock::Text(text) => {
605 log::trace!("Processing text block: {} chars", text.text.len());
606 message.push_str(&text.text);
607 }
608 acp::ContentBlock::ResourceLink(link) => {
609 log::trace!("Processing resource link: {}", link.uri);
610 message.push_str(&format!(" @{} ", link.uri));
611 }
612 acp::ContentBlock::Image(_) => {
613 log::trace!("Processing image block");
614 message.push_str(" [image] ");
615 }
616 acp::ContentBlock::Audio(_) => {
617 log::trace!("Processing audio block");
618 message.push_str(" [audio] ");
619 }
620 acp::ContentBlock::Resource(resource) => {
621 log::trace!("Processing resource block: {:?}", resource.resource);
622 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
623 }
624 }
625 }
626
627 message
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use fs::FakeFs;
634 use gpui::TestAppContext;
635 use serde_json::json;
636 use settings::SettingsStore;
637
638 #[gpui::test]
639 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
640 init_test(cx);
641 let fs = FakeFs::new(cx.executor());
642 fs.insert_tree(
643 "/",
644 json!({
645 "a": {}
646 }),
647 )
648 .await;
649 let project = Project::test(fs.clone(), [], cx).await;
650 let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
651 .await
652 .unwrap();
653 agent.read_with(cx, |agent, _| {
654 assert_eq!(agent.project_context.borrow().worktrees, vec![])
655 });
656
657 let worktree = project
658 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
659 .await
660 .unwrap();
661 cx.run_until_parked();
662 agent.read_with(cx, |agent, _| {
663 assert_eq!(
664 agent.project_context.borrow().worktrees,
665 vec![WorktreeContext {
666 root_name: "a".into(),
667 abs_path: Path::new("/a").into(),
668 rules_file: None
669 }]
670 )
671 });
672
673 // Creating `/a/.rules` updates the project context.
674 fs.insert_file("/a/.rules", Vec::new()).await;
675 cx.run_until_parked();
676 agent.read_with(cx, |agent, cx| {
677 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
678 assert_eq!(
679 agent.project_context.borrow().worktrees,
680 vec![WorktreeContext {
681 root_name: "a".into(),
682 abs_path: Path::new("/a").into(),
683 rules_file: Some(RulesFileContext {
684 path_in_worktree: Path::new(".rules").into(),
685 text: "".into(),
686 project_entry_id: rules_entry.id.to_usize()
687 })
688 }]
689 )
690 });
691 }
692
693 fn init_test(cx: &mut TestAppContext) {
694 env_logger::try_init().ok();
695 cx.update(|cx| {
696 let settings_store = SettingsStore::test(cx);
697 cx.set_global(settings_store);
698 Project::init_settings(cx);
699 language::init(cx);
700 });
701 }
702}