1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, ListDirectoryTool, MovePathTool,
4 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization,
5};
6use acp_thread::ModelSelector;
7use agent_client_protocol as acp;
8use anyhow::{Context as _, Result, anyhow};
9use futures::{StreamExt, future};
10use gpui::{
11 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
12};
13use language_model::{LanguageModel, LanguageModelRegistry};
14use project::{Project, ProjectItem, ProjectPath, Worktree};
15use prompt_store::{
16 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
17};
18use std::cell::RefCell;
19use std::collections::HashMap;
20use std::path::Path;
21use std::rc::Rc;
22use std::sync::Arc;
23use util::ResultExt;
24
25const RULES_FILE_NAMES: [&'static str; 9] = [
26 ".rules",
27 ".cursorrules",
28 ".windsurfrules",
29 ".clinerules",
30 ".github/copilot-instructions.md",
31 "CLAUDE.md",
32 "AGENT.md",
33 "AGENTS.md",
34 "GEMINI.md",
35];
36
37pub struct RulesLoadingError {
38 pub message: SharedString,
39}
40
41/// Holds both the internal Thread and the AcpThread for a session
42struct Session {
43 /// The internal thread that processes messages
44 thread: Entity<Thread>,
45 /// The ACP thread that handles protocol communication
46 acp_thread: WeakEntity<acp_thread::AcpThread>,
47 _subscription: Subscription,
48}
49
50pub struct NativeAgent {
51 /// Session ID -> Session mapping
52 sessions: HashMap<acp::SessionId, Session>,
53 /// Shared project context for all threads
54 project_context: Rc<RefCell<ProjectContext>>,
55 project_context_needs_refresh: watch::Sender<()>,
56 _maintain_project_context: Task<Result<()>>,
57 /// Shared templates for all threads
58 templates: Arc<Templates>,
59 project: Entity<Project>,
60 prompt_store: Option<Entity<PromptStore>>,
61 _subscriptions: Vec<Subscription>,
62}
63
64impl NativeAgent {
65 pub async fn new(
66 project: Entity<Project>,
67 templates: Arc<Templates>,
68 prompt_store: Option<Entity<PromptStore>>,
69 cx: &mut AsyncApp,
70 ) -> Result<Entity<NativeAgent>> {
71 log::info!("Creating new NativeAgent");
72
73 let project_context = cx
74 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
75 .await;
76
77 cx.new(|cx| {
78 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
79 if let Some(prompt_store) = prompt_store.as_ref() {
80 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
81 }
82
83 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
84 watch::channel(());
85 Self {
86 sessions: HashMap::new(),
87 project_context: Rc::new(RefCell::new(project_context)),
88 project_context_needs_refresh: project_context_needs_refresh_tx,
89 _maintain_project_context: cx.spawn(async move |this, cx| {
90 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
91 }),
92 templates,
93 project,
94 prompt_store,
95 _subscriptions: subscriptions,
96 }
97 })
98 }
99
100 async fn maintain_project_context(
101 this: WeakEntity<Self>,
102 mut needs_refresh: watch::Receiver<()>,
103 cx: &mut AsyncApp,
104 ) -> Result<()> {
105 while needs_refresh.changed().await.is_ok() {
106 let project_context = this
107 .update(cx, |this, cx| {
108 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
109 })?
110 .await;
111 this.update(cx, |this, _| this.project_context.replace(project_context))?;
112 }
113
114 Ok(())
115 }
116
117 fn build_project_context(
118 project: &Entity<Project>,
119 prompt_store: Option<&Entity<PromptStore>>,
120 cx: &mut App,
121 ) -> Task<ProjectContext> {
122 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
123 let worktree_tasks = worktrees
124 .into_iter()
125 .map(|worktree| {
126 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
127 })
128 .collect::<Vec<_>>();
129 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
130 prompt_store.read_with(cx, |prompt_store, cx| {
131 let prompts = prompt_store.default_prompt_metadata();
132 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
133 let contents = prompt_store.load(prompt_metadata.id, cx);
134 async move { (contents.await, prompt_metadata) }
135 });
136 cx.background_spawn(future::join_all(load_tasks))
137 })
138 } else {
139 Task::ready(vec![])
140 };
141
142 cx.spawn(async move |_cx| {
143 let (worktrees, default_user_rules) =
144 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
145
146 let worktrees = worktrees
147 .into_iter()
148 .map(|(worktree, _rules_error)| {
149 // TODO: show error message
150 // if let Some(rules_error) = rules_error {
151 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
152 // }
153 worktree
154 })
155 .collect::<Vec<_>>();
156
157 let default_user_rules = default_user_rules
158 .into_iter()
159 .flat_map(|(contents, prompt_metadata)| match contents {
160 Ok(contents) => Some(UserRulesContext {
161 uuid: match prompt_metadata.id {
162 PromptId::User { uuid } => uuid,
163 PromptId::EditWorkflow => return None,
164 },
165 title: prompt_metadata.title.map(|title| title.to_string()),
166 contents,
167 }),
168 Err(_err) => {
169 // TODO: show error message
170 // this.update(cx, |_, cx| {
171 // cx.emit(RulesLoadingError {
172 // message: format!("{err:?}").into(),
173 // });
174 // })
175 // .ok();
176 None
177 }
178 })
179 .collect::<Vec<_>>();
180
181 ProjectContext::new(worktrees, default_user_rules)
182 })
183 }
184
185 fn load_worktree_info_for_system_prompt(
186 worktree: Entity<Worktree>,
187 project: Entity<Project>,
188 cx: &mut App,
189 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
190 let tree = worktree.read(cx);
191 let root_name = tree.root_name().into();
192 let abs_path = tree.abs_path();
193
194 let mut context = WorktreeContext {
195 root_name,
196 abs_path,
197 rules_file: None,
198 };
199
200 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
201 let Some(rules_task) = rules_task else {
202 return Task::ready((context, None));
203 };
204
205 cx.spawn(async move |_| {
206 let (rules_file, rules_file_error) = match rules_task.await {
207 Ok(rules_file) => (Some(rules_file), None),
208 Err(err) => (
209 None,
210 Some(RulesLoadingError {
211 message: format!("{err}").into(),
212 }),
213 ),
214 };
215 context.rules_file = rules_file;
216 (context, rules_file_error)
217 })
218 }
219
220 fn load_worktree_rules_file(
221 worktree: Entity<Worktree>,
222 project: Entity<Project>,
223 cx: &mut App,
224 ) -> Option<Task<Result<RulesFileContext>>> {
225 let worktree = worktree.read(cx);
226 let worktree_id = worktree.id();
227 let selected_rules_file = RULES_FILE_NAMES
228 .into_iter()
229 .filter_map(|name| {
230 worktree
231 .entry_for_path(name)
232 .filter(|entry| entry.is_file())
233 .map(|entry| entry.path.clone())
234 })
235 .next();
236
237 // Note that Cline supports `.clinerules` being a directory, but that is not currently
238 // supported. This doesn't seem to occur often in GitHub repositories.
239 selected_rules_file.map(|path_in_worktree| {
240 let project_path = ProjectPath {
241 worktree_id,
242 path: path_in_worktree.clone(),
243 };
244 let buffer_task =
245 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
246 let rope_task = cx.spawn(async move |cx| {
247 buffer_task.await?.read_with(cx, |buffer, cx| {
248 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
249 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
250 })?
251 });
252 // Build a string from the rope on a background thread.
253 cx.background_spawn(async move {
254 let (project_entry_id, rope) = rope_task.await?;
255 anyhow::Ok(RulesFileContext {
256 path_in_worktree,
257 text: rope.to_string().trim().to_string(),
258 project_entry_id: project_entry_id.to_usize(),
259 })
260 })
261 })
262 }
263
264 fn handle_project_event(
265 &mut self,
266 _project: Entity<Project>,
267 event: &project::Event,
268 _cx: &mut Context<Self>,
269 ) {
270 match event {
271 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
272 self.project_context_needs_refresh.send(()).ok();
273 }
274 project::Event::WorktreeUpdatedEntries(_, items) => {
275 if items.iter().any(|(path, _, _)| {
276 RULES_FILE_NAMES
277 .iter()
278 .any(|name| path.as_ref() == Path::new(name))
279 }) {
280 self.project_context_needs_refresh.send(()).ok();
281 }
282 }
283 _ => {}
284 }
285 }
286
287 fn handle_prompts_updated_event(
288 &mut self,
289 _prompt_store: Entity<PromptStore>,
290 _event: &prompt_store::PromptsUpdatedEvent,
291 _cx: &mut Context<Self>,
292 ) {
293 self.project_context_needs_refresh.send(()).ok();
294 }
295}
296
297/// Wrapper struct that implements the AgentConnection trait
298#[derive(Clone)]
299pub struct NativeAgentConnection(pub Entity<NativeAgent>);
300
301impl ModelSelector for NativeAgentConnection {
302 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
303 log::debug!("NativeAgentConnection::list_models called");
304 cx.spawn(async move |cx| {
305 cx.update(|cx| {
306 let registry = LanguageModelRegistry::read_global(cx);
307 let models = registry.available_models(cx).collect::<Vec<_>>();
308 log::info!("Found {} available models", models.len());
309 if models.is_empty() {
310 Err(anyhow::anyhow!("No models available"))
311 } else {
312 Ok(models)
313 }
314 })?
315 })
316 }
317
318 fn select_model(
319 &self,
320 session_id: acp::SessionId,
321 model: Arc<dyn LanguageModel>,
322 cx: &mut AsyncApp,
323 ) -> Task<Result<()>> {
324 log::info!(
325 "Setting model for session {}: {:?}",
326 session_id,
327 model.name()
328 );
329 let agent = self.0.clone();
330
331 cx.spawn(async move |cx| {
332 agent.update(cx, |agent, cx| {
333 if let Some(session) = agent.sessions.get(&session_id) {
334 session.thread.update(cx, |thread, _cx| {
335 thread.selected_model = model;
336 });
337 Ok(())
338 } else {
339 Err(anyhow!("Session not found"))
340 }
341 })?
342 })
343 }
344
345 fn selected_model(
346 &self,
347 session_id: &acp::SessionId,
348 cx: &mut AsyncApp,
349 ) -> Task<Result<Arc<dyn LanguageModel>>> {
350 let agent = self.0.clone();
351 let session_id = session_id.clone();
352 cx.spawn(async move |cx| {
353 let thread = agent
354 .read_with(cx, |agent, _| {
355 agent
356 .sessions
357 .get(&session_id)
358 .map(|session| session.thread.clone())
359 })?
360 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
361 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
362 Ok(selected)
363 })
364 }
365}
366
367impl acp_thread::AgentConnection for NativeAgentConnection {
368 fn new_thread(
369 self: Rc<Self>,
370 project: Entity<Project>,
371 cwd: &Path,
372 cx: &mut AsyncApp,
373 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
374 let agent = self.0.clone();
375 log::info!("Creating new thread for project at: {:?}", cwd);
376
377 cx.spawn(async move |cx| {
378 log::debug!("Starting thread creation in async context");
379
380 // Generate session ID
381 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
382 log::info!("Created session with ID: {}", session_id);
383
384 // Create AcpThread
385 let acp_thread = cx.update(|cx| {
386 cx.new(|cx| {
387 acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
388 })
389 })?;
390 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
391
392 // Create Thread
393 let thread = agent.update(
394 cx,
395 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
396 // Fetch default model from registry settings
397 let registry = LanguageModelRegistry::read_global(cx);
398
399 // Log available models for debugging
400 let available_count = registry.available_models(cx).count();
401 log::debug!("Total available models: {}", available_count);
402
403 let default_model = registry
404 .default_model()
405 .map(|configured| {
406 log::info!(
407 "Using configured default model: {:?} from provider: {:?}",
408 configured.model.name(),
409 configured.provider.name()
410 );
411 configured.model
412 })
413 .ok_or_else(|| {
414 log::warn!("No default model configured in settings");
415 anyhow!("No default model configured. Please configure a default model in settings.")
416 })?;
417
418 let thread = cx.new(|cx| {
419 let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
420 thread.add_tool(CreateDirectoryTool::new(project.clone()));
421 thread.add_tool(CopyPathTool::new(project.clone()));
422 thread.add_tool(MovePathTool::new(project.clone()));
423 thread.add_tool(ListDirectoryTool::new(project.clone()));
424 thread.add_tool(OpenTool::new(project.clone()));
425 thread.add_tool(ThinkingTool);
426 thread.add_tool(FindPathTool::new(project.clone()));
427 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
428 thread.add_tool(EditFileTool::new(cx.entity()));
429 thread.add_tool(TerminalTool::new(project.clone(), cx));
430 thread
431 });
432
433 Ok(thread)
434 },
435 )??;
436
437 // Store the session
438 agent.update(cx, |agent, cx| {
439 agent.sessions.insert(
440 session_id,
441 Session {
442 thread,
443 acp_thread: acp_thread.downgrade(),
444 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
445 this.sessions.remove(acp_thread.session_id());
446 })
447 },
448 );
449 })?;
450
451 Ok(acp_thread)
452 })
453 }
454
455 fn auth_methods(&self) -> &[acp::AuthMethod] {
456 &[] // No auth for in-process
457 }
458
459 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
460 Task::ready(Ok(()))
461 }
462
463 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
464 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
465 }
466
467 fn prompt(
468 &self,
469 params: acp::PromptRequest,
470 cx: &mut App,
471 ) -> Task<Result<acp::PromptResponse>> {
472 let session_id = params.session_id.clone();
473 let agent = self.0.clone();
474 log::info!("Received prompt request for session: {}", session_id);
475 log::debug!("Prompt blocks count: {}", params.prompt.len());
476
477 cx.spawn(async move |cx| {
478 // Get session
479 let (thread, acp_thread) = agent
480 .update(cx, |agent, _| {
481 agent
482 .sessions
483 .get_mut(&session_id)
484 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
485 })?
486 .ok_or_else(|| {
487 log::error!("Session not found: {}", session_id);
488 anyhow::anyhow!("Session not found")
489 })?;
490 log::debug!("Found session for: {}", session_id);
491
492 // Convert prompt to message
493 let message = convert_prompt_to_message(params.prompt);
494 log::info!("Converted prompt to message: {} chars", message.len());
495 log::debug!("Message content: {}", message);
496
497 // Get model using the ModelSelector capability (always available for agent2)
498 // Get the selected model from the thread directly
499 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
500
501 // Send to thread
502 log::info!("Sending message to thread with model: {:?}", model.name());
503 let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
504
505 // Handle response stream and forward to session.acp_thread
506 while let Some(result) = response_stream.next().await {
507 match result {
508 Ok(event) => {
509 log::trace!("Received completion event: {:?}", event);
510
511 match event {
512 AgentResponseEvent::Text(text) => {
513 acp_thread.update(cx, |thread, cx| {
514 thread.push_assistant_content_block(
515 acp::ContentBlock::Text(acp::TextContent {
516 text,
517 annotations: None,
518 }),
519 false,
520 cx,
521 )
522 })?;
523 }
524 AgentResponseEvent::Thinking(text) => {
525 acp_thread.update(cx, |thread, cx| {
526 thread.push_assistant_content_block(
527 acp::ContentBlock::Text(acp::TextContent {
528 text,
529 annotations: None,
530 }),
531 true,
532 cx,
533 )
534 })?;
535 }
536 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
537 tool_call,
538 options,
539 response,
540 }) => {
541 let recv = acp_thread.update(cx, |thread, cx| {
542 thread.request_tool_call_authorization(tool_call, options, cx)
543 })?;
544 cx.background_spawn(async move {
545 if let Some(option) = recv
546 .await
547 .context("authorization sender was dropped")
548 .log_err()
549 {
550 response
551 .send(option)
552 .map(|_| anyhow!("authorization receiver was dropped"))
553 .log_err();
554 }
555 })
556 .detach();
557 }
558 AgentResponseEvent::ToolCall(tool_call) => {
559 acp_thread.update(cx, |thread, cx| {
560 thread.upsert_tool_call(tool_call, cx)
561 })?;
562 }
563 AgentResponseEvent::ToolCallUpdate(update) => {
564 acp_thread.update(cx, |thread, cx| {
565 thread.update_tool_call(update, cx)
566 })??;
567 }
568 AgentResponseEvent::Stop(stop_reason) => {
569 log::debug!("Assistant message complete: {:?}", stop_reason);
570 return Ok(acp::PromptResponse { stop_reason });
571 }
572 }
573 }
574 Err(e) => {
575 log::error!("Error in model response stream: {:?}", e);
576 // TODO: Consider sending an error message to the UI
577 break;
578 }
579 }
580 }
581
582 log::info!("Response stream completed");
583 anyhow::Ok(acp::PromptResponse {
584 stop_reason: acp::StopReason::EndTurn,
585 })
586 })
587 }
588
589 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
590 log::info!("Cancelling on session: {}", session_id);
591 self.0.update(cx, |agent, cx| {
592 if let Some(agent) = agent.sessions.get(session_id) {
593 agent.thread.update(cx, |thread, _cx| thread.cancel());
594 }
595 });
596 }
597}
598
599/// Convert ACP content blocks to a message string
600fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
601 log::debug!("Converting {} content blocks to message", blocks.len());
602 let mut message = String::new();
603
604 for block in blocks {
605 match block {
606 acp::ContentBlock::Text(text) => {
607 log::trace!("Processing text block: {} chars", text.text.len());
608 message.push_str(&text.text);
609 }
610 acp::ContentBlock::ResourceLink(link) => {
611 log::trace!("Processing resource link: {}", link.uri);
612 message.push_str(&format!(" @{} ", link.uri));
613 }
614 acp::ContentBlock::Image(_) => {
615 log::trace!("Processing image block");
616 message.push_str(" [image] ");
617 }
618 acp::ContentBlock::Audio(_) => {
619 log::trace!("Processing audio block");
620 message.push_str(" [audio] ");
621 }
622 acp::ContentBlock::Resource(resource) => {
623 log::trace!("Processing resource block: {:?}", resource.resource);
624 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
625 }
626 }
627 }
628
629 message
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use fs::FakeFs;
636 use gpui::TestAppContext;
637 use serde_json::json;
638 use settings::SettingsStore;
639
640 #[gpui::test]
641 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
642 init_test(cx);
643 let fs = FakeFs::new(cx.executor());
644 fs.insert_tree(
645 "/",
646 json!({
647 "a": {}
648 }),
649 )
650 .await;
651 let project = Project::test(fs.clone(), [], cx).await;
652 let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
653 .await
654 .unwrap();
655 agent.read_with(cx, |agent, _| {
656 assert_eq!(agent.project_context.borrow().worktrees, vec![])
657 });
658
659 let worktree = project
660 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
661 .await
662 .unwrap();
663 cx.run_until_parked();
664 agent.read_with(cx, |agent, _| {
665 assert_eq!(
666 agent.project_context.borrow().worktrees,
667 vec![WorktreeContext {
668 root_name: "a".into(),
669 abs_path: Path::new("/a").into(),
670 rules_file: None
671 }]
672 )
673 });
674
675 // Creating `/a/.rules` updates the project context.
676 fs.insert_file("/a/.rules", Vec::new()).await;
677 cx.run_until_parked();
678 agent.read_with(cx, |agent, cx| {
679 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
680 assert_eq!(
681 agent.project_context.borrow().worktrees,
682 vec![WorktreeContext {
683 root_name: "a".into(),
684 abs_path: Path::new("/a").into(),
685 rules_file: Some(RulesFileContext {
686 path_in_worktree: Path::new(".rules").into(),
687 text: "".into(),
688 project_entry_id: rules_entry.id.to_usize()
689 })
690 }]
691 )
692 });
693 }
694
695 fn init_test(cx: &mut TestAppContext) {
696 env_logger::try_init().ok();
697 cx.update(|cx| {
698 let settings_store = SettingsStore::test(cx);
699 cx.set_global(settings_store);
700 Project::init_settings(cx);
701 language::init(cx);
702 });
703 }
704}