1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
4 FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
6};
7use acp_thread::ModelSelector;
8use agent_client_protocol as acp;
9use anyhow::{Context as _, Result, anyhow};
10use futures::{StreamExt, future};
11use gpui::{
12 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
13};
14use language_model::{LanguageModel, LanguageModelRegistry};
15use project::{Project, ProjectItem, ProjectPath, Worktree};
16use prompt_store::{
17 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
18};
19use std::cell::RefCell;
20use std::collections::HashMap;
21use std::path::Path;
22use std::rc::Rc;
23use std::sync::Arc;
24use util::ResultExt;
25
26const RULES_FILE_NAMES: [&'static str; 9] = [
27 ".rules",
28 ".cursorrules",
29 ".windsurfrules",
30 ".clinerules",
31 ".github/copilot-instructions.md",
32 "CLAUDE.md",
33 "AGENT.md",
34 "AGENTS.md",
35 "GEMINI.md",
36];
37
38pub struct RulesLoadingError {
39 pub message: SharedString,
40}
41
42/// Holds both the internal Thread and the AcpThread for a session
43struct Session {
44 /// The internal thread that processes messages
45 thread: Entity<Thread>,
46 /// The ACP thread that handles protocol communication
47 acp_thread: WeakEntity<acp_thread::AcpThread>,
48 _subscription: Subscription,
49}
50
51pub struct NativeAgent {
52 /// Session ID -> Session mapping
53 sessions: HashMap<acp::SessionId, Session>,
54 /// Shared project context for all threads
55 project_context: Rc<RefCell<ProjectContext>>,
56 project_context_needs_refresh: watch::Sender<()>,
57 _maintain_project_context: Task<Result<()>>,
58 context_server_registry: Entity<ContextServerRegistry>,
59 /// Shared templates for all threads
60 templates: Arc<Templates>,
61 project: Entity<Project>,
62 prompt_store: Option<Entity<PromptStore>>,
63 _subscriptions: Vec<Subscription>,
64}
65
66impl NativeAgent {
67 pub async fn new(
68 project: Entity<Project>,
69 templates: Arc<Templates>,
70 prompt_store: Option<Entity<PromptStore>>,
71 cx: &mut AsyncApp,
72 ) -> Result<Entity<NativeAgent>> {
73 log::info!("Creating new NativeAgent");
74
75 let project_context = cx
76 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
77 .await;
78
79 cx.new(|cx| {
80 let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)];
81 if let Some(prompt_store) = prompt_store.as_ref() {
82 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
83 }
84
85 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
86 watch::channel(());
87 Self {
88 sessions: HashMap::new(),
89 project_context: Rc::new(RefCell::new(project_context)),
90 project_context_needs_refresh: project_context_needs_refresh_tx,
91 _maintain_project_context: cx.spawn(async move |this, cx| {
92 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
93 }),
94 context_server_registry: cx.new(|cx| {
95 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
96 }),
97 templates,
98 project,
99 prompt_store,
100 _subscriptions: subscriptions,
101 }
102 })
103 }
104
105 async fn maintain_project_context(
106 this: WeakEntity<Self>,
107 mut needs_refresh: watch::Receiver<()>,
108 cx: &mut AsyncApp,
109 ) -> Result<()> {
110 while needs_refresh.changed().await.is_ok() {
111 let project_context = this
112 .update(cx, |this, cx| {
113 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
114 })?
115 .await;
116 this.update(cx, |this, _| this.project_context.replace(project_context))?;
117 }
118
119 Ok(())
120 }
121
122 fn build_project_context(
123 project: &Entity<Project>,
124 prompt_store: Option<&Entity<PromptStore>>,
125 cx: &mut App,
126 ) -> Task<ProjectContext> {
127 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
128 let worktree_tasks = worktrees
129 .into_iter()
130 .map(|worktree| {
131 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
132 })
133 .collect::<Vec<_>>();
134 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
135 prompt_store.read_with(cx, |prompt_store, cx| {
136 let prompts = prompt_store.default_prompt_metadata();
137 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
138 let contents = prompt_store.load(prompt_metadata.id, cx);
139 async move { (contents.await, prompt_metadata) }
140 });
141 cx.background_spawn(future::join_all(load_tasks))
142 })
143 } else {
144 Task::ready(vec![])
145 };
146
147 cx.spawn(async move |_cx| {
148 let (worktrees, default_user_rules) =
149 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
150
151 let worktrees = worktrees
152 .into_iter()
153 .map(|(worktree, _rules_error)| {
154 // TODO: show error message
155 // if let Some(rules_error) = rules_error {
156 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
157 // }
158 worktree
159 })
160 .collect::<Vec<_>>();
161
162 let default_user_rules = default_user_rules
163 .into_iter()
164 .flat_map(|(contents, prompt_metadata)| match contents {
165 Ok(contents) => Some(UserRulesContext {
166 uuid: match prompt_metadata.id {
167 PromptId::User { uuid } => uuid,
168 PromptId::EditWorkflow => return None,
169 },
170 title: prompt_metadata.title.map(|title| title.to_string()),
171 contents,
172 }),
173 Err(_err) => {
174 // TODO: show error message
175 // this.update(cx, |_, cx| {
176 // cx.emit(RulesLoadingError {
177 // message: format!("{err:?}").into(),
178 // });
179 // })
180 // .ok();
181 None
182 }
183 })
184 .collect::<Vec<_>>();
185
186 ProjectContext::new(worktrees, default_user_rules)
187 })
188 }
189
190 fn load_worktree_info_for_system_prompt(
191 worktree: Entity<Worktree>,
192 project: Entity<Project>,
193 cx: &mut App,
194 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
195 let tree = worktree.read(cx);
196 let root_name = tree.root_name().into();
197 let abs_path = tree.abs_path();
198
199 let mut context = WorktreeContext {
200 root_name,
201 abs_path,
202 rules_file: None,
203 };
204
205 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
206 let Some(rules_task) = rules_task else {
207 return Task::ready((context, None));
208 };
209
210 cx.spawn(async move |_| {
211 let (rules_file, rules_file_error) = match rules_task.await {
212 Ok(rules_file) => (Some(rules_file), None),
213 Err(err) => (
214 None,
215 Some(RulesLoadingError {
216 message: format!("{err}").into(),
217 }),
218 ),
219 };
220 context.rules_file = rules_file;
221 (context, rules_file_error)
222 })
223 }
224
225 fn load_worktree_rules_file(
226 worktree: Entity<Worktree>,
227 project: Entity<Project>,
228 cx: &mut App,
229 ) -> Option<Task<Result<RulesFileContext>>> {
230 let worktree = worktree.read(cx);
231 let worktree_id = worktree.id();
232 let selected_rules_file = RULES_FILE_NAMES
233 .into_iter()
234 .filter_map(|name| {
235 worktree
236 .entry_for_path(name)
237 .filter(|entry| entry.is_file())
238 .map(|entry| entry.path.clone())
239 })
240 .next();
241
242 // Note that Cline supports `.clinerules` being a directory, but that is not currently
243 // supported. This doesn't seem to occur often in GitHub repositories.
244 selected_rules_file.map(|path_in_worktree| {
245 let project_path = ProjectPath {
246 worktree_id,
247 path: path_in_worktree.clone(),
248 };
249 let buffer_task =
250 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
251 let rope_task = cx.spawn(async move |cx| {
252 buffer_task.await?.read_with(cx, |buffer, cx| {
253 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
254 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
255 })?
256 });
257 // Build a string from the rope on a background thread.
258 cx.background_spawn(async move {
259 let (project_entry_id, rope) = rope_task.await?;
260 anyhow::Ok(RulesFileContext {
261 path_in_worktree,
262 text: rope.to_string().trim().to_string(),
263 project_entry_id: project_entry_id.to_usize(),
264 })
265 })
266 })
267 }
268
269 fn handle_project_event(
270 &mut self,
271 _project: Entity<Project>,
272 event: &project::Event,
273 _cx: &mut Context<Self>,
274 ) {
275 match event {
276 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
277 self.project_context_needs_refresh.send(()).ok();
278 }
279 project::Event::WorktreeUpdatedEntries(_, items) => {
280 if items.iter().any(|(path, _, _)| {
281 RULES_FILE_NAMES
282 .iter()
283 .any(|name| path.as_ref() == Path::new(name))
284 }) {
285 self.project_context_needs_refresh.send(()).ok();
286 }
287 }
288 _ => {}
289 }
290 }
291
292 fn handle_prompts_updated_event(
293 &mut self,
294 _prompt_store: Entity<PromptStore>,
295 _event: &prompt_store::PromptsUpdatedEvent,
296 _cx: &mut Context<Self>,
297 ) {
298 self.project_context_needs_refresh.send(()).ok();
299 }
300}
301
302/// Wrapper struct that implements the AgentConnection trait
303#[derive(Clone)]
304pub struct NativeAgentConnection(pub Entity<NativeAgent>);
305
306impl ModelSelector for NativeAgentConnection {
307 fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
308 log::debug!("NativeAgentConnection::list_models called");
309 cx.spawn(async move |cx| {
310 cx.update(|cx| {
311 let registry = LanguageModelRegistry::read_global(cx);
312 let models = registry.available_models(cx).collect::<Vec<_>>();
313 log::info!("Found {} available models", models.len());
314 if models.is_empty() {
315 Err(anyhow::anyhow!("No models available"))
316 } else {
317 Ok(models)
318 }
319 })?
320 })
321 }
322
323 fn select_model(
324 &self,
325 session_id: acp::SessionId,
326 model: Arc<dyn LanguageModel>,
327 cx: &mut AsyncApp,
328 ) -> Task<Result<()>> {
329 log::info!(
330 "Setting model for session {}: {:?}",
331 session_id,
332 model.name()
333 );
334 let agent = self.0.clone();
335
336 cx.spawn(async move |cx| {
337 agent.update(cx, |agent, cx| {
338 if let Some(session) = agent.sessions.get(&session_id) {
339 session.thread.update(cx, |thread, _cx| {
340 thread.selected_model = model;
341 });
342 Ok(())
343 } else {
344 Err(anyhow!("Session not found"))
345 }
346 })?
347 })
348 }
349
350 fn selected_model(
351 &self,
352 session_id: &acp::SessionId,
353 cx: &mut AsyncApp,
354 ) -> Task<Result<Arc<dyn LanguageModel>>> {
355 let agent = self.0.clone();
356 let session_id = session_id.clone();
357 cx.spawn(async move |cx| {
358 let thread = agent
359 .read_with(cx, |agent, _| {
360 agent
361 .sessions
362 .get(&session_id)
363 .map(|session| session.thread.clone())
364 })?
365 .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
366 let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
367 Ok(selected)
368 })
369 }
370}
371
372impl acp_thread::AgentConnection for NativeAgentConnection {
373 fn new_thread(
374 self: Rc<Self>,
375 project: Entity<Project>,
376 cwd: &Path,
377 cx: &mut AsyncApp,
378 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
379 let agent = self.0.clone();
380 log::info!("Creating new thread for project at: {:?}", cwd);
381
382 cx.spawn(async move |cx| {
383 log::debug!("Starting thread creation in async context");
384
385 // Generate session ID
386 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
387 log::info!("Created session with ID: {}", session_id);
388
389 // Create AcpThread
390 let acp_thread = cx.update(|cx| {
391 cx.new(|cx| {
392 acp_thread::AcpThread::new(
393 "agent2",
394 self.clone(),
395 project.clone(),
396 session_id.clone(),
397 cx,
398 )
399 })
400 })?;
401 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
402
403 // Create Thread
404 let thread = agent.update(
405 cx,
406 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
407 // Fetch default model from registry settings
408 let registry = LanguageModelRegistry::read_global(cx);
409
410 // Log available models for debugging
411 let available_count = registry.available_models(cx).count();
412 log::debug!("Total available models: {}", available_count);
413
414 let default_model = registry
415 .default_model()
416 .map(|configured| {
417 log::info!(
418 "Using configured default model: {:?} from provider: {:?}",
419 configured.model.name(),
420 configured.provider.name()
421 );
422 configured.model
423 })
424 .ok_or_else(|| {
425 log::warn!("No default model configured in settings");
426 anyhow!(
427 "No default model. Please configure a default model in settings."
428 )
429 })?;
430
431 let thread = cx.new(|cx| {
432 let mut thread = Thread::new(
433 project.clone(),
434 agent.project_context.clone(),
435 agent.context_server_registry.clone(),
436 action_log.clone(),
437 agent.templates.clone(),
438 default_model,
439 cx,
440 );
441 thread.add_tool(CreateDirectoryTool::new(project.clone()));
442 thread.add_tool(CopyPathTool::new(project.clone()));
443 thread.add_tool(DiagnosticsTool::new(project.clone()));
444 thread.add_tool(MovePathTool::new(project.clone()));
445 thread.add_tool(ListDirectoryTool::new(project.clone()));
446 thread.add_tool(OpenTool::new(project.clone()));
447 thread.add_tool(ThinkingTool);
448 thread.add_tool(FindPathTool::new(project.clone()));
449 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
450 thread.add_tool(GrepTool::new(project.clone()));
451 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
452 thread.add_tool(EditFileTool::new(cx.entity()));
453 thread.add_tool(NowTool);
454 thread.add_tool(TerminalTool::new(project.clone(), cx));
455 // TODO: Needs to be conditional based on zed model or not
456 thread.add_tool(WebSearchTool);
457 thread
458 });
459
460 Ok(thread)
461 },
462 )??;
463
464 // Store the session
465 agent.update(cx, |agent, cx| {
466 agent.sessions.insert(
467 session_id,
468 Session {
469 thread,
470 acp_thread: acp_thread.downgrade(),
471 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
472 this.sessions.remove(acp_thread.session_id());
473 }),
474 },
475 );
476 })?;
477
478 Ok(acp_thread)
479 })
480 }
481
482 fn auth_methods(&self) -> &[acp::AuthMethod] {
483 &[] // No auth for in-process
484 }
485
486 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
487 Task::ready(Ok(()))
488 }
489
490 fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
491 Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
492 }
493
494 fn prompt(
495 &self,
496 params: acp::PromptRequest,
497 cx: &mut App,
498 ) -> Task<Result<acp::PromptResponse>> {
499 let session_id = params.session_id.clone();
500 let agent = self.0.clone();
501 log::info!("Received prompt request for session: {}", session_id);
502 log::debug!("Prompt blocks count: {}", params.prompt.len());
503
504 cx.spawn(async move |cx| {
505 // Get session
506 let (thread, acp_thread) = agent
507 .update(cx, |agent, _| {
508 agent
509 .sessions
510 .get_mut(&session_id)
511 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
512 })?
513 .ok_or_else(|| {
514 log::error!("Session not found: {}", session_id);
515 anyhow::anyhow!("Session not found")
516 })?;
517 log::debug!("Found session for: {}", session_id);
518
519 let message: Vec<MessageContent> = params
520 .prompt
521 .into_iter()
522 .map(Into::into)
523 .collect::<Vec<_>>();
524 log::info!("Converted prompt to message: {} chars", message.len());
525 log::debug!("Message content: {:?}", message);
526
527 // Get model using the ModelSelector capability (always available for agent2)
528 // Get the selected model from the thread directly
529 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
530
531 // Send to thread
532 log::info!("Sending message to thread with model: {:?}", model.name());
533 let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
534
535 // Handle response stream and forward to session.acp_thread
536 while let Some(result) = response_stream.next().await {
537 match result {
538 Ok(event) => {
539 log::trace!("Received completion event: {:?}", event);
540
541 match event {
542 AgentResponseEvent::Text(text) => {
543 acp_thread.update(cx, |thread, cx| {
544 thread.push_assistant_content_block(
545 acp::ContentBlock::Text(acp::TextContent {
546 text,
547 annotations: None,
548 }),
549 false,
550 cx,
551 )
552 })?;
553 }
554 AgentResponseEvent::Thinking(text) => {
555 acp_thread.update(cx, |thread, cx| {
556 thread.push_assistant_content_block(
557 acp::ContentBlock::Text(acp::TextContent {
558 text,
559 annotations: None,
560 }),
561 true,
562 cx,
563 )
564 })?;
565 }
566 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
567 tool_call,
568 options,
569 response,
570 }) => {
571 let recv = acp_thread.update(cx, |thread, cx| {
572 thread.request_tool_call_authorization(tool_call, options, cx)
573 })?;
574 cx.background_spawn(async move {
575 if let Some(option) = recv
576 .await
577 .context("authorization sender was dropped")
578 .log_err()
579 {
580 response
581 .send(option)
582 .map(|_| anyhow!("authorization receiver was dropped"))
583 .log_err();
584 }
585 })
586 .detach();
587 }
588 AgentResponseEvent::ToolCall(tool_call) => {
589 acp_thread.update(cx, |thread, cx| {
590 thread.upsert_tool_call(tool_call, cx)
591 })?;
592 }
593 AgentResponseEvent::ToolCallUpdate(update) => {
594 acp_thread.update(cx, |thread, cx| {
595 thread.update_tool_call(update, cx)
596 })??;
597 }
598 AgentResponseEvent::Stop(stop_reason) => {
599 log::debug!("Assistant message complete: {:?}", stop_reason);
600 return Ok(acp::PromptResponse { stop_reason });
601 }
602 }
603 }
604 Err(e) => {
605 log::error!("Error in model response stream: {:?}", e);
606 // TODO: Consider sending an error message to the UI
607 break;
608 }
609 }
610 }
611
612 log::info!("Response stream completed");
613 anyhow::Ok(acp::PromptResponse {
614 stop_reason: acp::StopReason::EndTurn,
615 })
616 })
617 }
618
619 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
620 log::info!("Cancelling on session: {}", session_id);
621 self.0.update(cx, |agent, cx| {
622 if let Some(agent) = agent.sessions.get(session_id) {
623 agent.thread.update(cx, |thread, _cx| thread.cancel());
624 }
625 });
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use fs::FakeFs;
633 use gpui::TestAppContext;
634 use serde_json::json;
635 use settings::SettingsStore;
636
637 #[gpui::test]
638 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
639 init_test(cx);
640 let fs = FakeFs::new(cx.executor());
641 fs.insert_tree(
642 "/",
643 json!({
644 "a": {}
645 }),
646 )
647 .await;
648 let project = Project::test(fs.clone(), [], cx).await;
649 let agent = NativeAgent::new(project.clone(), Templates::new(), None, &mut cx.to_async())
650 .await
651 .unwrap();
652 agent.read_with(cx, |agent, _| {
653 assert_eq!(agent.project_context.borrow().worktrees, vec![])
654 });
655
656 let worktree = project
657 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
658 .await
659 .unwrap();
660 cx.run_until_parked();
661 agent.read_with(cx, |agent, _| {
662 assert_eq!(
663 agent.project_context.borrow().worktrees,
664 vec![WorktreeContext {
665 root_name: "a".into(),
666 abs_path: Path::new("/a").into(),
667 rules_file: None
668 }]
669 )
670 });
671
672 // Creating `/a/.rules` updates the project context.
673 fs.insert_file("/a/.rules", Vec::new()).await;
674 cx.run_until_parked();
675 agent.read_with(cx, |agent, cx| {
676 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
677 assert_eq!(
678 agent.project_context.borrow().worktrees,
679 vec![WorktreeContext {
680 root_name: "a".into(),
681 abs_path: Path::new("/a").into(),
682 rules_file: Some(RulesFileContext {
683 path_in_worktree: Path::new(".rules").into(),
684 text: "".into(),
685 project_entry_id: rules_entry.id.to_usize()
686 })
687 }]
688 )
689 });
690 }
691
692 fn init_test(cx: &mut TestAppContext) {
693 env_logger::try_init().ok();
694 cx.update(|cx| {
695 let settings_store = SettingsStore::test(cx);
696 cx.set_global(settings_store);
697 Project::init_settings(cx);
698 language::init(cx);
699 });
700 }
701}