1use crate::{AgentResponseEvent, Thread, templates::Templates};
2use crate::{
3 ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
4 FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
5 OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
6};
7use acp_thread::AgentModelSelector;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::{StreamExt, future};
14use gpui::{
15 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
16};
17use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
18use project::{Project, ProjectItem, ProjectPath, Worktree};
19use prompt_store::{
20 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
21};
22use settings::update_settings_file;
23use std::cell::RefCell;
24use std::collections::HashMap;
25use std::path::Path;
26use std::rc::Rc;
27use std::sync::Arc;
28use util::ResultExt;
29
30const RULES_FILE_NAMES: [&'static str; 9] = [
31 ".rules",
32 ".cursorrules",
33 ".windsurfrules",
34 ".clinerules",
35 ".github/copilot-instructions.md",
36 "CLAUDE.md",
37 "AGENT.md",
38 "AGENTS.md",
39 "GEMINI.md",
40];
41
42pub struct RulesLoadingError {
43 pub message: SharedString,
44}
45
46/// Holds both the internal Thread and the AcpThread for a session
47struct Session {
48 /// The internal thread that processes messages
49 thread: Entity<Thread>,
50 /// The ACP thread that handles protocol communication
51 acp_thread: WeakEntity<acp_thread::AcpThread>,
52 _subscription: Subscription,
53}
54
55pub struct LanguageModels {
56 /// Access language model by ID
57 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
58 /// Cached list for returning language model information
59 model_list: acp_thread::AgentModelList,
60 refresh_models_rx: watch::Receiver<()>,
61 refresh_models_tx: watch::Sender<()>,
62}
63
64impl LanguageModels {
65 fn new(cx: &App) -> Self {
66 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
67 let mut this = Self {
68 models: HashMap::default(),
69 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
70 refresh_models_rx,
71 refresh_models_tx,
72 };
73 this.refresh_list(cx);
74 this
75 }
76
77 fn refresh_list(&mut self, cx: &App) {
78 let providers = LanguageModelRegistry::global(cx)
79 .read(cx)
80 .providers()
81 .into_iter()
82 .filter(|provider| provider.is_authenticated(cx))
83 .collect::<Vec<_>>();
84
85 let mut language_model_list = IndexMap::default();
86 let mut recommended_models = HashSet::default();
87
88 let mut recommended = Vec::new();
89 for provider in &providers {
90 for model in provider.recommended_models(cx) {
91 recommended_models.insert(model.id());
92 recommended.push(Self::map_language_model_to_info(&model, &provider));
93 }
94 }
95 if !recommended.is_empty() {
96 language_model_list.insert(
97 acp_thread::AgentModelGroupName("Recommended".into()),
98 recommended,
99 );
100 }
101
102 let mut models = HashMap::default();
103 for provider in providers {
104 let mut provider_models = Vec::new();
105 for model in provider.provided_models(cx) {
106 let model_info = Self::map_language_model_to_info(&model, &provider);
107 let model_id = model_info.id.clone();
108 if !recommended_models.contains(&model.id()) {
109 provider_models.push(model_info);
110 }
111 models.insert(model_id, model);
112 }
113 if !provider_models.is_empty() {
114 language_model_list.insert(
115 acp_thread::AgentModelGroupName(provider.name().0.clone()),
116 provider_models,
117 );
118 }
119 }
120
121 self.models = models;
122 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
123 self.refresh_models_tx.send(()).ok();
124 }
125
126 fn watch(&self) -> watch::Receiver<()> {
127 self.refresh_models_rx.clone()
128 }
129
130 pub fn model_from_id(
131 &self,
132 model_id: &acp_thread::AgentModelId,
133 ) -> Option<Arc<dyn LanguageModel>> {
134 self.models.get(model_id).cloned()
135 }
136
137 fn map_language_model_to_info(
138 model: &Arc<dyn LanguageModel>,
139 provider: &Arc<dyn LanguageModelProvider>,
140 ) -> acp_thread::AgentModelInfo {
141 acp_thread::AgentModelInfo {
142 id: Self::model_id(model),
143 name: model.name().0,
144 icon: Some(provider.icon()),
145 }
146 }
147
148 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
149 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
150 }
151}
152
153pub struct NativeAgent {
154 /// Session ID -> Session mapping
155 sessions: HashMap<acp::SessionId, Session>,
156 /// Shared project context for all threads
157 project_context: Rc<RefCell<ProjectContext>>,
158 project_context_needs_refresh: watch::Sender<()>,
159 _maintain_project_context: Task<Result<()>>,
160 context_server_registry: Entity<ContextServerRegistry>,
161 /// Shared templates for all threads
162 templates: Arc<Templates>,
163 /// Cached model information
164 models: LanguageModels,
165 project: Entity<Project>,
166 prompt_store: Option<Entity<PromptStore>>,
167 fs: Arc<dyn Fs>,
168 _subscriptions: Vec<Subscription>,
169}
170
171impl NativeAgent {
172 pub async fn new(
173 project: Entity<Project>,
174 templates: Arc<Templates>,
175 prompt_store: Option<Entity<PromptStore>>,
176 fs: Arc<dyn Fs>,
177 cx: &mut AsyncApp,
178 ) -> Result<Entity<NativeAgent>> {
179 log::info!("Creating new NativeAgent");
180
181 let project_context = cx
182 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
183 .await;
184
185 cx.new(|cx| {
186 let mut subscriptions = vec![
187 cx.subscribe(&project, Self::handle_project_event),
188 cx.subscribe(
189 &LanguageModelRegistry::global(cx),
190 Self::handle_models_updated_event,
191 ),
192 ];
193 if let Some(prompt_store) = prompt_store.as_ref() {
194 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
195 }
196
197 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
198 watch::channel(());
199 Self {
200 sessions: HashMap::new(),
201 project_context: Rc::new(RefCell::new(project_context)),
202 project_context_needs_refresh: project_context_needs_refresh_tx,
203 _maintain_project_context: cx.spawn(async move |this, cx| {
204 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
205 }),
206 context_server_registry: cx.new(|cx| {
207 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
208 }),
209 templates,
210 models: LanguageModels::new(cx),
211 project,
212 prompt_store,
213 fs,
214 _subscriptions: subscriptions,
215 }
216 })
217 }
218
219 pub fn models(&self) -> &LanguageModels {
220 &self.models
221 }
222
223 async fn maintain_project_context(
224 this: WeakEntity<Self>,
225 mut needs_refresh: watch::Receiver<()>,
226 cx: &mut AsyncApp,
227 ) -> Result<()> {
228 while needs_refresh.changed().await.is_ok() {
229 let project_context = this
230 .update(cx, |this, cx| {
231 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
232 })?
233 .await;
234 this.update(cx, |this, _| this.project_context.replace(project_context))?;
235 }
236
237 Ok(())
238 }
239
240 fn build_project_context(
241 project: &Entity<Project>,
242 prompt_store: Option<&Entity<PromptStore>>,
243 cx: &mut App,
244 ) -> Task<ProjectContext> {
245 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
246 let worktree_tasks = worktrees
247 .into_iter()
248 .map(|worktree| {
249 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
250 })
251 .collect::<Vec<_>>();
252 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
253 prompt_store.read_with(cx, |prompt_store, cx| {
254 let prompts = prompt_store.default_prompt_metadata();
255 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
256 let contents = prompt_store.load(prompt_metadata.id, cx);
257 async move { (contents.await, prompt_metadata) }
258 });
259 cx.background_spawn(future::join_all(load_tasks))
260 })
261 } else {
262 Task::ready(vec![])
263 };
264
265 cx.spawn(async move |_cx| {
266 let (worktrees, default_user_rules) =
267 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
268
269 let worktrees = worktrees
270 .into_iter()
271 .map(|(worktree, _rules_error)| {
272 // TODO: show error message
273 // if let Some(rules_error) = rules_error {
274 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
275 // }
276 worktree
277 })
278 .collect::<Vec<_>>();
279
280 let default_user_rules = default_user_rules
281 .into_iter()
282 .flat_map(|(contents, prompt_metadata)| match contents {
283 Ok(contents) => Some(UserRulesContext {
284 uuid: match prompt_metadata.id {
285 PromptId::User { uuid } => uuid,
286 PromptId::EditWorkflow => return None,
287 },
288 title: prompt_metadata.title.map(|title| title.to_string()),
289 contents,
290 }),
291 Err(_err) => {
292 // TODO: show error message
293 // this.update(cx, |_, cx| {
294 // cx.emit(RulesLoadingError {
295 // message: format!("{err:?}").into(),
296 // });
297 // })
298 // .ok();
299 None
300 }
301 })
302 .collect::<Vec<_>>();
303
304 ProjectContext::new(worktrees, default_user_rules)
305 })
306 }
307
308 fn load_worktree_info_for_system_prompt(
309 worktree: Entity<Worktree>,
310 project: Entity<Project>,
311 cx: &mut App,
312 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
313 let tree = worktree.read(cx);
314 let root_name = tree.root_name().into();
315 let abs_path = tree.abs_path();
316
317 let mut context = WorktreeContext {
318 root_name,
319 abs_path,
320 rules_file: None,
321 };
322
323 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
324 let Some(rules_task) = rules_task else {
325 return Task::ready((context, None));
326 };
327
328 cx.spawn(async move |_| {
329 let (rules_file, rules_file_error) = match rules_task.await {
330 Ok(rules_file) => (Some(rules_file), None),
331 Err(err) => (
332 None,
333 Some(RulesLoadingError {
334 message: format!("{err}").into(),
335 }),
336 ),
337 };
338 context.rules_file = rules_file;
339 (context, rules_file_error)
340 })
341 }
342
343 fn load_worktree_rules_file(
344 worktree: Entity<Worktree>,
345 project: Entity<Project>,
346 cx: &mut App,
347 ) -> Option<Task<Result<RulesFileContext>>> {
348 let worktree = worktree.read(cx);
349 let worktree_id = worktree.id();
350 let selected_rules_file = RULES_FILE_NAMES
351 .into_iter()
352 .filter_map(|name| {
353 worktree
354 .entry_for_path(name)
355 .filter(|entry| entry.is_file())
356 .map(|entry| entry.path.clone())
357 })
358 .next();
359
360 // Note that Cline supports `.clinerules` being a directory, but that is not currently
361 // supported. This doesn't seem to occur often in GitHub repositories.
362 selected_rules_file.map(|path_in_worktree| {
363 let project_path = ProjectPath {
364 worktree_id,
365 path: path_in_worktree.clone(),
366 };
367 let buffer_task =
368 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
369 let rope_task = cx.spawn(async move |cx| {
370 buffer_task.await?.read_with(cx, |buffer, cx| {
371 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
372 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
373 })?
374 });
375 // Build a string from the rope on a background thread.
376 cx.background_spawn(async move {
377 let (project_entry_id, rope) = rope_task.await?;
378 anyhow::Ok(RulesFileContext {
379 path_in_worktree,
380 text: rope.to_string().trim().to_string(),
381 project_entry_id: project_entry_id.to_usize(),
382 })
383 })
384 })
385 }
386
387 fn handle_project_event(
388 &mut self,
389 _project: Entity<Project>,
390 event: &project::Event,
391 _cx: &mut Context<Self>,
392 ) {
393 match event {
394 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
395 self.project_context_needs_refresh.send(()).ok();
396 }
397 project::Event::WorktreeUpdatedEntries(_, items) => {
398 if items.iter().any(|(path, _, _)| {
399 RULES_FILE_NAMES
400 .iter()
401 .any(|name| path.as_ref() == Path::new(name))
402 }) {
403 self.project_context_needs_refresh.send(()).ok();
404 }
405 }
406 _ => {}
407 }
408 }
409
410 fn handle_prompts_updated_event(
411 &mut self,
412 _prompt_store: Entity<PromptStore>,
413 _event: &prompt_store::PromptsUpdatedEvent,
414 _cx: &mut Context<Self>,
415 ) {
416 self.project_context_needs_refresh.send(()).ok();
417 }
418
419 fn handle_models_updated_event(
420 &mut self,
421 _registry: Entity<LanguageModelRegistry>,
422 _event: &language_model::Event,
423 cx: &mut Context<Self>,
424 ) {
425 self.models.refresh_list(cx);
426 for session in self.sessions.values_mut() {
427 session.thread.update(cx, |thread, _| {
428 let model_id = LanguageModels::model_id(&thread.selected_model);
429 if let Some(model) = self.models.model_from_id(&model_id) {
430 thread.selected_model = model.clone();
431 }
432 });
433 }
434 }
435}
436
437/// Wrapper struct that implements the AgentConnection trait
438#[derive(Clone)]
439pub struct NativeAgentConnection(pub Entity<NativeAgent>);
440
441impl AgentModelSelector for NativeAgentConnection {
442 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
443 log::debug!("NativeAgentConnection::list_models called");
444 let list = self.0.read(cx).models.model_list.clone();
445 Task::ready(if list.is_empty() {
446 Err(anyhow::anyhow!("No models available"))
447 } else {
448 Ok(list)
449 })
450 }
451
452 fn select_model(
453 &self,
454 session_id: acp::SessionId,
455 model_id: acp_thread::AgentModelId,
456 cx: &mut App,
457 ) -> Task<Result<()>> {
458 log::info!("Setting model for session {}: {}", session_id, model_id);
459 let Some(thread) = self
460 .0
461 .read(cx)
462 .sessions
463 .get(&session_id)
464 .map(|session| session.thread.clone())
465 else {
466 return Task::ready(Err(anyhow!("Session not found")));
467 };
468
469 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
470 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
471 };
472
473 thread.update(cx, |thread, _cx| {
474 thread.selected_model = model.clone();
475 });
476
477 update_settings_file::<AgentSettings>(
478 self.0.read(cx).fs.clone(),
479 cx,
480 move |settings, _cx| {
481 settings.set_model(model);
482 },
483 );
484
485 Task::ready(Ok(()))
486 }
487
488 fn selected_model(
489 &self,
490 session_id: &acp::SessionId,
491 cx: &mut App,
492 ) -> Task<Result<acp_thread::AgentModelInfo>> {
493 let session_id = session_id.clone();
494
495 let Some(thread) = self
496 .0
497 .read(cx)
498 .sessions
499 .get(&session_id)
500 .map(|session| session.thread.clone())
501 else {
502 return Task::ready(Err(anyhow!("Session not found")));
503 };
504 let model = thread.read(cx).selected_model.clone();
505 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
506 else {
507 return Task::ready(Err(anyhow!("Provider not found")));
508 };
509 Task::ready(Ok(LanguageModels::map_language_model_to_info(
510 &model, &provider,
511 )))
512 }
513
514 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
515 self.0.read(cx).models.watch()
516 }
517}
518
519impl acp_thread::AgentConnection for NativeAgentConnection {
520 fn new_thread(
521 self: Rc<Self>,
522 project: Entity<Project>,
523 cwd: &Path,
524 cx: &mut AsyncApp,
525 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
526 let agent = self.0.clone();
527 log::info!("Creating new thread for project at: {:?}", cwd);
528
529 cx.spawn(async move |cx| {
530 log::debug!("Starting thread creation in async context");
531
532 // Generate session ID
533 let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
534 log::info!("Created session with ID: {}", session_id);
535
536 // Create AcpThread
537 let acp_thread = cx.update(|cx| {
538 cx.new(|cx| {
539 acp_thread::AcpThread::new(
540 "agent2",
541 self.clone(),
542 project.clone(),
543 session_id.clone(),
544 cx,
545 )
546 })
547 })?;
548 let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
549
550 // Create Thread
551 let thread = agent.update(
552 cx,
553 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
554 // Fetch default model from registry settings
555 let registry = LanguageModelRegistry::read_global(cx);
556
557 // Log available models for debugging
558 let available_count = registry.available_models(cx).count();
559 log::debug!("Total available models: {}", available_count);
560
561 let default_model = registry
562 .default_model()
563 .and_then(|default_model| {
564 agent
565 .models
566 .model_from_id(&LanguageModels::model_id(&default_model.model))
567 })
568 .ok_or_else(|| {
569 log::warn!("No default model configured in settings");
570 anyhow!(
571 "No default model. Please configure a default model in settings."
572 )
573 })?;
574
575 let thread = cx.new(|cx| {
576 let mut thread = Thread::new(
577 project.clone(),
578 agent.project_context.clone(),
579 agent.context_server_registry.clone(),
580 action_log.clone(),
581 agent.templates.clone(),
582 default_model,
583 cx,
584 );
585 thread.add_tool(CreateDirectoryTool::new(project.clone()));
586 thread.add_tool(CopyPathTool::new(project.clone()));
587 thread.add_tool(DiagnosticsTool::new(project.clone()));
588 thread.add_tool(MovePathTool::new(project.clone()));
589 thread.add_tool(ListDirectoryTool::new(project.clone()));
590 thread.add_tool(OpenTool::new(project.clone()));
591 thread.add_tool(ThinkingTool);
592 thread.add_tool(FindPathTool::new(project.clone()));
593 thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
594 thread.add_tool(GrepTool::new(project.clone()));
595 thread.add_tool(ReadFileTool::new(project.clone(), action_log));
596 thread.add_tool(EditFileTool::new(cx.entity()));
597 thread.add_tool(NowTool);
598 thread.add_tool(TerminalTool::new(project.clone(), cx));
599 // TODO: Needs to be conditional based on zed model or not
600 thread.add_tool(WebSearchTool);
601 thread
602 });
603
604 Ok(thread)
605 },
606 )??;
607
608 // Store the session
609 agent.update(cx, |agent, cx| {
610 agent.sessions.insert(
611 session_id,
612 Session {
613 thread,
614 acp_thread: acp_thread.downgrade(),
615 _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
616 this.sessions.remove(acp_thread.session_id());
617 }),
618 },
619 );
620 })?;
621
622 Ok(acp_thread)
623 })
624 }
625
626 fn auth_methods(&self) -> &[acp::AuthMethod] {
627 &[] // No auth for in-process
628 }
629
630 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
631 Task::ready(Ok(()))
632 }
633
634 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
635 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
636 }
637
638 fn prompt(
639 &self,
640 params: acp::PromptRequest,
641 cx: &mut App,
642 ) -> Task<Result<acp::PromptResponse>> {
643 let session_id = params.session_id.clone();
644 let agent = self.0.clone();
645 log::info!("Received prompt request for session: {}", session_id);
646 log::debug!("Prompt blocks count: {}", params.prompt.len());
647
648 cx.spawn(async move |cx| {
649 // Get session
650 let (thread, acp_thread) = agent
651 .update(cx, |agent, _| {
652 agent
653 .sessions
654 .get_mut(&session_id)
655 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
656 })?
657 .ok_or_else(|| {
658 log::error!("Session not found: {}", session_id);
659 anyhow::anyhow!("Session not found")
660 })?;
661 log::debug!("Found session for: {}", session_id);
662
663 let message: Vec<MessageContent> = params
664 .prompt
665 .into_iter()
666 .map(Into::into)
667 .collect::<Vec<_>>();
668 log::info!("Converted prompt to message: {} chars", message.len());
669 log::debug!("Message content: {:?}", message);
670
671 // Get model using the ModelSelector capability (always available for agent2)
672 // Get the selected model from the thread directly
673 let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
674
675 // Send to thread
676 log::info!("Sending message to thread with model: {:?}", model.name());
677 let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
678
679 // Handle response stream and forward to session.acp_thread
680 while let Some(result) = response_stream.next().await {
681 match result {
682 Ok(event) => {
683 log::trace!("Received completion event: {:?}", event);
684
685 match event {
686 AgentResponseEvent::Text(text) => {
687 acp_thread.update(cx, |thread, cx| {
688 thread.push_assistant_content_block(
689 acp::ContentBlock::Text(acp::TextContent {
690 text,
691 annotations: None,
692 }),
693 false,
694 cx,
695 )
696 })?;
697 }
698 AgentResponseEvent::Thinking(text) => {
699 acp_thread.update(cx, |thread, cx| {
700 thread.push_assistant_content_block(
701 acp::ContentBlock::Text(acp::TextContent {
702 text,
703 annotations: None,
704 }),
705 true,
706 cx,
707 )
708 })?;
709 }
710 AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
711 tool_call,
712 options,
713 response,
714 }) => {
715 let recv = acp_thread.update(cx, |thread, cx| {
716 thread.request_tool_call_authorization(tool_call, options, cx)
717 })?;
718 cx.background_spawn(async move {
719 if let Some(option) = recv
720 .await
721 .context("authorization sender was dropped")
722 .log_err()
723 {
724 response
725 .send(option)
726 .map(|_| anyhow!("authorization receiver was dropped"))
727 .log_err();
728 }
729 })
730 .detach();
731 }
732 AgentResponseEvent::ToolCall(tool_call) => {
733 acp_thread.update(cx, |thread, cx| {
734 thread.upsert_tool_call(tool_call, cx)
735 })?;
736 }
737 AgentResponseEvent::ToolCallUpdate(update) => {
738 acp_thread.update(cx, |thread, cx| {
739 thread.update_tool_call(update, cx)
740 })??;
741 }
742 AgentResponseEvent::Stop(stop_reason) => {
743 log::debug!("Assistant message complete: {:?}", stop_reason);
744 return Ok(acp::PromptResponse { stop_reason });
745 }
746 }
747 }
748 Err(e) => {
749 log::error!("Error in model response stream: {:?}", e);
750 // TODO: Consider sending an error message to the UI
751 break;
752 }
753 }
754 }
755
756 log::info!("Response stream completed");
757 anyhow::Ok(acp::PromptResponse {
758 stop_reason: acp::StopReason::EndTurn,
759 })
760 })
761 }
762
763 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
764 log::info!("Cancelling on session: {}", session_id);
765 self.0.update(cx, |agent, cx| {
766 if let Some(agent) = agent.sessions.get(session_id) {
767 agent.thread.update(cx, |thread, _cx| thread.cancel());
768 }
769 });
770 }
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
777 use fs::FakeFs;
778 use gpui::TestAppContext;
779 use serde_json::json;
780 use settings::SettingsStore;
781
782 #[gpui::test]
783 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
784 init_test(cx);
785 let fs = FakeFs::new(cx.executor());
786 fs.insert_tree(
787 "/",
788 json!({
789 "a": {}
790 }),
791 )
792 .await;
793 let project = Project::test(fs.clone(), [], cx).await;
794 let agent = NativeAgent::new(
795 project.clone(),
796 Templates::new(),
797 None,
798 fs.clone(),
799 &mut cx.to_async(),
800 )
801 .await
802 .unwrap();
803 agent.read_with(cx, |agent, _| {
804 assert_eq!(agent.project_context.borrow().worktrees, vec![])
805 });
806
807 let worktree = project
808 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
809 .await
810 .unwrap();
811 cx.run_until_parked();
812 agent.read_with(cx, |agent, _| {
813 assert_eq!(
814 agent.project_context.borrow().worktrees,
815 vec![WorktreeContext {
816 root_name: "a".into(),
817 abs_path: Path::new("/a").into(),
818 rules_file: None
819 }]
820 )
821 });
822
823 // Creating `/a/.rules` updates the project context.
824 fs.insert_file("/a/.rules", Vec::new()).await;
825 cx.run_until_parked();
826 agent.read_with(cx, |agent, cx| {
827 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
828 assert_eq!(
829 agent.project_context.borrow().worktrees,
830 vec![WorktreeContext {
831 root_name: "a".into(),
832 abs_path: Path::new("/a").into(),
833 rules_file: Some(RulesFileContext {
834 path_in_worktree: Path::new(".rules").into(),
835 text: "".into(),
836 project_entry_id: rules_entry.id.to_usize()
837 })
838 }]
839 )
840 });
841 }
842
843 #[gpui::test]
844 async fn test_listing_models(cx: &mut TestAppContext) {
845 init_test(cx);
846 let fs = FakeFs::new(cx.executor());
847 fs.insert_tree("/", json!({ "a": {} })).await;
848 let project = Project::test(fs.clone(), [], cx).await;
849 let connection = NativeAgentConnection(
850 NativeAgent::new(
851 project.clone(),
852 Templates::new(),
853 None,
854 fs.clone(),
855 &mut cx.to_async(),
856 )
857 .await
858 .unwrap(),
859 );
860
861 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
862
863 let acp_thread::AgentModelList::Grouped(models) = models else {
864 panic!("Unexpected model group");
865 };
866 assert_eq!(
867 models,
868 IndexMap::from_iter([(
869 AgentModelGroupName("Fake".into()),
870 vec![AgentModelInfo {
871 id: AgentModelId("fake/fake".into()),
872 name: "Fake".into(),
873 icon: Some(ui::IconName::ZedAssistant),
874 }]
875 )])
876 );
877 }
878
879 #[gpui::test]
880 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
881 init_test(cx);
882 let fs = FakeFs::new(cx.executor());
883 fs.create_dir(paths::settings_file().parent().unwrap())
884 .await
885 .unwrap();
886 fs.insert_file(
887 paths::settings_file(),
888 json!({
889 "agent": {
890 "default_model": {
891 "provider": "foo",
892 "model": "bar"
893 }
894 }
895 })
896 .to_string()
897 .into_bytes(),
898 )
899 .await;
900 let project = Project::test(fs.clone(), [], cx).await;
901
902 // Create the agent and connection
903 let agent = NativeAgent::new(
904 project.clone(),
905 Templates::new(),
906 None,
907 fs.clone(),
908 &mut cx.to_async(),
909 )
910 .await
911 .unwrap();
912 let connection = NativeAgentConnection(agent.clone());
913
914 // Create a thread/session
915 let acp_thread = cx
916 .update(|cx| {
917 Rc::new(connection.clone()).new_thread(
918 project.clone(),
919 Path::new("/a"),
920 &mut cx.to_async(),
921 )
922 })
923 .await
924 .unwrap();
925
926 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
927
928 // Select a model
929 let model_id = AgentModelId("fake/fake".into());
930 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
931 .await
932 .unwrap();
933
934 // Verify the thread has the selected model
935 agent.read_with(cx, |agent, _| {
936 let session = agent.sessions.get(&session_id).unwrap();
937 session.thread.read_with(cx, |thread, _| {
938 assert_eq!(thread.selected_model.id().0, "fake");
939 });
940 });
941
942 cx.run_until_parked();
943
944 // Verify settings file was updated
945 let settings_content = fs.load(paths::settings_file()).await.unwrap();
946 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
947
948 // Check that the agent settings contain the selected model
949 assert_eq!(
950 settings_json["agent"]["default_model"]["model"],
951 json!("fake")
952 );
953 assert_eq!(
954 settings_json["agent"]["default_model"]["provider"],
955 json!("fake")
956 );
957 }
958
959 fn init_test(cx: &mut TestAppContext) {
960 env_logger::try_init().ok();
961 cx.update(|cx| {
962 let settings_store = SettingsStore::test(cx);
963 cx.set_global(settings_store);
964 Project::init_settings(cx);
965 agent_settings::init(cx);
966 language::init(cx);
967 LanguageModelRegistry::test(cx);
968 });
969 }
970}