1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool,
4 GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool,
5 ThinkingTool, ToolCallAuthorization, WebSearchTool,
6};
7use acp_thread::ModelSelector;
8use agent_client_protocol as acp;
9use anyhow::{Context as _, Result, anyhow};
10use futures::{StreamExt, future};
11use gpui::{
12 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
13};
14use language_model::{LanguageModel, LanguageModelRegistry};
15use project::{Project, ProjectItem, ProjectPath, Worktree};
16use prompt_store::{
17 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
18};
19use std::cell::RefCell;
20use std::collections::HashMap;
21use std::path::Path;
22use std::rc::Rc;
23use std::sync::Arc;
24use util::ResultExt;
25
26const RULES_FILE_NAMES: [&'static str; 9] = [
27 ".rules",
28 ".cursorrules",
29 ".windsurfrules",
30 ".clinerules",
31 ".github/copilot-instructions.md",
32 "CLAUDE.md",
33 "AGENT.md",
34 "AGENTS.md",
35 "GEMINI.md",
36];
37
38pub struct RulesLoadingError {
39 pub message: SharedString,
40}
41
42/// Holds both the internal Thread and the AcpThread for a session
43struct Session {
44 /// The internal thread that processes messages
45 thread: Entity<Thread>,
46 /// The ACP thread that handles protocol communication
47 acp_thread: WeakEntity<acp_thread::AcpThread>,
48 _subscription: Subscription,
49}
50
51pub struct NativeAgent {
52 /// Session ID -> Session mapping
53 sessions: HashMap<acp::SessionId, Session>,
54 /// Shared project context for all threads
55 project_context: Rc<RefCell<ProjectContext>>,
56 project_context_needs_refresh: watch::Sender<()>,
57 _maintain_project_context: Task<Result<()>>,
58 /// Shared templates for all threads
59 templates: Arc<Templates>,
60 project: Entity<Project>,
61 prompt_store: Option<Entity<PromptStore>>,
62 _subscriptions: Vec<Subscription>,
63}
64
65impl NativeAgent {
66 pub async fn new(
67 project: Entity<Project>,
68 templates: Arc<Templates>,
69 prompt_store: Option<Entity<PromptStore>>,
70 cx: &mut AsyncApp,
71 ) -> Result<Entity<NativeAgent>> {
72 log::info!("Creating new NativeAgent");
73
74 let project_context = cx
75 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
76 .await;
77
78 cx.new(|cx| {
79 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
80 if let Some(prompt_store) = prompt_store.as_ref() {
81 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
82 }
83
84 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
85 watch::channel(());
86 Self {
87 sessions: HashMap::new(),
88 project_context: Rc::new(RefCell::new(project_context)),
89 project_context_needs_refresh: project_context_needs_refresh_tx,
90 _maintain_project_context: cx.spawn(async move |this, cx| {
91 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
92 }),
93 templates,
94 project,
95 prompt_store,
96 _subscriptions: subscriptions,
97 }
98 })
99 }
100
101 async fn maintain_project_context(
102 this: WeakEntity<Self>,
103 mut needs_refresh: watch::Receiver<()>,
104 cx: &mut AsyncApp,
105 ) -> Result<()> {
106 while needs_refresh.changed().await.is_ok() {
107 let project_context = this
108 .update(cx, |this, cx| {
109 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
110 })?
111 .await;
112 this.update(cx, |this, _| this.project_context.replace(project_context))?;
113 }
114
115 Ok(())
116 }
117
118 fn build_project_context(
119 project: &Entity<Project>,
120 prompt_store: Option<&Entity<PromptStore>>,
121 cx: &mut App,
122 ) -> Task<ProjectContext> {
123 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
124 let worktree_tasks = worktrees
125 .into_iter()
126 .map(|worktree| {
127 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
128 })
129 .collect::<Vec<_>>();
130 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
131 prompt_store.read_with(cx, |prompt_store, cx| {
132 let prompts = prompt_store.default_prompt_metadata();
133 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
134 let contents = prompt_store.load(prompt_metadata.id, cx);
135 async move { (contents.await, prompt_metadata) }
136 });
137 cx.background_spawn(future::join_all(load_tasks))
138 })
139 } else {
140 Task::ready(vec![])
141 };
142
143 cx.spawn(async move |_cx| {
144 let (worktrees, default_user_rules) =
145 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
146
147 let worktrees = worktrees
148 .into_iter()
149 .map(|(worktree, _rules_error)| {
150 // TODO: show error message
151 // if let Some(rules_error) = rules_error {
152 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
153 // }
154 worktree
155 })
156 .collect::<Vec<_>>();
157
158 let default_user_rules = default_user_rules
159 .into_iter()
160 .flat_map(|(contents, prompt_metadata)| match contents {
161 Ok(contents) => Some(UserRulesContext {
162 uuid: match prompt_metadata.id {
163 PromptId::User { uuid } => uuid,
164 PromptId::EditWorkflow => return None,
165 },
166 title: prompt_metadata.title.map(|title| title.to_string()),
167 contents,
168 }),
169 Err(_err) => {
170 // TODO: show error message
171 // this.update(cx, |_, cx| {
172 // cx.emit(RulesLoadingError {
173 // message: format!("{err:?}").into(),
174 // });
175 // })
176 // .ok();
177 None
178 }
179 })
180 .collect::<Vec<_>>();
181
182 ProjectContext::new(worktrees, default_user_rules)
183 })
184 }
185
186 fn load_worktree_info_for_system_prompt(
187 worktree: Entity<Worktree>,
188 project: Entity<Project>,
189 cx: &mut App,
190 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
191 let tree = worktree.read(cx);
192 let root_name = tree.root_name().into();
193 let abs_path = tree.abs_path();
194
195 let mut context = WorktreeContext {
196 root_name,
197 abs_path,
198 rules_file: None,
199 };
200
201 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
202 let Some(rules_task) = rules_task else {
203 return Task::ready((context, None));
204 };
205
206 cx.spawn(async move |_| {
207 let (rules_file, rules_file_error) = match rules_task.await {
208 Ok(rules_file) => (Some(rules_file), None),
209 Err(err) => (
210 None,
211 Some(RulesLoadingError {
212 message: format!("{err}").into(),
213 }),
214 ),
215 };
216 context.rules_file = rules_file;
217 (context, rules_file_error)
218 })
219 }
220
221 fn load_worktree_rules_file(
222 worktree: Entity<Worktree>,
223 project: Entity<Project>,
224 cx: &mut App,
225 ) -> Option<Task<Result<RulesFileContext>>> {
226 let worktree = worktree.read(cx);
227 let worktree_id = worktree.id();
228 let selected_rules_file = RULES_FILE_NAMES
229 .into_iter()
230 .filter_map(|name| {
231 worktree
232 .entry_for_path(name)
233 .filter(|entry| entry.is_file())
234 .map(|entry| entry.path.clone())
235 })
236 .next();
237
238 // Note that Cline supports `.clinerules` being a directory, but that is not currently
239 // supported. This doesn't seem to occur often in GitHub repositories.
240 selected_rules_file.map(|path_in_worktree| {
241 let project_path = ProjectPath {
242 worktree_id,
243 path: path_in_worktree.clone(),
244 };
245 let buffer_task =
246 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
247 let rope_task = cx.spawn(async move |cx| {
248 buffer_task.await?.read_with(cx, |buffer, cx| {
249 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
250 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
251 })?
252 });
253 // Build a string from the rope on a background thread.
254 cx.background_spawn(async move {
255 let (project_entry_id, rope) = rope_task.await?;
256 anyhow::Ok(RulesFileContext {
257 path_in_worktree,
258 text: rope.to_string().trim().to_string(),
259 project_entry_id: project_entry_id.to_usize(),
260 })
261 })
262 })
263 }
264
265 fn handle_project_event(
266 &mut self,
267 _project: Entity<Project>,
268 event: &project::Event,
269 _cx: &mut Context<Self>,
270 ) {
271 match event {
272 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
273 self.project_context_needs_refresh.send(()).ok();
274 }
275 project::Event::WorktreeUpdatedEntries(_, items) => {
276 if items.iter().any(|(path, _, _)| {
277 RULES_FILE_NAMES
278 .iter()
279 .any(|name| path.as_ref() == Path::new(name))
280 }) {
281 self.project_context_needs_refresh.send(()).ok();
282 }
283 }
284 _ => {}
285 }
286 }
287
288 fn handle_prompts_updated_event(
289 &mut self,
290 _prompt_store: Entity<PromptStore>,
291 _event: &prompt_store::PromptsUpdatedEvent,
292 _cx: &mut Context<Self>,
293 ) {
294 self.project_context_needs_refresh.send(()).ok();
295 }
296}
297
298/// Wrapper struct that implements the AgentConnection trait
299#[derive(Clone)]
300pub struct NativeAgentConnection(pub Entity<NativeAgent>);
301
302impl ModelSelector for NativeAgentConnection {
303 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
304 log::debug!("NativeAgentConnection::list_models called");
305 cx.spawn(async move |cx| {
306 cx.update(|cx| {
307 let registry = LanguageModelRegistry::read_global(cx);
308 let models = registry.available_models(cx).collect::<Vec<_>>();
309 log::info!("Found {} available models", models.len());
310 if models.is_empty() {
311 Err(anyhow::anyhow!("No models available"))
312 } else {
313 Ok(models)
314 }
315 })?
316 })
317 }
318
319 fn select_model(
320 &self,
321 session_id: acp::SessionId,
322 model: Arc<dyn LanguageModel>,
323 cx: &mut AsyncApp,
324 ) -> Task<Result<()>> {
325 log::info!(
326 "Setting model for session {}: {:?}",
327 session_id,
328 model.name()
329 );
330 let agent = self.0.clone();
331
332 cx.spawn(async move |cx| {
333 agent.update(cx, |agent, cx| {
334 if let Some(session) = agent.sessions.get(&session_id) {
335 session.thread.update(cx, |thread, _cx| {
336 thread.selected_model = model;
337 });
338 Ok(())
339 } else {
340 Err(anyhow!("Session not found"))
341 }
342 })?
343 })
344 }
345
346 fn selected_model(
347 &self,
348 session_id: &acp::SessionId,
349 cx: &mut AsyncApp,
350 ) -> Task<Result<Arc<dyn LanguageModel>>> {
351 let agent = self.0.clone();
352 let session_id = session_id.clone();
353 cx.spawn(async move |cx| {
354 let thread = agent
355 .read_with(cx, |agent, _| {
356 agent
357 .sessions
358 .get(&session_id)
359 .map(|session| session.thread.clone())
360 })?
361 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
362 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
363 Ok(selected)
364 })
365 }
366}
367
368impl acp_thread::AgentConnection for NativeAgentConnection {
369 fn new_thread(
370 self: Rc<Self>,
371 project: Entity<Project>,
372 cwd: &Path,
373 cx: &mut AsyncApp,
374 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
375 let agent = self.0.clone();
376 log::info!("Creating new thread for project at: {:?}", cwd);
377
378 cx.spawn(async move |cx| {
379 log::debug!("Starting thread creation in async context");
380
381 // Generate session ID
382 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
383 log::info!("Created session with ID: {}", session_id);
384
385 // Create AcpThread
386 let acp_thread = cx.update(|cx| {
387 cx.new(|cx| {
388 acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
389 })
390 })?;
391 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
392
393 // Create Thread
394 let thread = agent.update(
395 cx,
396 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
397 // Fetch default model from registry settings
398 let registry = LanguageModelRegistry::read_global(cx);
399
400 // Log available models for debugging
401 let available_count = registry.available_models(cx).count();
402 log::debug!("Total available models: {}", available_count);
403
404 let default_model = registry
405 .default_model()
406 .map(|configured| {
407 log::info!(
408 "Using configured default model: {:?} from provider: {:?}",
409 configured.model.name(),
410 configured.provider.name()
411 );
412 configured.model
413 })
414 .ok_or_else(|| {
415 log::warn!("No default model configured in settings");
416 anyhow!("No default model configured. Please configure a default model in settings.")
417 })?;
418
419 let thread = cx.new(|cx| {
420 let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
421 thread.add_tool(CreateDirectoryTool::new(project.clone()));
422 thread.add_tool(CopyPathTool::new(project.clone()));
423 thread.add_tool(DiagnosticsTool::new(project.clone()));
424 thread.add_tool(MovePathTool::new(project.clone()));
425 thread.add_tool(ListDirectoryTool::new(project.clone()));
426 thread.add_tool(OpenTool::new(project.clone()));
427 thread.add_tool(ThinkingTool);
428 thread.add_tool(FindPathTool::new(project.clone()));
429 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
430 thread.add_tool(GrepTool::new(project.clone()));
431 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
432 thread.add_tool(EditFileTool::new(cx.entity()));
433 thread.add_tool(NowTool);
434 thread.add_tool(TerminalTool::new(project.clone(), cx));
435 // TODO: Needs to be conditional based on zed model or not
436 thread.add_tool(WebSearchTool);
437 thread
438 });
439
440 Ok(thread)
441 },
442 )??;
443
444 // Store the session
445 agent.update(cx, |agent, cx| {
446 agent.sessions.insert(
447 session_id,
448 Session {
449 thread,
450 acp_thread: acp_thread.downgrade(),
451 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
452 this.sessions.remove(acp_thread.session_id());
453 })
454 },
455 );
456 })?;
457
458 Ok(acp_thread)
459 })
460 }
461
462 fn auth_methods(&self) -> &[acp::AuthMethod] {
463 &[] // No auth for in-process
464 }
465
466 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
467 Task::ready(Ok(()))
468 }
469
470 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
471 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
472 }
473
474 fn prompt(
475 &self,
476 params: acp::PromptRequest,
477 cx: &mut App,
478 ) -> Task<Result<acp::PromptResponse>> {
479 let session_id = params.session_id.clone();
480 let agent = self.0.clone();
481 log::info!("Received prompt request for session: {}", session_id);
482 log::debug!("Prompt blocks count: {}", params.prompt.len());
483
484 cx.spawn(async move |cx| {
485 // Get session
486 let (thread, acp_thread) = agent
487 .update(cx, |agent, _| {
488 agent
489 .sessions
490 .get_mut(&session_id)
491 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
492 })?
493 .ok_or_else(|| {
494 log::error!("Session not found: {}", session_id);
495 anyhow::anyhow!("Session not found")
496 })?;
497 log::debug!("Found session for: {}", session_id);
498
499 // Convert prompt to message
500 let message = convert_prompt_to_message(params.prompt);
501 log::info!("Converted prompt to message: {} chars", message.len());
502 log::debug!("Message content: {}", message);
503
504 // Get model using the ModelSelector capability (always available for agent2)
505 // Get the selected model from the thread directly
506 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
507
508 // Send to thread
509 log::info!("Sending message to thread with model: {:?}", model.name());
510 let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
511
512 // Handle response stream and forward to session.acp_thread
513 while let Some(result) = response_stream.next().await {
514 match result {
515 Ok(event) => {
516 log::trace!("Received completion event: {:?}", event);
517
518 match event {
519 AgentResponseEvent::Text(text) => {
520 acp_thread.update(cx, |thread, cx| {
521 thread.push_assistant_content_block(
522 acp::ContentBlock::Text(acp::TextContent {
523 text,
524 annotations: None,
525 }),
526 false,
527 cx,
528 )
529 })?;
530 }
531 AgentResponseEvent::Thinking(text) => {
532 acp_thread.update(cx, |thread, cx| {
533 thread.push_assistant_content_block(
534 acp::ContentBlock::Text(acp::TextContent {
535 text,
536 annotations: None,
537 }),
538 true,
539 cx,
540 )
541 })?;
542 }
543 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
544 tool_call,
545 options,
546 response,
547 }) => {
548 let recv = acp_thread.update(cx, |thread, cx| {
549 thread.request_tool_call_authorization(tool_call, options, cx)
550 })?;
551 cx.background_spawn(async move {
552 if let Some(option) = recv
553 .await
554 .context("authorization sender was dropped")
555 .log_err()
556 {
557 response
558 .send(option)
559 .map(|_| anyhow!("authorization receiver was dropped"))
560 .log_err();
561 }
562 })
563 .detach();
564 }
565 AgentResponseEvent::ToolCall(tool_call) => {
566 acp_thread.update(cx, |thread, cx| {
567 thread.upsert_tool_call(tool_call, cx)
568 })?;
569 }
570 AgentResponseEvent::ToolCallUpdate(update) => {
571 acp_thread.update(cx, |thread, cx| {
572 thread.update_tool_call(update, cx)
573 })??;
574 }
575 AgentResponseEvent::Stop(stop_reason) => {
576 log::debug!("Assistant message complete: {:?}", stop_reason);
577 return Ok(acp::PromptResponse { stop_reason });
578 }
579 }
580 }
581 Err(e) => {
582 log::error!("Error in model response stream: {:?}", e);
583 // TODO: Consider sending an error message to the UI
584 break;
585 }
586 }
587 }
588
589 log::info!("Response stream completed");
590 anyhow::Ok(acp::PromptResponse {
591 stop_reason: acp::StopReason::EndTurn,
592 })
593 })
594 }
595
596 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
597 log::info!("Cancelling on session: {}", session_id);
598 self.0.update(cx, |agent, cx| {
599 if let Some(agent) = agent.sessions.get(session_id) {
600 agent.thread.update(cx, |thread, _cx| thread.cancel());
601 }
602 });
603 }
604}
605
606/// Convert ACP content blocks to a message string
607fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
608 log::debug!("Converting {} content blocks to message", blocks.len());
609 let mut message = String::new();
610
611 for block in blocks {
612 match block {
613 acp::ContentBlock::Text(text) => {
614 log::trace!("Processing text block: {} chars", text.text.len());
615 message.push_str(&text.text);
616 }
617 acp::ContentBlock::ResourceLink(link) => {
618 log::trace!("Processing resource link: {}", link.uri);
619 message.push_str(&format!(" @{} ", link.uri));
620 }
621 acp::ContentBlock::Image(_) => {
622 log::trace!("Processing image block");
623 message.push_str(" [image] ");
624 }
625 acp::ContentBlock::Audio(_) => {
626 log::trace!("Processing audio block");
627 message.push_str(" [audio] ");
628 }
629 acp::ContentBlock::Resource(resource) => {
630 log::trace!("Processing resource block: {:?}", resource.resource);
631 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
632 }
633 }
634 }
635
636 message
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use fs::FakeFs;
643 use gpui::TestAppContext;
644 use serde_json::json;
645 use settings::SettingsStore;
646
647 #[gpui::test]
648 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
649 init_test(cx);
650 let fs = FakeFs::new(cx.executor());
651 fs.insert_tree(
652 "/",
653 json!({
654 "a": {}
655 }),
656 )
657 .await;
658 let project = Project::test(fs.clone(), [], cx).await;
659 let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
660 .await
661 .unwrap();
662 agent.read_with(cx, |agent, _| {
663 assert_eq!(agent.project_context.borrow().worktrees, vec![])
664 });
665
666 let worktree = project
667 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
668 .await
669 .unwrap();
670 cx.run_until_parked();
671 agent.read_with(cx, |agent, _| {
672 assert_eq!(
673 agent.project_context.borrow().worktrees,
674 vec![WorktreeContext {
675 root_name: "a".into(),
676 abs_path: Path::new("/a").into(),
677 rules_file: None
678 }]
679 )
680 });
681
682 // Creating `/a/.rules` updates the project context.
683 fs.insert_file("/a/.rules", Vec::new()).await;
684 cx.run_until_parked();
685 agent.read_with(cx, |agent, cx| {
686 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
687 assert_eq!(
688 agent.project_context.borrow().worktrees,
689 vec![WorktreeContext {
690 root_name: "a".into(),
691 abs_path: Path::new("/a").into(),
692 rules_file: Some(RulesFileContext {
693 path_in_worktree: Path::new(".rules").into(),
694 text: "".into(),
695 project_entry_id: rules_entry.id.to_usize()
696 })
697 }]
698 )
699 });
700 }
701
702 fn init_test(cx: &mut TestAppContext) {
703 env_logger::try_init().ok();
704 cx.update(|cx| {
705 let settings_store = SettingsStore::test(cx);
706 cx.set_global(settings_store);
707 Project::init_settings(cx);
708 language::init(cx);
709 });
710 }
711}