1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
4 EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
6 WebSearchTool,
7};
8use acp_thread::AgentModelSelector;
9use agent_client_protocol as acp;
10use agent_settings::AgentSettings;
11use anyhow::{Context as _, Result, anyhow};
12use collections::{HashSet, IndexMap};
13use fs::Fs;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::update_settings_file;
24use std::cell::RefCell;
25use std::collections::HashMap;
26use std::path::Path;
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30
31const RULES_FILE_NAMES: [&'static str; 9] = [
32 ".rules",
33 ".cursorrules",
34 ".windsurfrules",
35 ".clinerules",
36 ".github/copilot-instructions.md",
37 "CLAUDE.md",
38 "AGENT.md",
39 "AGENTS.md",
40 "GEMINI.md",
41];
42
43pub struct RulesLoadingError {
44 pub message: SharedString,
45}
46
47/// Holds both the internal Thread and the AcpThread for a session
48struct Session {
49 /// The internal thread that processes messages
50 thread: Entity<Thread>,
51 /// The ACP thread that handles protocol communication
52 acp_thread: WeakEntity<acp_thread::AcpThread>,
53 _subscription: Subscription,
54}
55
56pub struct LanguageModels {
57 /// Access language model by ID
58 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
59 /// Cached list for returning language model information
60 model_list: acp_thread::AgentModelList,
61 refresh_models_rx: watch::Receiver<()>,
62 refresh_models_tx: watch::Sender<()>,
63}
64
65impl LanguageModels {
66 fn new(cx: &App) -> Self {
67 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
68 let mut this = Self {
69 models: HashMap::default(),
70 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
71 refresh_models_rx,
72 refresh_models_tx,
73 };
74 this.refresh_list(cx);
75 this
76 }
77
78 fn refresh_list(&mut self, cx: &App) {
79 let providers = LanguageModelRegistry::global(cx)
80 .read(cx)
81 .providers()
82 .into_iter()
83 .filter(|provider| provider.is_authenticated(cx))
84 .collect::<Vec<_>>();
85
86 let mut language_model_list = IndexMap::default();
87 let mut recommended_models = HashSet::default();
88
89 let mut recommended = Vec::new();
90 for provider in &providers {
91 for model in provider.recommended_models(cx) {
92 recommended_models.insert(model.id());
93 recommended.push(Self::map_language_model_to_info(&model, &provider));
94 }
95 }
96 if !recommended.is_empty() {
97 language_model_list.insert(
98 acp_thread::AgentModelGroupName("Recommended".into()),
99 recommended,
100 );
101 }
102
103 let mut models = HashMap::default();
104 for provider in providers {
105 let mut provider_models = Vec::new();
106 for model in provider.provided_models(cx) {
107 let model_info = Self::map_language_model_to_info(&model, &provider);
108 let model_id = model_info.id.clone();
109 if !recommended_models.contains(&model.id()) {
110 provider_models.push(model_info);
111 }
112 models.insert(model_id, model);
113 }
114 if !provider_models.is_empty() {
115 language_model_list.insert(
116 acp_thread::AgentModelGroupName(provider.name().0.clone()),
117 provider_models,
118 );
119 }
120 }
121
122 self.models = models;
123 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
124 self.refresh_models_tx.send(()).ok();
125 }
126
127 fn watch(&self) -> watch::Receiver<()> {
128 self.refresh_models_rx.clone()
129 }
130
131 pub fn model_from_id(
132 &self,
133 model_id: &acp_thread::AgentModelId,
134 ) -> Option<Arc<dyn LanguageModel>> {
135 self.models.get(model_id).cloned()
136 }
137
138 fn map_language_model_to_info(
139 model: &Arc<dyn LanguageModel>,
140 provider: &Arc<dyn LanguageModelProvider>,
141 ) -> acp_thread::AgentModelInfo {
142 acp_thread::AgentModelInfo {
143 id: Self::model_id(model),
144 name: model.name().0,
145 icon: Some(provider.icon()),
146 }
147 }
148
149 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
150 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
151 }
152}
153
154pub struct NativeAgent {
155 /// Session ID -> Session mapping
156 sessions: HashMap<acp::SessionId, Session>,
157 /// Shared project context for all threads
158 project_context: Rc<RefCell<ProjectContext>>,
159 project_context_needs_refresh: watch::Sender<()>,
160 _maintain_project_context: Task<Result<()>>,
161 context_server_registry: Entity<ContextServerRegistry>,
162 /// Shared templates for all threads
163 templates: Arc<Templates>,
164 /// Cached model information
165 models: LanguageModels,
166 project: Entity<Project>,
167 prompt_store: Option<Entity<PromptStore>>,
168 fs: Arc<dyn Fs>,
169 _subscriptions: Vec<Subscription>,
170}
171
172impl NativeAgent {
173 pub async fn new(
174 project: Entity<Project>,
175 templates: Arc<Templates>,
176 prompt_store: Option<Entity<PromptStore>>,
177 fs: Arc<dyn Fs>,
178 cx: &mut AsyncApp,
179 ) -> Result<Entity<NativeAgent>> {
180 log::info!("Creating new NativeAgent");
181
182 let project_context = cx
183 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
184 .await;
185
186 cx.new(|cx| {
187 let mut subscriptions = vec![
188 cx.subscribe(&project, Self::handle_project_event),
189 cx.subscribe(
190 &LanguageModelRegistry::global(cx),
191 Self::handle_models_updated_event,
192 ),
193 ];
194 if let Some(prompt_store) = prompt_store.as_ref() {
195 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
196 }
197
198 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
199 watch::channel(());
200 Self {
201 sessions: HashMap::new(),
202 project_context: Rc::new(RefCell::new(project_context)),
203 project_context_needs_refresh: project_context_needs_refresh_tx,
204 _maintain_project_context: cx.spawn(async move |this, cx| {
205 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
206 }),
207 context_server_registry: cx.new(|cx| {
208 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
209 }),
210 templates,
211 models: LanguageModels::new(cx),
212 project,
213 prompt_store,
214 fs,
215 _subscriptions: subscriptions,
216 }
217 })
218 }
219
220 pub fn models(&self) -> &LanguageModels {
221 &self.models
222 }
223
224 async fn maintain_project_context(
225 this: WeakEntity<Self>,
226 mut needs_refresh: watch::Receiver<()>,
227 cx: &mut AsyncApp,
228 ) -> Result<()> {
229 while needs_refresh.changed().await.is_ok() {
230 let project_context = this
231 .update(cx, |this, cx| {
232 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
233 })?
234 .await;
235 this.update(cx, |this, _| this.project_context.replace(project_context))?;
236 }
237
238 Ok(())
239 }
240
241 fn build_project_context(
242 project: &Entity<Project>,
243 prompt_store: Option<&Entity<PromptStore>>,
244 cx: &mut App,
245 ) -> Task<ProjectContext> {
246 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
247 let worktree_tasks = worktrees
248 .into_iter()
249 .map(|worktree| {
250 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
251 })
252 .collect::<Vec<_>>();
253 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
254 prompt_store.read_with(cx, |prompt_store, cx| {
255 let prompts = prompt_store.default_prompt_metadata();
256 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
257 let contents = prompt_store.load(prompt_metadata.id, cx);
258 async move { (contents.await, prompt_metadata) }
259 });
260 cx.background_spawn(future::join_all(load_tasks))
261 })
262 } else {
263 Task::ready(vec![])
264 };
265
266 cx.spawn(async move |_cx| {
267 let (worktrees, default_user_rules) =
268 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
269
270 let worktrees = worktrees
271 .into_iter()
272 .map(|(worktree, _rules_error)| {
273 // TODO: show error message
274 // if let Some(rules_error) = rules_error {
275 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
276 // }
277 worktree
278 })
279 .collect::<Vec<_>>();
280
281 let default_user_rules = default_user_rules
282 .into_iter()
283 .flat_map(|(contents, prompt_metadata)| match contents {
284 Ok(contents) => Some(UserRulesContext {
285 uuid: match prompt_metadata.id {
286 PromptId::User { uuid } => uuid,
287 PromptId::EditWorkflow => return None,
288 },
289 title: prompt_metadata.title.map(|title| title.to_string()),
290 contents,
291 }),
292 Err(_err) => {
293 // TODO: show error message
294 // this.update(cx, |_, cx| {
295 // cx.emit(RulesLoadingError {
296 // message: format!("{err:?}").into(),
297 // });
298 // })
299 // .ok();
300 None
301 }
302 })
303 .collect::<Vec<_>>();
304
305 ProjectContext::new(worktrees, default_user_rules)
306 })
307 }
308
309 fn load_worktree_info_for_system_prompt(
310 worktree: Entity<Worktree>,
311 project: Entity<Project>,
312 cx: &mut App,
313 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
314 let tree = worktree.read(cx);
315 let root_name = tree.root_name().into();
316 let abs_path = tree.abs_path();
317
318 let mut context = WorktreeContext {
319 root_name,
320 abs_path,
321 rules_file: None,
322 };
323
324 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
325 let Some(rules_task) = rules_task else {
326 return Task::ready((context, None));
327 };
328
329 cx.spawn(async move |_| {
330 let (rules_file, rules_file_error) = match rules_task.await {
331 Ok(rules_file) => (Some(rules_file), None),
332 Err(err) => (
333 None,
334 Some(RulesLoadingError {
335 message: format!("{err}").into(),
336 }),
337 ),
338 };
339 context.rules_file = rules_file;
340 (context, rules_file_error)
341 })
342 }
343
344 fn load_worktree_rules_file(
345 worktree: Entity<Worktree>,
346 project: Entity<Project>,
347 cx: &mut App,
348 ) -> Option<Task<Result<RulesFileContext>>> {
349 let worktree = worktree.read(cx);
350 let worktree_id = worktree.id();
351 let selected_rules_file = RULES_FILE_NAMES
352 .into_iter()
353 .filter_map(|name| {
354 worktree
355 .entry_for_path(name)
356 .filter(|entry| entry.is_file())
357 .map(|entry| entry.path.clone())
358 })
359 .next();
360
361 // Note that Cline supports `.clinerules` being a directory, but that is not currently
362 // supported. This doesn't seem to occur often in GitHub repositories.
363 selected_rules_file.map(|path_in_worktree| {
364 let project_path = ProjectPath {
365 worktree_id,
366 path: path_in_worktree.clone(),
367 };
368 let buffer_task =
369 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
370 let rope_task = cx.spawn(async move |cx| {
371 buffer_task.await?.read_with(cx, |buffer, cx| {
372 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
373 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
374 })?
375 });
376 // Build a string from the rope on a background thread.
377 cx.background_spawn(async move {
378 let (project_entry_id, rope) = rope_task.await?;
379 anyhow::Ok(RulesFileContext {
380 path_in_worktree,
381 text: rope.to_string().trim().to_string(),
382 project_entry_id: project_entry_id.to_usize(),
383 })
384 })
385 })
386 }
387
388 fn handle_project_event(
389 &mut self,
390 _project: Entity<Project>,
391 event: &project::Event,
392 _cx: &mut Context<Self>,
393 ) {
394 match event {
395 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
396 self.project_context_needs_refresh.send(()).ok();
397 }
398 project::Event::WorktreeUpdatedEntries(_, items) => {
399 if items.iter().any(|(path, _, _)| {
400 RULES_FILE_NAMES
401 .iter()
402 .any(|name| path.as_ref() == Path::new(name))
403 }) {
404 self.project_context_needs_refresh.send(()).ok();
405 }
406 }
407 _ => {}
408 }
409 }
410
411 fn handle_prompts_updated_event(
412 &mut self,
413 _prompt_store: Entity<PromptStore>,
414 _event: &prompt_store::PromptsUpdatedEvent,
415 _cx: &mut Context<Self>,
416 ) {
417 self.project_context_needs_refresh.send(()).ok();
418 }
419
420 fn handle_models_updated_event(
421 &mut self,
422 _registry: Entity<LanguageModelRegistry>,
423 _event: &language_model::Event,
424 cx: &mut Context<Self>,
425 ) {
426 self.models.refresh_list(cx);
427 for session in self.sessions.values_mut() {
428 session.thread.update(cx, |thread, _| {
429 let model_id = LanguageModels::model_id(&thread.selected_model);
430 if let Some(model) = self.models.model_from_id(&model_id) {
431 thread.selected_model = model.clone();
432 }
433 });
434 }
435 }
436}
437
438/// Wrapper struct that implements the AgentConnection trait
439#[derive(Clone)]
440pub struct NativeAgentConnection(pub Entity<NativeAgent>);
441
442impl AgentModelSelector for NativeAgentConnection {
443 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
444 log::debug!("NativeAgentConnection::list_models called");
445 let list = self.0.read(cx).models.model_list.clone();
446 Task::ready(if list.is_empty() {
447 Err(anyhow::anyhow!("No models available"))
448 } else {
449 Ok(list)
450 })
451 }
452
453 fn select_model(
454 &self,
455 session_id: acp::SessionId,
456 model_id: acp_thread::AgentModelId,
457 cx: &mut App,
458 ) -> Task<Result<()>> {
459 log::info!("Setting model for session {}: {}", session_id, model_id);
460 let Some(thread) = self
461 .0
462 .read(cx)
463 .sessions
464 .get(&session_id)
465 .map(|session| session.thread.clone())
466 else {
467 return Task::ready(Err(anyhow!("Session not found")));
468 };
469
470 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
471 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
472 };
473
474 thread.update(cx, |thread, _cx| {
475 thread.selected_model = model.clone();
476 });
477
478 update_settings_file::<AgentSettings>(
479 self.0.read(cx).fs.clone(),
480 cx,
481 move |settings, _cx| {
482 settings.set_model(model);
483 },
484 );
485
486 Task::ready(Ok(()))
487 }
488
489 fn selected_model(
490 &self,
491 session_id: &acp::SessionId,
492 cx: &mut App,
493 ) -> Task<Result<acp_thread::AgentModelInfo>> {
494 let session_id = session_id.clone();
495
496 let Some(thread) = self
497 .0
498 .read(cx)
499 .sessions
500 .get(&session_id)
501 .map(|session| session.thread.clone())
502 else {
503 return Task::ready(Err(anyhow!("Session not found")));
504 };
505 let model = thread.read(cx).selected_model.clone();
506 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
507 else {
508 return Task::ready(Err(anyhow!("Provider not found")));
509 };
510 Task::ready(Ok(LanguageModels::map_language_model_to_info(
511 &model, &provider,
512 )))
513 }
514
515 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
516 self.0.read(cx).models.watch()
517 }
518}
519
520impl acp_thread::AgentConnection for NativeAgentConnection {
521 fn new_thread(
522 self: Rc<Self>,
523 project: Entity<Project>,
524 cwd: &Path,
525 cx: &mut AsyncApp,
526 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
527 let agent = self.0.clone();
528 log::info!("Creating new thread for project at: {:?}", cwd);
529
530 cx.spawn(async move |cx| {
531 log::debug!("Starting thread creation in async context");
532
533 // Generate session ID
534 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
535 log::info!("Created session with ID: {}", session_id);
536
537 // Create AcpThread
538 let acp_thread = cx.update(|cx| {
539 cx.new(|cx| {
540 acp_thread::AcpThread::new(
541 "agent2",
542 self.clone(),
543 project.clone(),
544 session_id.clone(),
545 cx,
546 )
547 })
548 })?;
549 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
550
551 // Create Thread
552 let thread = agent.update(
553 cx,
554 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
555 // Fetch default model from registry settings
556 let registry = LanguageModelRegistry::read_global(cx);
557
558 // Log available models for debugging
559 let available_count = registry.available_models(cx).count();
560 log::debug!("Total available models: {}", available_count);
561
562 let default_model = registry
563 .default_model()
564 .and_then(|default_model| {
565 agent
566 .models
567 .model_from_id(&LanguageModels::model_id(&default_model.model))
568 })
569 .ok_or_else(|| {
570 log::warn!("No default model configured in settings");
571 anyhow!(
572 "No default model. Please configure a default model in settings."
573 )
574 })?;
575
576 let thread = cx.new(|cx| {
577 let mut thread = Thread::new(
578 project.clone(),
579 agent.project_context.clone(),
580 agent.context_server_registry.clone(),
581 action_log.clone(),
582 agent.templates.clone(),
583 default_model,
584 cx,
585 );
586 thread.add_tool(CopyPathTool::new(project.clone()));
587 thread.add_tool(CreateDirectoryTool::new(project.clone()));
588 thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
589 thread.add_tool(DiagnosticsTool::new(project.clone()));
590 thread.add_tool(EditFileTool::new(cx.entity()));
591 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
592 thread.add_tool(FindPathTool::new(project.clone()));
593 thread.add_tool(GrepTool::new(project.clone()));
594 thread.add_tool(ListDirectoryTool::new(project.clone()));
595 thread.add_tool(MovePathTool::new(project.clone()));
596 thread.add_tool(NowTool);
597 thread.add_tool(OpenTool::new(project.clone()));
598 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
599 thread.add_tool(TerminalTool::new(project.clone(), cx));
600 thread.add_tool(ThinkingTool);
601 thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
602 thread
603 });
604
605 Ok(thread)
606 },
607 )??;
608
609 // Store the session
610 agent.update(cx, |agent, cx| {
611 agent.sessions.insert(
612 session_id,
613 Session {
614 thread,
615 acp_thread: acp_thread.downgrade(),
616 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
617 this.sessions.remove(acp_thread.session_id());
618 }),
619 },
620 );
621 })?;
622
623 Ok(acp_thread)
624 })
625 }
626
627 fn auth_methods(&self) -> &[acp::AuthMethod] {
628 &[] // No auth for in-process
629 }
630
631 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
632 Task::ready(Ok(()))
633 }
634
635 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
636 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
637 }
638
639 fn prompt(
640 &self,
641 id: Option<acp_thread::UserMessageId>,
642 params: acp::PromptRequest,
643 cx: &mut App,
644 ) -> Task<Result<acp::PromptResponse>> {
645 let id = id.expect("UserMessageId is required");
646 let session_id = params.session_id.clone();
647 let agent = self.0.clone();
648 log::info!("Received prompt request for session: {}", session_id);
649 log::debug!("Prompt blocks count: {}", params.prompt.len());
650
651 cx.spawn(async move |cx| {
652 // Get session
653 let (thread, acp_thread) = agent
654 .update(cx, |agent, _| {
655 agent
656 .sessions
657 .get_mut(&session_id)
658 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
659 })?
660 .ok_or_else(|| {
661 log::error!("Session not found: {}", session_id);
662 anyhow::anyhow!("Session not found")
663 })?;
664 log::debug!("Found session for: {}", session_id);
665
666 let content: Vec<UserMessageContent> = params
667 .prompt
668 .into_iter()
669 .map(Into::into)
670 .collect::<Vec<_>>();
671 log::info!("Converted prompt to message: {} chars", content.len());
672 log::debug!("Message id: {:?}", id);
673 log::debug!("Message content: {:?}", content);
674
675 // Get model using the ModelSelector capability (always available for agent2)
676 // Get the selected model from the thread directly
677 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
678
679 // Send to thread
680 log::info!("Sending message to thread with model: {:?}", model.name());
681 let mut response_stream =
682 thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
683
684 // Handle response stream and forward to session.acp_thread
685 while let Some(result) = response_stream.next().await {
686 match result {
687 Ok(event) => {
688 log::trace!("Received completion event: {:?}", event);
689
690 match event {
691 AgentResponseEvent::Text(text) => {
692 acp_thread.update(cx, |thread, cx| {
693 thread.push_assistant_content_block(
694 acp::ContentBlock::Text(acp::TextContent {
695 text,
696 annotations: None,
697 }),
698 false,
699 cx,
700 )
701 })?;
702 }
703 AgentResponseEvent::Thinking(text) => {
704 acp_thread.update(cx, |thread, cx| {
705 thread.push_assistant_content_block(
706 acp::ContentBlock::Text(acp::TextContent {
707 text,
708 annotations: None,
709 }),
710 true,
711 cx,
712 )
713 })?;
714 }
715 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
716 tool_call,
717 options,
718 response,
719 }) => {
720 let recv = acp_thread.update(cx, |thread, cx| {
721 thread.request_tool_call_authorization(tool_call, options, cx)
722 })?;
723 cx.background_spawn(async move {
724 if let Some(option) = recv
725 .await
726 .context("authorization sender was dropped")
727 .log_err()
728 {
729 response
730 .send(option)
731 .map(|_| anyhow!("authorization receiver was dropped"))
732 .log_err();
733 }
734 })
735 .detach();
736 }
737 AgentResponseEvent::ToolCall(tool_call) => {
738 acp_thread.update(cx, |thread, cx| {
739 thread.upsert_tool_call(tool_call, cx)
740 })?;
741 }
742 AgentResponseEvent::ToolCallUpdate(update) => {
743 acp_thread.update(cx, |thread, cx| {
744 thread.update_tool_call(update, cx)
745 })??;
746 }
747 AgentResponseEvent::Stop(stop_reason) => {
748 log::debug!("Assistant message complete: {:?}", stop_reason);
749 return Ok(acp::PromptResponse { stop_reason });
750 }
751 }
752 }
753 Err(e) => {
754 log::error!("Error in model response stream: {:?}", e);
755 // TODO: Consider sending an error message to the UI
756 break;
757 }
758 }
759 }
760
761 log::info!("Response stream completed");
762 anyhow::Ok(acp::PromptResponse {
763 stop_reason: acp::StopReason::EndTurn,
764 })
765 })
766 }
767
768 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
769 log::info!("Cancelling on session: {}", session_id);
770 self.0.update(cx, |agent, cx| {
771 if let Some(agent) = agent.sessions.get(session_id) {
772 agent.thread.update(cx, |thread, _cx| thread.cancel());
773 }
774 });
775 }
776
777 fn session_editor(
778 &self,
779 session_id: &agent_client_protocol::SessionId,
780 cx: &mut App,
781 ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
782 self.0.update(cx, |agent, _cx| {
783 agent
784 .sessions
785 .get(session_id)
786 .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
787 })
788 }
789}
790
791struct NativeAgentSessionEditor(Entity<Thread>);
792
793impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
794 fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
795 Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
796 }
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
803 use fs::FakeFs;
804 use gpui::TestAppContext;
805 use serde_json::json;
806 use settings::SettingsStore;
807
808 #[gpui::test]
809 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
810 init_test(cx);
811 let fs = FakeFs::new(cx.executor());
812 fs.insert_tree(
813 "/",
814 json!({
815 "a": {}
816 }),
817 )
818 .await;
819 let project = Project::test(fs.clone(), [], cx).await;
820 let agent = NativeAgent::new(
821 project.clone(),
822 Templates::new(),
823 None,
824 fs.clone(),
825 &mut cx.to_async(),
826 )
827 .await
828 .unwrap();
829 agent.read_with(cx, |agent, _| {
830 assert_eq!(agent.project_context.borrow().worktrees, vec![])
831 });
832
833 let worktree = project
834 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
835 .await
836 .unwrap();
837 cx.run_until_parked();
838 agent.read_with(cx, |agent, _| {
839 assert_eq!(
840 agent.project_context.borrow().worktrees,
841 vec![WorktreeContext {
842 root_name: "a".into(),
843 abs_path: Path::new("/a").into(),
844 rules_file: None
845 }]
846 )
847 });
848
849 // Creating `/a/.rules` updates the project context.
850 fs.insert_file("/a/.rules", Vec::new()).await;
851 cx.run_until_parked();
852 agent.read_with(cx, |agent, cx| {
853 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
854 assert_eq!(
855 agent.project_context.borrow().worktrees,
856 vec![WorktreeContext {
857 root_name: "a".into(),
858 abs_path: Path::new("/a").into(),
859 rules_file: Some(RulesFileContext {
860 path_in_worktree: Path::new(".rules").into(),
861 text: "".into(),
862 project_entry_id: rules_entry.id.to_usize()
863 })
864 }]
865 )
866 });
867 }
868
869 #[gpui::test]
870 async fn test_listing_models(cx: &mut TestAppContext) {
871 init_test(cx);
872 let fs = FakeFs::new(cx.executor());
873 fs.insert_tree("/", json!({ "a": {} })).await;
874 let project = Project::test(fs.clone(), [], cx).await;
875 let connection = NativeAgentConnection(
876 NativeAgent::new(
877 project.clone(),
878 Templates::new(),
879 None,
880 fs.clone(),
881 &mut cx.to_async(),
882 )
883 .await
884 .unwrap(),
885 );
886
887 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
888
889 let acp_thread::AgentModelList::Grouped(models) = models else {
890 panic!("Unexpected model group");
891 };
892 assert_eq!(
893 models,
894 IndexMap::from_iter([(
895 AgentModelGroupName("Fake".into()),
896 vec![AgentModelInfo {
897 id: AgentModelId("fake/fake".into()),
898 name: "Fake".into(),
899 icon: Some(ui::IconName::ZedAssistant),
900 }]
901 )])
902 );
903 }
904
905 #[gpui::test]
906 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
907 init_test(cx);
908 let fs = FakeFs::new(cx.executor());
909 fs.create_dir(paths::settings_file().parent().unwrap())
910 .await
911 .unwrap();
912 fs.insert_file(
913 paths::settings_file(),
914 json!({
915 "agent": {
916 "default_model": {
917 "provider": "foo",
918 "model": "bar"
919 }
920 }
921 })
922 .to_string()
923 .into_bytes(),
924 )
925 .await;
926 let project = Project::test(fs.clone(), [], cx).await;
927
928 // Create the agent and connection
929 let agent = NativeAgent::new(
930 project.clone(),
931 Templates::new(),
932 None,
933 fs.clone(),
934 &mut cx.to_async(),
935 )
936 .await
937 .unwrap();
938 let connection = NativeAgentConnection(agent.clone());
939
940 // Create a thread/session
941 let acp_thread = cx
942 .update(|cx| {
943 Rc::new(connection.clone()).new_thread(
944 project.clone(),
945 Path::new("/a"),
946 &mut cx.to_async(),
947 )
948 })
949 .await
950 .unwrap();
951
952 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
953
954 // Select a model
955 let model_id = AgentModelId("fake/fake".into());
956 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
957 .await
958 .unwrap();
959
960 // Verify the thread has the selected model
961 agent.read_with(cx, |agent, _| {
962 let session = agent.sessions.get(&session_id).unwrap();
963 session.thread.read_with(cx, |thread, _| {
964 assert_eq!(thread.selected_model.id().0, "fake");
965 });
966 });
967
968 cx.run_until_parked();
969
970 // Verify settings file was updated
971 let settings_content = fs.load(paths::settings_file()).await.unwrap();
972 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
973
974 // Check that the agent settings contain the selected model
975 assert_eq!(
976 settings_json["agent"]["default_model"]["model"],
977 json!("fake")
978 );
979 assert_eq!(
980 settings_json["agent"]["default_model"]["provider"],
981 json!("fake")
982 );
983 }
984
985 fn init_test(cx: &mut TestAppContext) {
986 env_logger::try_init().ok();
987 cx.update(|cx| {
988 let settings_store = SettingsStore::test(cx);
989 cx.set_global(settings_store);
990 Project::init_settings(cx);
991 agent_settings::init(cx);
992 language::init(cx);
993 LanguageModelRegistry::test(cx);
994 });
995 }
996}