1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 CopyPathTool, CreateDirectoryTool, EditFileTool, FindPathTool, GrepTool, ListDirectoryTool,
4 MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool,
5 ToolCallAuthorization, WebSearchTool,
6};
7use acp_thread::ModelSelector;
8use agent_client_protocol as acp;
9use anyhow::{Context as _, Result, anyhow};
10use futures::{StreamExt, future};
11use gpui::{
12 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
13};
14use language_model::{LanguageModel, LanguageModelRegistry};
15use project::{Project, ProjectItem, ProjectPath, Worktree};
16use prompt_store::{
17 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
18};
19use std::cell::RefCell;
20use std::collections::HashMap;
21use std::path::Path;
22use std::rc::Rc;
23use std::sync::Arc;
24use util::ResultExt;
25
26const RULES_FILE_NAMES: [&'static str; 9] = [
27 ".rules",
28 ".cursorrules",
29 ".windsurfrules",
30 ".clinerules",
31 ".github/copilot-instructions.md",
32 "CLAUDE.md",
33 "AGENT.md",
34 "AGENTS.md",
35 "GEMINI.md",
36];
37
38pub struct RulesLoadingError {
39 pub message: SharedString,
40}
41
42/// Holds both the internal Thread and the AcpThread for a session
43struct Session {
44 /// The internal thread that processes messages
45 thread: Entity<Thread>,
46 /// The ACP thread that handles protocol communication
47 acp_thread: WeakEntity<acp_thread::AcpThread>,
48 _subscription: Subscription,
49}
50
51pub struct NativeAgent {
52 /// Session ID -> Session mapping
53 sessions: HashMap<acp::SessionId, Session>,
54 /// Shared project context for all threads
55 project_context: Rc<RefCell<ProjectContext>>,
56 project_context_needs_refresh: watch::Sender<()>,
57 _maintain_project_context: Task<Result<()>>,
58 /// Shared templates for all threads
59 templates: Arc<Templates>,
60 project: Entity<Project>,
61 prompt_store: Option<Entity<PromptStore>>,
62 _subscriptions: Vec<Subscription>,
63}
64
65impl NativeAgent {
66 pub async fn new(
67 project: Entity<Project>,
68 templates: Arc<Templates>,
69 prompt_store: Option<Entity<PromptStore>>,
70 cx: &mut AsyncApp,
71 ) -> Result<Entity<NativeAgent>> {
72 log::info!("Creating new NativeAgent");
73
74 let project_context = cx
75 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
76 .await;
77
78 cx.new(|cx| {
79 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
80 if let Some(prompt_store) = prompt_store.as_ref() {
81 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
82 }
83
84 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
85 watch::channel(());
86 Self {
87 sessions: HashMap::new(),
88 project_context: Rc::new(RefCell::new(project_context)),
89 project_context_needs_refresh: project_context_needs_refresh_tx,
90 _maintain_project_context: cx.spawn(async move |this, cx| {
91 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
92 }),
93 templates,
94 project,
95 prompt_store,
96 _subscriptions: subscriptions,
97 }
98 })
99 }
100
101 async fn maintain_project_context(
102 this: WeakEntity<Self>,
103 mut needs_refresh: watch::Receiver<()>,
104 cx: &mut AsyncApp,
105 ) -> Result<()> {
106 while needs_refresh.changed().await.is_ok() {
107 let project_context = this
108 .update(cx, |this, cx| {
109 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
110 })?
111 .await;
112 this.update(cx, |this, _| this.project_context.replace(project_context))?;
113 }
114
115 Ok(())
116 }
117
118 fn build_project_context(
119 project: &Entity<Project>,
120 prompt_store: Option<&Entity<PromptStore>>,
121 cx: &mut App,
122 ) -> Task<ProjectContext> {
123 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
124 let worktree_tasks = worktrees
125 .into_iter()
126 .map(|worktree| {
127 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
128 })
129 .collect::<Vec<_>>();
130 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
131 prompt_store.read_with(cx, |prompt_store, cx| {
132 let prompts = prompt_store.default_prompt_metadata();
133 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
134 let contents = prompt_store.load(prompt_metadata.id, cx);
135 async move { (contents.await, prompt_metadata) }
136 });
137 cx.background_spawn(future::join_all(load_tasks))
138 })
139 } else {
140 Task::ready(vec![])
141 };
142
143 cx.spawn(async move |_cx| {
144 let (worktrees, default_user_rules) =
145 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
146
147 let worktrees = worktrees
148 .into_iter()
149 .map(|(worktree, _rules_error)| {
150 // TODO: show error message
151 // if let Some(rules_error) = rules_error {
152 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
153 // }
154 worktree
155 })
156 .collect::<Vec<_>>();
157
158 let default_user_rules = default_user_rules
159 .into_iter()
160 .flat_map(|(contents, prompt_metadata)| match contents {
161 Ok(contents) => Some(UserRulesContext {
162 uuid: match prompt_metadata.id {
163 PromptId::User { uuid } => uuid,
164 PromptId::EditWorkflow => return None,
165 },
166 title: prompt_metadata.title.map(|title| title.to_string()),
167 contents,
168 }),
169 Err(_err) => {
170 // TODO: show error message
171 // this.update(cx, |_, cx| {
172 // cx.emit(RulesLoadingError {
173 // message: format!("{err:?}").into(),
174 // });
175 // })
176 // .ok();
177 None
178 }
179 })
180 .collect::<Vec<_>>();
181
182 ProjectContext::new(worktrees, default_user_rules)
183 })
184 }
185
186 fn load_worktree_info_for_system_prompt(
187 worktree: Entity<Worktree>,
188 project: Entity<Project>,
189 cx: &mut App,
190 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
191 let tree = worktree.read(cx);
192 let root_name = tree.root_name().into();
193 let abs_path = tree.abs_path();
194
195 let mut context = WorktreeContext {
196 root_name,
197 abs_path,
198 rules_file: None,
199 };
200
201 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
202 let Some(rules_task) = rules_task else {
203 return Task::ready((context, None));
204 };
205
206 cx.spawn(async move |_| {
207 let (rules_file, rules_file_error) = match rules_task.await {
208 Ok(rules_file) => (Some(rules_file), None),
209 Err(err) => (
210 None,
211 Some(RulesLoadingError {
212 message: format!("{err}").into(),
213 }),
214 ),
215 };
216 context.rules_file = rules_file;
217 (context, rules_file_error)
218 })
219 }
220
221 fn load_worktree_rules_file(
222 worktree: Entity<Worktree>,
223 project: Entity<Project>,
224 cx: &mut App,
225 ) -> Option<Task<Result<RulesFileContext>>> {
226 let worktree = worktree.read(cx);
227 let worktree_id = worktree.id();
228 let selected_rules_file = RULES_FILE_NAMES
229 .into_iter()
230 .filter_map(|name| {
231 worktree
232 .entry_for_path(name)
233 .filter(|entry| entry.is_file())
234 .map(|entry| entry.path.clone())
235 })
236 .next();
237
238 // Note that Cline supports `.clinerules` being a directory, but that is not currently
239 // supported. This doesn't seem to occur often in GitHub repositories.
240 selected_rules_file.map(|path_in_worktree| {
241 let project_path = ProjectPath {
242 worktree_id,
243 path: path_in_worktree.clone(),
244 };
245 let buffer_task =
246 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
247 let rope_task = cx.spawn(async move |cx| {
248 buffer_task.await?.read_with(cx, |buffer, cx| {
249 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
250 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
251 })?
252 });
253 // Build a string from the rope on a background thread.
254 cx.background_spawn(async move {
255 let (project_entry_id, rope) = rope_task.await?;
256 anyhow::Ok(RulesFileContext {
257 path_in_worktree,
258 text: rope.to_string().trim().to_string(),
259 project_entry_id: project_entry_id.to_usize(),
260 })
261 })
262 })
263 }
264
265 fn handle_project_event(
266 &mut self,
267 _project: Entity<Project>,
268 event: &project::Event,
269 _cx: &mut Context<Self>,
270 ) {
271 match event {
272 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
273 self.project_context_needs_refresh.send(()).ok();
274 }
275 project::Event::WorktreeUpdatedEntries(_, items) => {
276 if items.iter().any(|(path, _, _)| {
277 RULES_FILE_NAMES
278 .iter()
279 .any(|name| path.as_ref() == Path::new(name))
280 }) {
281 self.project_context_needs_refresh.send(()).ok();
282 }
283 }
284 _ => {}
285 }
286 }
287
288 fn handle_prompts_updated_event(
289 &mut self,
290 _prompt_store: Entity<PromptStore>,
291 _event: &prompt_store::PromptsUpdatedEvent,
292 _cx: &mut Context<Self>,
293 ) {
294 self.project_context_needs_refresh.send(()).ok();
295 }
296}
297
298/// Wrapper struct that implements the AgentConnection trait
299#[derive(Clone)]
300pub struct NativeAgentConnection(pub Entity<NativeAgent>);
301
302impl ModelSelector for NativeAgentConnection {
303 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
304 log::debug!("NativeAgentConnection::list_models called");
305 cx.spawn(async move |cx| {
306 cx.update(|cx| {
307 let registry = LanguageModelRegistry::read_global(cx);
308 let models = registry.available_models(cx).collect::<Vec<_>>();
309 log::info!("Found {} available models", models.len());
310 if models.is_empty() {
311 Err(anyhow::anyhow!("No models available"))
312 } else {
313 Ok(models)
314 }
315 })?
316 })
317 }
318
319 fn select_model(
320 &self,
321 session_id: acp::SessionId,
322 model: Arc<dyn LanguageModel>,
323 cx: &mut AsyncApp,
324 ) -> Task<Result<()>> {
325 log::info!(
326 "Setting model for session {}: {:?}",
327 session_id,
328 model.name()
329 );
330 let agent = self.0.clone();
331
332 cx.spawn(async move |cx| {
333 agent.update(cx, |agent, cx| {
334 if let Some(session) = agent.sessions.get(&session_id) {
335 session.thread.update(cx, |thread, _cx| {
336 thread.selected_model = model;
337 });
338 Ok(())
339 } else {
340 Err(anyhow!("Session not found"))
341 }
342 })?
343 })
344 }
345
346 fn selected_model(
347 &self,
348 session_id: &acp::SessionId,
349 cx: &mut AsyncApp,
350 ) -> Task<Result<Arc<dyn LanguageModel>>> {
351 let agent = self.0.clone();
352 let session_id = session_id.clone();
353 cx.spawn(async move |cx| {
354 let thread = agent
355 .read_with(cx, |agent, _| {
356 agent
357 .sessions
358 .get(&session_id)
359 .map(|session| session.thread.clone())
360 })?
361 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
362 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
363 Ok(selected)
364 })
365 }
366}
367
368impl acp_thread::AgentConnection for NativeAgentConnection {
369 fn new_thread(
370 self: Rc<Self>,
371 project: Entity<Project>,
372 cwd: &Path,
373 cx: &mut AsyncApp,
374 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
375 let agent = self.0.clone();
376 log::info!("Creating new thread for project at: {:?}", cwd);
377
378 cx.spawn(async move |cx| {
379 log::debug!("Starting thread creation in async context");
380
381 // Generate session ID
382 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
383 log::info!("Created session with ID: {}", session_id);
384
385 // Create AcpThread
386 let acp_thread = cx.update(|cx| {
387 cx.new(|cx| {
388 acp_thread::AcpThread::new("agent2", self.clone(), project.clone(), session_id.clone(), cx)
389 })
390 })?;
391 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
392
393 // Create Thread
394 let thread = agent.update(
395 cx,
396 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
397 // Fetch default model from registry settings
398 let registry = LanguageModelRegistry::read_global(cx);
399
400 // Log available models for debugging
401 let available_count = registry.available_models(cx).count();
402 log::debug!("Total available models: {}", available_count);
403
404 let default_model = registry
405 .default_model()
406 .map(|configured| {
407 log::info!(
408 "Using configured default model: {:?} from provider: {:?}",
409 configured.model.name(),
410 configured.provider.name()
411 );
412 configured.model
413 })
414 .ok_or_else(|| {
415 log::warn!("No default model configured in settings");
416 anyhow!("No default model configured. Please configure a default model in settings.")
417 })?;
418
419 let thread = cx.new(|cx| {
420 let mut thread = Thread::new(project.clone(), agent.project_context.clone(), action_log.clone(), agent.templates.clone(), default_model);
421 thread.add_tool(CreateDirectoryTool::new(project.clone()));
422 thread.add_tool(CopyPathTool::new(project.clone()));
423 thread.add_tool(MovePathTool::new(project.clone()));
424 thread.add_tool(ListDirectoryTool::new(project.clone()));
425 thread.add_tool(OpenTool::new(project.clone()));
426 thread.add_tool(ThinkingTool);
427 thread.add_tool(FindPathTool::new(project.clone()));
428 thread.add_tool(GrepTool::new(project.clone()));
429 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
430 thread.add_tool(EditFileTool::new(cx.entity()));
431 thread.add_tool(NowTool);
432 thread.add_tool(TerminalTool::new(project.clone(), cx));
433 // TODO: Needs to be conditional based on zed model or not
434 thread.add_tool(WebSearchTool);
435 thread
436 });
437
438 Ok(thread)
439 },
440 )??;
441
442 // Store the session
443 agent.update(cx, |agent, cx| {
444 agent.sessions.insert(
445 session_id,
446 Session {
447 thread,
448 acp_thread: acp_thread.downgrade(),
449 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
450 this.sessions.remove(acp_thread.session_id());
451 })
452 },
453 );
454 })?;
455
456 Ok(acp_thread)
457 })
458 }
459
460 fn auth_methods(&self) -> &[acp::AuthMethod] {
461 &[] // No auth for in-process
462 }
463
464 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
465 Task::ready(Ok(()))
466 }
467
468 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
469 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
470 }
471
472 fn prompt(
473 &self,
474 params: acp::PromptRequest,
475 cx: &mut App,
476 ) -> Task<Result<acp::PromptResponse>> {
477 let session_id = params.session_id.clone();
478 let agent = self.0.clone();
479 log::info!("Received prompt request for session: {}", session_id);
480 log::debug!("Prompt blocks count: {}", params.prompt.len());
481
482 cx.spawn(async move |cx| {
483 // Get session
484 let (thread, acp_thread) = agent
485 .update(cx, |agent, _| {
486 agent
487 .sessions
488 .get_mut(&session_id)
489 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
490 })?
491 .ok_or_else(|| {
492 log::error!("Session not found: {}", session_id);
493 anyhow::anyhow!("Session not found")
494 })?;
495 log::debug!("Found session for: {}", session_id);
496
497 // Convert prompt to message
498 let message = convert_prompt_to_message(params.prompt);
499 log::info!("Converted prompt to message: {} chars", message.len());
500 log::debug!("Message content: {}", message);
501
502 // Get model using the ModelSelector capability (always available for agent2)
503 // Get the selected model from the thread directly
504 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
505
506 // Send to thread
507 log::info!("Sending message to thread with model: {:?}", model.name());
508 let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
509
510 // Handle response stream and forward to session.acp_thread
511 while let Some(result) = response_stream.next().await {
512 match result {
513 Ok(event) => {
514 log::trace!("Received completion event: {:?}", event);
515
516 match event {
517 AgentResponseEvent::Text(text) => {
518 acp_thread.update(cx, |thread, cx| {
519 thread.push_assistant_content_block(
520 acp::ContentBlock::Text(acp::TextContent {
521 text,
522 annotations: None,
523 }),
524 false,
525 cx,
526 )
527 })?;
528 }
529 AgentResponseEvent::Thinking(text) => {
530 acp_thread.update(cx, |thread, cx| {
531 thread.push_assistant_content_block(
532 acp::ContentBlock::Text(acp::TextContent {
533 text,
534 annotations: None,
535 }),
536 true,
537 cx,
538 )
539 })?;
540 }
541 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
542 tool_call,
543 options,
544 response,
545 }) => {
546 let recv = acp_thread.update(cx, |thread, cx| {
547 thread.request_tool_call_authorization(tool_call, options, cx)
548 })?;
549 cx.background_spawn(async move {
550 if let Some(option) = recv
551 .await
552 .context("authorization sender was dropped")
553 .log_err()
554 {
555 response
556 .send(option)
557 .map(|_| anyhow!("authorization receiver was dropped"))
558 .log_err();
559 }
560 })
561 .detach();
562 }
563 AgentResponseEvent::ToolCall(tool_call) => {
564 acp_thread.update(cx, |thread, cx| {
565 thread.upsert_tool_call(tool_call, cx)
566 })?;
567 }
568 AgentResponseEvent::ToolCallUpdate(update) => {
569 acp_thread.update(cx, |thread, cx| {
570 thread.update_tool_call(update, cx)
571 })??;
572 }
573 AgentResponseEvent::Stop(stop_reason) => {
574 log::debug!("Assistant message complete: {:?}", stop_reason);
575 return Ok(acp::PromptResponse { stop_reason });
576 }
577 }
578 }
579 Err(e) => {
580 log::error!("Error in model response stream: {:?}", e);
581 // TODO: Consider sending an error message to the UI
582 break;
583 }
584 }
585 }
586
587 log::info!("Response stream completed");
588 anyhow::Ok(acp::PromptResponse {
589 stop_reason: acp::StopReason::EndTurn,
590 })
591 })
592 }
593
594 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
595 log::info!("Cancelling on session: {}", session_id);
596 self.0.update(cx, |agent, cx| {
597 if let Some(agent) = agent.sessions.get(session_id) {
598 agent.thread.update(cx, |thread, _cx| thread.cancel());
599 }
600 });
601 }
602}
603
604/// Convert ACP content blocks to a message string
605fn convert_prompt_to_message(blocks: Vec<acp::ContentBlock>) -> String {
606 log::debug!("Converting {} content blocks to message", blocks.len());
607 let mut message = String::new();
608
609 for block in blocks {
610 match block {
611 acp::ContentBlock::Text(text) => {
612 log::trace!("Processing text block: {} chars", text.text.len());
613 message.push_str(&text.text);
614 }
615 acp::ContentBlock::ResourceLink(link) => {
616 log::trace!("Processing resource link: {}", link.uri);
617 message.push_str(&format!(" @{} ", link.uri));
618 }
619 acp::ContentBlock::Image(_) => {
620 log::trace!("Processing image block");
621 message.push_str(" [image] ");
622 }
623 acp::ContentBlock::Audio(_) => {
624 log::trace!("Processing audio block");
625 message.push_str(" [audio] ");
626 }
627 acp::ContentBlock::Resource(resource) => {
628 log::trace!("Processing resource block: {:?}", resource.resource);
629 message.push_str(&format!(" [resource: {:?}] ", resource.resource));
630 }
631 }
632 }
633
634 message
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640 use fs::FakeFs;
641 use gpui::TestAppContext;
642 use serde_json::json;
643 use settings::SettingsStore;
644
645 #[gpui::test]
646 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
647 init_test(cx);
648 let fs = FakeFs::new(cx.executor());
649 fs.insert_tree(
650 "/",
651 json!({
652 "a": {}
653 }),
654 )
655 .await;
656 let project = Project::test(fs.clone(), [], cx).await;
657 let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
658 .await
659 .unwrap();
660 agent.read_with(cx, |agent, _| {
661 assert_eq!(agent.project_context.borrow().worktrees, vec![])
662 });
663
664 let worktree = project
665 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
666 .await
667 .unwrap();
668 cx.run_until_parked();
669 agent.read_with(cx, |agent, _| {
670 assert_eq!(
671 agent.project_context.borrow().worktrees,
672 vec![WorktreeContext {
673 root_name: "a".into(),
674 abs_path: Path::new("/a").into(),
675 rules_file: None
676 }]
677 )
678 });
679
680 // Creating `/a/.rules` updates the project context.
681 fs.insert_file("/a/.rules", Vec::new()).await;
682 cx.run_until_parked();
683 agent.read_with(cx, |agent, cx| {
684 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
685 assert_eq!(
686 agent.project_context.borrow().worktrees,
687 vec![WorktreeContext {
688 root_name: "a".into(),
689 abs_path: Path::new("/a").into(),
690 rules_file: Some(RulesFileContext {
691 path_in_worktree: Path::new(".rules").into(),
692 text: "".into(),
693 project_entry_id: rules_entry.id.to_usize()
694 })
695 }]
696 )
697 });
698 }
699
700 fn init_test(cx: &mut TestAppContext) {
701 env_logger::try_init().ok();
702 cx.update(|cx| {
703 let settings_store = SettingsStore::test(cx);
704 cx.set_global(settings_store);
705 Project::init_settings(cx);
706 language::init(cx);
707 });
708 }
709}