1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::{mpsc, oneshot};
14use futures::future::Shared;
15use futures::{StreamExt, future};
16use gpui::{
17 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
18};
19use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
20use project::{Project, ProjectItem, ProjectPath, Worktree};
21use prompt_store::{
22 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
23};
24use settings::update_settings_file;
25use std::any::Any;
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 pending_save: Task<()>,
55 _subscriptions: Vec<Subscription>,
56}
57
58pub struct LanguageModels {
59 /// Access language model by ID
60 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
61 /// Cached list for returning language model information
62 model_list: acp_thread::AgentModelList,
63 refresh_models_rx: watch::Receiver<()>,
64 refresh_models_tx: watch::Sender<()>,
65 _authenticate_all_providers_task: Task<()>,
66}
67
68impl LanguageModels {
69 fn new(cx: &mut App) -> Self {
70 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
71
72 let mut this = Self {
73 models: HashMap::default(),
74 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
75 refresh_models_rx,
76 refresh_models_tx,
77 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
78 };
79 this.refresh_list(cx);
80 this
81 }
82
83 fn refresh_list(&mut self, cx: &App) {
84 let providers = LanguageModelRegistry::global(cx)
85 .read(cx)
86 .providers()
87 .into_iter()
88 .filter(|provider| provider.is_authenticated(cx))
89 .collect::<Vec<_>>();
90
91 let mut language_model_list = IndexMap::default();
92 let mut recommended_models = HashSet::default();
93
94 let mut recommended = Vec::new();
95 for provider in &providers {
96 for model in provider.recommended_models(cx) {
97 recommended_models.insert((model.provider_id(), model.id()));
98 recommended.push(Self::map_language_model_to_info(&model, provider));
99 }
100 }
101 if !recommended.is_empty() {
102 language_model_list.insert(
103 acp_thread::AgentModelGroupName("Recommended".into()),
104 recommended,
105 );
106 }
107
108 let mut models = HashMap::default();
109 for provider in providers {
110 let mut provider_models = Vec::new();
111 for model in provider.provided_models(cx) {
112 let model_info = Self::map_language_model_to_info(&model, &provider);
113 let model_id = model_info.id.clone();
114 if !recommended_models.contains(&(model.provider_id(), model.id())) {
115 provider_models.push(model_info);
116 }
117 models.insert(model_id, model);
118 }
119 if !provider_models.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName(provider.name().0.clone()),
122 provider_models,
123 );
124 }
125 }
126
127 self.models = models;
128 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
129 self.refresh_models_tx.send(()).ok();
130 }
131
132 fn watch(&self) -> watch::Receiver<()> {
133 self.refresh_models_rx.clone()
134 }
135
136 pub fn model_from_id(
137 &self,
138 model_id: &acp_thread::AgentModelId,
139 ) -> Option<Arc<dyn LanguageModel>> {
140 self.models.get(model_id).cloned()
141 }
142
143 fn map_language_model_to_info(
144 model: &Arc<dyn LanguageModel>,
145 provider: &Arc<dyn LanguageModelProvider>,
146 ) -> acp_thread::AgentModelInfo {
147 acp_thread::AgentModelInfo {
148 id: Self::model_id(model),
149 name: model.name().0,
150 icon: Some(provider.icon()),
151 }
152 }
153
154 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
155 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
156 }
157
158 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
159 let authenticate_all_providers = LanguageModelRegistry::global(cx)
160 .read(cx)
161 .providers()
162 .iter()
163 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
164 .collect::<Vec<_>>();
165
166 cx.background_spawn(async move {
167 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
168 if let Err(err) = authenticate_task.await {
169 if matches!(err, language_model::AuthenticateError::CredentialsNotFound) {
170 // Since we're authenticating these providers in the
171 // background for the purposes of populating the
172 // language selector, we don't care about providers
173 // where the credentials are not found.
174 } else {
175 // Some providers have noisy failure states that we
176 // don't want to spam the logs with every time the
177 // language model selector is initialized.
178 //
179 // Ideally these should have more clear failure modes
180 // that we know are safe to ignore here, like what we do
181 // with `CredentialsNotFound` above.
182 match provider_id.0.as_ref() {
183 "lmstudio" | "ollama" => {
184 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
185 //
186 // These fail noisily, so we don't log them.
187 }
188 "copilot_chat" => {
189 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
190 }
191 _ => {
192 log::error!(
193 "Failed to authenticate provider: {}: {err}",
194 provider_name.0
195 );
196 }
197 }
198 }
199 }
200 }
201 })
202 }
203}
204
205pub struct NativeAgent {
206 /// Session ID -> Session mapping
207 sessions: HashMap<acp::SessionId, Session>,
208 history: Entity<HistoryStore>,
209 /// Shared project context for all threads
210 project_context: Entity<ProjectContext>,
211 project_context_needs_refresh: watch::Sender<()>,
212 _maintain_project_context: Task<Result<()>>,
213 context_server_registry: Entity<ContextServerRegistry>,
214 /// Shared templates for all threads
215 templates: Arc<Templates>,
216 /// Cached model information
217 models: LanguageModels,
218 project: Entity<Project>,
219 prompt_store: Option<Entity<PromptStore>>,
220 fs: Arc<dyn Fs>,
221 _subscriptions: Vec<Subscription>,
222}
223
224impl NativeAgent {
225 pub async fn new(
226 project: Entity<Project>,
227 history: Entity<HistoryStore>,
228 templates: Arc<Templates>,
229 prompt_store: Option<Entity<PromptStore>>,
230 fs: Arc<dyn Fs>,
231 cx: &mut AsyncApp,
232 ) -> Result<Entity<NativeAgent>> {
233 log::debug!("Creating new NativeAgent");
234
235 let project_context = cx
236 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
237 .await;
238
239 cx.new(|cx| {
240 let mut subscriptions = vec![
241 cx.subscribe(&project, Self::handle_project_event),
242 cx.subscribe(
243 &LanguageModelRegistry::global(cx),
244 Self::handle_models_updated_event,
245 ),
246 ];
247 if let Some(prompt_store) = prompt_store.as_ref() {
248 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
249 }
250
251 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
252 watch::channel(());
253 Self {
254 sessions: HashMap::new(),
255 history,
256 project_context: cx.new(|_| project_context),
257 project_context_needs_refresh: project_context_needs_refresh_tx,
258 _maintain_project_context: cx.spawn(async move |this, cx| {
259 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
260 }),
261 context_server_registry: cx.new(|cx| {
262 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
263 }),
264 templates,
265 models: LanguageModels::new(cx),
266 project,
267 prompt_store,
268 fs,
269 _subscriptions: subscriptions,
270 }
271 })
272 }
273
274 fn register_session(
275 &mut self,
276 thread_handle: Entity<Thread>,
277 cx: &mut Context<Self>,
278 ) -> Entity<AcpThread> {
279 let connection = Rc::new(NativeAgentConnection(cx.entity()));
280
281 let thread = thread_handle.read(cx);
282 let session_id = thread.id().clone();
283 let title = thread.title();
284 let project = thread.project.clone();
285 let action_log = thread.action_log.clone();
286 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
287 let acp_thread = cx.new(|cx| {
288 acp_thread::AcpThread::new(
289 title,
290 connection,
291 project.clone(),
292 action_log.clone(),
293 session_id.clone(),
294 prompt_capabilities_rx,
295 vec![],
296 cx,
297 )
298 });
299
300 let registry = LanguageModelRegistry::read_global(cx);
301 let summarization_model = registry.thread_summary_model().map(|c| c.model);
302
303 thread_handle.update(cx, |thread, cx| {
304 thread.set_summarization_model(summarization_model, cx);
305 thread.add_default_tools(
306 Rc::new(AcpThreadEnvironment {
307 acp_thread: acp_thread.downgrade(),
308 }) as _,
309 cx,
310 )
311 });
312
313 let subscriptions = vec![
314 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
315 this.sessions.remove(acp_thread.session_id());
316 }),
317 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
318 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
319 cx.observe(&thread_handle, move |this, thread, cx| {
320 this.save_thread(thread, cx)
321 }),
322 ];
323
324 self.sessions.insert(
325 session_id,
326 Session {
327 thread: thread_handle,
328 acp_thread: acp_thread.downgrade(),
329 _subscriptions: subscriptions,
330 pending_save: Task::ready(()),
331 },
332 );
333 acp_thread
334 }
335
336 pub fn models(&self) -> &LanguageModels {
337 &self.models
338 }
339
340 async fn maintain_project_context(
341 this: WeakEntity<Self>,
342 mut needs_refresh: watch::Receiver<()>,
343 cx: &mut AsyncApp,
344 ) -> Result<()> {
345 while needs_refresh.changed().await.is_ok() {
346 let project_context = this
347 .update(cx, |this, cx| {
348 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
349 })?
350 .await;
351 this.update(cx, |this, cx| {
352 this.project_context = cx.new(|_| project_context);
353 })?;
354 }
355
356 Ok(())
357 }
358
359 fn build_project_context(
360 project: &Entity<Project>,
361 prompt_store: Option<&Entity<PromptStore>>,
362 cx: &mut App,
363 ) -> Task<ProjectContext> {
364 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
365 let worktree_tasks = worktrees
366 .into_iter()
367 .map(|worktree| {
368 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
369 })
370 .collect::<Vec<_>>();
371 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
372 prompt_store.read_with(cx, |prompt_store, cx| {
373 let prompts = prompt_store.default_prompt_metadata();
374 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
375 let contents = prompt_store.load(prompt_metadata.id, cx);
376 async move { (contents.await, prompt_metadata) }
377 });
378 cx.background_spawn(future::join_all(load_tasks))
379 })
380 } else {
381 Task::ready(vec![])
382 };
383
384 cx.spawn(async move |_cx| {
385 let (worktrees, default_user_rules) =
386 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
387
388 let worktrees = worktrees
389 .into_iter()
390 .map(|(worktree, _rules_error)| {
391 // TODO: show error message
392 // if let Some(rules_error) = rules_error {
393 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
394 // }
395 worktree
396 })
397 .collect::<Vec<_>>();
398
399 let default_user_rules = default_user_rules
400 .into_iter()
401 .flat_map(|(contents, prompt_metadata)| match contents {
402 Ok(contents) => Some(UserRulesContext {
403 uuid: match prompt_metadata.id {
404 PromptId::User { uuid } => uuid,
405 PromptId::EditWorkflow => return None,
406 },
407 title: prompt_metadata.title.map(|title| title.to_string()),
408 contents,
409 }),
410 Err(_err) => {
411 // TODO: show error message
412 // this.update(cx, |_, cx| {
413 // cx.emit(RulesLoadingError {
414 // message: format!("{err:?}").into(),
415 // });
416 // })
417 // .ok();
418 None
419 }
420 })
421 .collect::<Vec<_>>();
422
423 ProjectContext::new(worktrees, default_user_rules)
424 })
425 }
426
427 fn load_worktree_info_for_system_prompt(
428 worktree: Entity<Worktree>,
429 project: Entity<Project>,
430 cx: &mut App,
431 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
432 let tree = worktree.read(cx);
433 let root_name = tree.root_name().into();
434 let abs_path = tree.abs_path();
435
436 let mut context = WorktreeContext {
437 root_name,
438 abs_path,
439 rules_file: None,
440 };
441
442 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
443 let Some(rules_task) = rules_task else {
444 return Task::ready((context, None));
445 };
446
447 cx.spawn(async move |_| {
448 let (rules_file, rules_file_error) = match rules_task.await {
449 Ok(rules_file) => (Some(rules_file), None),
450 Err(err) => (
451 None,
452 Some(RulesLoadingError {
453 message: format!("{err}").into(),
454 }),
455 ),
456 };
457 context.rules_file = rules_file;
458 (context, rules_file_error)
459 })
460 }
461
462 fn load_worktree_rules_file(
463 worktree: Entity<Worktree>,
464 project: Entity<Project>,
465 cx: &mut App,
466 ) -> Option<Task<Result<RulesFileContext>>> {
467 let worktree = worktree.read(cx);
468 let worktree_id = worktree.id();
469 let selected_rules_file = RULES_FILE_NAMES
470 .into_iter()
471 .filter_map(|name| {
472 worktree
473 .entry_for_path(name)
474 .filter(|entry| entry.is_file())
475 .map(|entry| entry.path.clone())
476 })
477 .next();
478
479 // Note that Cline supports `.clinerules` being a directory, but that is not currently
480 // supported. This doesn't seem to occur often in GitHub repositories.
481 selected_rules_file.map(|path_in_worktree| {
482 let project_path = ProjectPath {
483 worktree_id,
484 path: path_in_worktree.clone(),
485 };
486 let buffer_task =
487 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
488 let rope_task = cx.spawn(async move |cx| {
489 buffer_task.await?.read_with(cx, |buffer, cx| {
490 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
491 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
492 })?
493 });
494 // Build a string from the rope on a background thread.
495 cx.background_spawn(async move {
496 let (project_entry_id, rope) = rope_task.await?;
497 anyhow::Ok(RulesFileContext {
498 path_in_worktree,
499 text: rope.to_string().trim().to_string(),
500 project_entry_id: project_entry_id.to_usize(),
501 })
502 })
503 })
504 }
505
506 fn handle_thread_title_updated(
507 &mut self,
508 thread: Entity<Thread>,
509 _: &TitleUpdated,
510 cx: &mut Context<Self>,
511 ) {
512 let session_id = thread.read(cx).id();
513 let Some(session) = self.sessions.get(session_id) else {
514 return;
515 };
516 let thread = thread.downgrade();
517 let acp_thread = session.acp_thread.clone();
518 cx.spawn(async move |_, cx| {
519 let title = thread.read_with(cx, |thread, _| thread.title())?;
520 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
521 task.await
522 })
523 .detach_and_log_err(cx);
524 }
525
526 fn handle_thread_token_usage_updated(
527 &mut self,
528 thread: Entity<Thread>,
529 usage: &TokenUsageUpdated,
530 cx: &mut Context<Self>,
531 ) {
532 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
533 return;
534 };
535 session
536 .acp_thread
537 .update(cx, |acp_thread, cx| {
538 acp_thread.update_token_usage(usage.0.clone(), cx);
539 })
540 .ok();
541 }
542
543 fn handle_project_event(
544 &mut self,
545 _project: Entity<Project>,
546 event: &project::Event,
547 _cx: &mut Context<Self>,
548 ) {
549 match event {
550 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
551 self.project_context_needs_refresh.send(()).ok();
552 }
553 project::Event::WorktreeUpdatedEntries(_, items) => {
554 if items.iter().any(|(path, _, _)| {
555 RULES_FILE_NAMES
556 .iter()
557 .any(|name| path.as_ref() == Path::new(name))
558 }) {
559 self.project_context_needs_refresh.send(()).ok();
560 }
561 }
562 _ => {}
563 }
564 }
565
566 fn handle_prompts_updated_event(
567 &mut self,
568 _prompt_store: Entity<PromptStore>,
569 _event: &prompt_store::PromptsUpdatedEvent,
570 _cx: &mut Context<Self>,
571 ) {
572 self.project_context_needs_refresh.send(()).ok();
573 }
574
575 fn handle_models_updated_event(
576 &mut self,
577 _registry: Entity<LanguageModelRegistry>,
578 _event: &language_model::Event,
579 cx: &mut Context<Self>,
580 ) {
581 self.models.refresh_list(cx);
582
583 let registry = LanguageModelRegistry::read_global(cx);
584 let default_model = registry.default_model().map(|m| m.model);
585 let summarization_model = registry.thread_summary_model().map(|m| m.model);
586
587 for session in self.sessions.values_mut() {
588 session.thread.update(cx, |thread, cx| {
589 if thread.model().is_none()
590 && let Some(model) = default_model.clone()
591 {
592 thread.set_model(model, cx);
593 cx.notify();
594 }
595 thread.set_summarization_model(summarization_model.clone(), cx);
596 });
597 }
598 }
599
600 pub fn open_thread(
601 &mut self,
602 id: acp::SessionId,
603 cx: &mut Context<Self>,
604 ) -> Task<Result<Entity<AcpThread>>> {
605 let database_future = ThreadsDatabase::connect(cx);
606 cx.spawn(async move |this, cx| {
607 let database = database_future.await.map_err(|err| anyhow!(err))?;
608 let db_thread = database
609 .load_thread(id.clone())
610 .await?
611 .with_context(|| format!("no thread found with ID: {id:?}"))?;
612
613 let thread = this.update(cx, |this, cx| {
614 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
615 cx.new(|cx| {
616 Thread::from_db(
617 id.clone(),
618 db_thread,
619 this.project.clone(),
620 this.project_context.clone(),
621 this.context_server_registry.clone(),
622 action_log.clone(),
623 this.templates.clone(),
624 cx,
625 )
626 })
627 })?;
628 let acp_thread =
629 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
630 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
631 cx.update(|cx| {
632 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
633 })?
634 .await?;
635 Ok(acp_thread)
636 })
637 }
638
639 pub fn thread_summary(
640 &mut self,
641 id: acp::SessionId,
642 cx: &mut Context<Self>,
643 ) -> Task<Result<SharedString>> {
644 let thread = self.open_thread(id.clone(), cx);
645 cx.spawn(async move |this, cx| {
646 let acp_thread = thread.await?;
647 let result = this
648 .update(cx, |this, cx| {
649 this.sessions
650 .get(&id)
651 .unwrap()
652 .thread
653 .update(cx, |thread, cx| thread.summary(cx))
654 })?
655 .await?;
656 drop(acp_thread);
657 Ok(result)
658 })
659 }
660
661 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
662 if thread.read(cx).is_empty() {
663 return;
664 }
665
666 let database_future = ThreadsDatabase::connect(cx);
667 let (id, db_thread) =
668 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
669 let Some(session) = self.sessions.get_mut(&id) else {
670 return;
671 };
672 let history = self.history.clone();
673 session.pending_save = cx.spawn(async move |_, cx| {
674 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
675 return;
676 };
677 let db_thread = db_thread.await;
678 database.save_thread(id, db_thread).await.log_err();
679 history.update(cx, |history, cx| history.reload(cx)).ok();
680 });
681 }
682}
683
684/// Wrapper struct that implements the AgentConnection trait
685#[derive(Clone)]
686pub struct NativeAgentConnection(pub Entity<NativeAgent>);
687
688impl NativeAgentConnection {
689 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
690 self.0
691 .read(cx)
692 .sessions
693 .get(session_id)
694 .map(|session| session.thread.clone())
695 }
696
697 fn run_turn(
698 &self,
699 session_id: acp::SessionId,
700 cx: &mut App,
701 f: impl 'static
702 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
703 ) -> Task<Result<acp::PromptResponse>> {
704 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
705 agent
706 .sessions
707 .get_mut(&session_id)
708 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
709 }) else {
710 return Task::ready(Err(anyhow!("Session not found")));
711 };
712 log::debug!("Found session for: {}", session_id);
713
714 let response_stream = match f(thread, cx) {
715 Ok(stream) => stream,
716 Err(err) => return Task::ready(Err(err)),
717 };
718 Self::handle_thread_events(response_stream, acp_thread, cx)
719 }
720
721 fn handle_thread_events(
722 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
723 acp_thread: WeakEntity<AcpThread>,
724 cx: &App,
725 ) -> Task<Result<acp::PromptResponse>> {
726 cx.spawn(async move |cx| {
727 // Handle response stream and forward to session.acp_thread
728 while let Some(result) = events.next().await {
729 match result {
730 Ok(event) => {
731 log::trace!("Received completion event: {:?}", event);
732
733 match event {
734 ThreadEvent::UserMessage(message) => {
735 acp_thread.update(cx, |thread, cx| {
736 for content in message.content {
737 thread.push_user_content_block(
738 Some(message.id.clone()),
739 content.into(),
740 cx,
741 );
742 }
743 })?;
744 }
745 ThreadEvent::AgentText(text) => {
746 acp_thread.update(cx, |thread, cx| {
747 thread.push_assistant_content_block(
748 acp::ContentBlock::Text(acp::TextContent {
749 text,
750 annotations: None,
751 }),
752 false,
753 cx,
754 )
755 })?;
756 }
757 ThreadEvent::AgentThinking(text) => {
758 acp_thread.update(cx, |thread, cx| {
759 thread.push_assistant_content_block(
760 acp::ContentBlock::Text(acp::TextContent {
761 text,
762 annotations: None,
763 }),
764 true,
765 cx,
766 )
767 })?;
768 }
769 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
770 tool_call,
771 options,
772 response,
773 }) => {
774 let outcome_task = acp_thread.update(cx, |thread, cx| {
775 thread.request_tool_call_authorization(tool_call, options, cx)
776 })??;
777 cx.background_spawn(async move {
778 if let acp::RequestPermissionOutcome::Selected { option_id } =
779 outcome_task.await
780 {
781 response
782 .send(option_id)
783 .map(|_| anyhow!("authorization receiver was dropped"))
784 .log_err();
785 }
786 })
787 .detach();
788 }
789 ThreadEvent::ToolCall(tool_call) => {
790 acp_thread.update(cx, |thread, cx| {
791 thread.upsert_tool_call(tool_call, cx)
792 })??;
793 }
794 ThreadEvent::ToolCallUpdate(update) => {
795 acp_thread.update(cx, |thread, cx| {
796 thread.update_tool_call(update, cx)
797 })??;
798 }
799 ThreadEvent::Retry(status) => {
800 acp_thread.update(cx, |thread, cx| {
801 thread.update_retry_status(status, cx)
802 })?;
803 }
804 ThreadEvent::Stop(stop_reason) => {
805 log::debug!("Assistant message complete: {:?}", stop_reason);
806 return Ok(acp::PromptResponse { stop_reason });
807 }
808 }
809 }
810 Err(e) => {
811 log::error!("Error in model response stream: {:?}", e);
812 return Err(e);
813 }
814 }
815 }
816
817 log::debug!("Response stream completed");
818 anyhow::Ok(acp::PromptResponse {
819 stop_reason: acp::StopReason::EndTurn,
820 })
821 })
822 }
823}
824
825impl AgentModelSelector for NativeAgentConnection {
826 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
827 log::debug!("NativeAgentConnection::list_models called");
828 let list = self.0.read(cx).models.model_list.clone();
829 Task::ready(if list.is_empty() {
830 Err(anyhow::anyhow!("No models available"))
831 } else {
832 Ok(list)
833 })
834 }
835
836 fn select_model(
837 &self,
838 session_id: acp::SessionId,
839 model_id: acp_thread::AgentModelId,
840 cx: &mut App,
841 ) -> Task<Result<()>> {
842 log::debug!("Setting model for session {}: {}", session_id, model_id);
843 let Some(thread) = self
844 .0
845 .read(cx)
846 .sessions
847 .get(&session_id)
848 .map(|session| session.thread.clone())
849 else {
850 return Task::ready(Err(anyhow!("Session not found")));
851 };
852
853 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
854 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
855 };
856
857 thread.update(cx, |thread, cx| {
858 thread.set_model(model.clone(), cx);
859 });
860
861 update_settings_file::<AgentSettings>(
862 self.0.read(cx).fs.clone(),
863 cx,
864 move |settings, _cx| {
865 settings.set_model(model);
866 },
867 );
868
869 Task::ready(Ok(()))
870 }
871
872 fn selected_model(
873 &self,
874 session_id: &acp::SessionId,
875 cx: &mut App,
876 ) -> Task<Result<acp_thread::AgentModelInfo>> {
877 let session_id = session_id.clone();
878
879 let Some(thread) = self
880 .0
881 .read(cx)
882 .sessions
883 .get(&session_id)
884 .map(|session| session.thread.clone())
885 else {
886 return Task::ready(Err(anyhow!("Session not found")));
887 };
888 let Some(model) = thread.read(cx).model() else {
889 return Task::ready(Err(anyhow!("Model not found")));
890 };
891 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
892 else {
893 return Task::ready(Err(anyhow!("Provider not found")));
894 };
895 Task::ready(Ok(LanguageModels::map_language_model_to_info(
896 model, &provider,
897 )))
898 }
899
900 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
901 self.0.read(cx).models.watch()
902 }
903}
904
905impl acp_thread::AgentConnection for NativeAgentConnection {
906 fn new_thread(
907 self: Rc<Self>,
908 project: Entity<Project>,
909 cwd: &Path,
910 cx: &mut App,
911 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
912 let agent = self.0.clone();
913 log::debug!("Creating new thread for project at: {:?}", cwd);
914
915 cx.spawn(async move |cx| {
916 log::debug!("Starting thread creation in async context");
917
918 // Create Thread
919 let thread = agent.update(
920 cx,
921 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
922 // Fetch default model from registry settings
923 let registry = LanguageModelRegistry::read_global(cx);
924 // Log available models for debugging
925 let available_count = registry.available_models(cx).count();
926 log::debug!("Total available models: {}", available_count);
927
928 let default_model = registry.default_model().and_then(|default_model| {
929 agent
930 .models
931 .model_from_id(&LanguageModels::model_id(&default_model.model))
932 });
933 Ok(cx.new(|cx| {
934 Thread::new(
935 project.clone(),
936 agent.project_context.clone(),
937 agent.context_server_registry.clone(),
938 agent.templates.clone(),
939 default_model,
940 cx,
941 )
942 }))
943 },
944 )??;
945 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
946 })
947 }
948
949 fn auth_methods(&self) -> &[acp::AuthMethod] {
950 &[] // No auth for in-process
951 }
952
953 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
954 Task::ready(Ok(()))
955 }
956
957 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
958 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
959 }
960
961 fn prompt(
962 &self,
963 id: Option<acp_thread::UserMessageId>,
964 params: acp::PromptRequest,
965 cx: &mut App,
966 ) -> Task<Result<acp::PromptResponse>> {
967 let id = id.expect("UserMessageId is required");
968 let session_id = params.session_id.clone();
969 log::info!("Received prompt request for session: {}", session_id);
970 log::debug!("Prompt blocks count: {}", params.prompt.len());
971
972 self.run_turn(session_id, cx, |thread, cx| {
973 let content: Vec<UserMessageContent> = params
974 .prompt
975 .into_iter()
976 .map(Into::into)
977 .collect::<Vec<_>>();
978 log::debug!("Converted prompt to message: {} chars", content.len());
979 log::debug!("Message id: {:?}", id);
980 log::debug!("Message content: {:?}", content);
981
982 thread.update(cx, |thread, cx| thread.send(id, content, cx))
983 })
984 }
985
986 fn resume(
987 &self,
988 session_id: &acp::SessionId,
989 _cx: &App,
990 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
991 Some(Rc::new(NativeAgentSessionResume {
992 connection: self.clone(),
993 session_id: session_id.clone(),
994 }) as _)
995 }
996
997 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
998 log::info!("Cancelling on session: {}", session_id);
999 self.0.update(cx, |agent, cx| {
1000 if let Some(agent) = agent.sessions.get(session_id) {
1001 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1002 }
1003 });
1004 }
1005
1006 fn truncate(
1007 &self,
1008 session_id: &agent_client_protocol::SessionId,
1009 cx: &App,
1010 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1011 self.0.read_with(cx, |agent, _cx| {
1012 agent.sessions.get(session_id).map(|session| {
1013 Rc::new(NativeAgentSessionTruncate {
1014 thread: session.thread.clone(),
1015 acp_thread: session.acp_thread.clone(),
1016 }) as _
1017 })
1018 })
1019 }
1020
1021 fn set_title(
1022 &self,
1023 session_id: &acp::SessionId,
1024 _cx: &App,
1025 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1026 Some(Rc::new(NativeAgentSessionSetTitle {
1027 connection: self.clone(),
1028 session_id: session_id.clone(),
1029 }) as _)
1030 }
1031
1032 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1033 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1034 }
1035
1036 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1037 self
1038 }
1039}
1040
1041impl acp_thread::AgentTelemetry for NativeAgentConnection {
1042 fn agent_name(&self) -> String {
1043 "Zed".into()
1044 }
1045
1046 fn thread_data(
1047 &self,
1048 session_id: &acp::SessionId,
1049 cx: &mut App,
1050 ) -> Task<Result<serde_json::Value>> {
1051 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1052 return Task::ready(Err(anyhow!("Session not found")));
1053 };
1054
1055 let task = session.thread.read(cx).to_db(cx);
1056 cx.background_spawn(async move {
1057 serde_json::to_value(task.await).context("Failed to serialize thread")
1058 })
1059 }
1060}
1061
1062struct NativeAgentSessionTruncate {
1063 thread: Entity<Thread>,
1064 acp_thread: WeakEntity<AcpThread>,
1065}
1066
1067impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1068 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1069 match self.thread.update(cx, |thread, cx| {
1070 thread.truncate(message_id.clone(), cx)?;
1071 Ok(thread.latest_token_usage())
1072 }) {
1073 Ok(usage) => {
1074 self.acp_thread
1075 .update(cx, |thread, cx| {
1076 thread.update_token_usage(usage, cx);
1077 })
1078 .ok();
1079 Task::ready(Ok(()))
1080 }
1081 Err(error) => Task::ready(Err(error)),
1082 }
1083 }
1084}
1085
1086struct NativeAgentSessionResume {
1087 connection: NativeAgentConnection,
1088 session_id: acp::SessionId,
1089}
1090
1091impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1092 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1093 self.connection
1094 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1095 thread.update(cx, |thread, cx| thread.resume(cx))
1096 })
1097 }
1098}
1099
1100struct NativeAgentSessionSetTitle {
1101 connection: NativeAgentConnection,
1102 session_id: acp::SessionId,
1103}
1104
1105impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1106 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1107 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1108 return Task::ready(Err(anyhow!("session not found")));
1109 };
1110 let thread = session.thread.clone();
1111 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1112 Task::ready(Ok(()))
1113 }
1114}
1115
1116pub struct AcpThreadEnvironment {
1117 acp_thread: WeakEntity<AcpThread>,
1118}
1119
1120impl ThreadEnvironment for AcpThreadEnvironment {
1121 fn create_terminal(
1122 &self,
1123 command: String,
1124 cwd: Option<PathBuf>,
1125 output_byte_limit: Option<u64>,
1126 cx: &mut AsyncApp,
1127 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1128 let task = self.acp_thread.update(cx, |thread, cx| {
1129 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1130 });
1131
1132 let acp_thread = self.acp_thread.clone();
1133 cx.spawn(async move |cx| {
1134 let terminal = task?.await?;
1135
1136 let (drop_tx, drop_rx) = oneshot::channel();
1137 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1138
1139 cx.spawn(async move |cx| {
1140 drop_rx.await.ok();
1141 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1142 })
1143 .detach();
1144
1145 let handle = AcpTerminalHandle {
1146 terminal,
1147 _drop_tx: Some(drop_tx),
1148 };
1149
1150 Ok(Rc::new(handle) as _)
1151 })
1152 }
1153}
1154
1155pub struct AcpTerminalHandle {
1156 terminal: Entity<acp_thread::Terminal>,
1157 _drop_tx: Option<oneshot::Sender<()>>,
1158}
1159
1160impl TerminalHandle for AcpTerminalHandle {
1161 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1162 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1163 }
1164
1165 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1166 self.terminal
1167 .read_with(cx, |term, _cx| term.wait_for_exit())
1168 }
1169
1170 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1171 self.terminal
1172 .read_with(cx, |term, cx| term.current_output(cx))
1173 }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use crate::HistoryEntryId;
1179
1180 use super::*;
1181 use acp_thread::{
1182 AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
1183 };
1184 use fs::FakeFs;
1185 use gpui::TestAppContext;
1186 use indoc::indoc;
1187 use language_model::fake_provider::FakeLanguageModel;
1188 use serde_json::json;
1189 use settings::SettingsStore;
1190 use util::path;
1191
1192 #[gpui::test]
1193 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1194 init_test(cx);
1195 let fs = FakeFs::new(cx.executor());
1196 fs.insert_tree(
1197 "/",
1198 json!({
1199 "a": {}
1200 }),
1201 )
1202 .await;
1203 let project = Project::test(fs.clone(), [], cx).await;
1204 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1205 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1206 let agent = NativeAgent::new(
1207 project.clone(),
1208 history_store,
1209 Templates::new(),
1210 None,
1211 fs.clone(),
1212 &mut cx.to_async(),
1213 )
1214 .await
1215 .unwrap();
1216 agent.read_with(cx, |agent, cx| {
1217 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1218 });
1219
1220 let worktree = project
1221 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1222 .await
1223 .unwrap();
1224 cx.run_until_parked();
1225 agent.read_with(cx, |agent, cx| {
1226 assert_eq!(
1227 agent.project_context.read(cx).worktrees,
1228 vec![WorktreeContext {
1229 root_name: "a".into(),
1230 abs_path: Path::new("/a").into(),
1231 rules_file: None
1232 }]
1233 )
1234 });
1235
1236 // Creating `/a/.rules` updates the project context.
1237 fs.insert_file("/a/.rules", Vec::new()).await;
1238 cx.run_until_parked();
1239 agent.read_with(cx, |agent, cx| {
1240 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1241 assert_eq!(
1242 agent.project_context.read(cx).worktrees,
1243 vec![WorktreeContext {
1244 root_name: "a".into(),
1245 abs_path: Path::new("/a").into(),
1246 rules_file: Some(RulesFileContext {
1247 path_in_worktree: Path::new(".rules").into(),
1248 text: "".into(),
1249 project_entry_id: rules_entry.id.to_usize()
1250 })
1251 }]
1252 )
1253 });
1254 }
1255
1256 #[gpui::test]
1257 async fn test_listing_models(cx: &mut TestAppContext) {
1258 init_test(cx);
1259 let fs = FakeFs::new(cx.executor());
1260 fs.insert_tree("/", json!({ "a": {} })).await;
1261 let project = Project::test(fs.clone(), [], cx).await;
1262 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1263 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1264 let connection = NativeAgentConnection(
1265 NativeAgent::new(
1266 project.clone(),
1267 history_store,
1268 Templates::new(),
1269 None,
1270 fs.clone(),
1271 &mut cx.to_async(),
1272 )
1273 .await
1274 .unwrap(),
1275 );
1276
1277 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1278
1279 let acp_thread::AgentModelList::Grouped(models) = models else {
1280 panic!("Unexpected model group");
1281 };
1282 assert_eq!(
1283 models,
1284 IndexMap::from_iter([(
1285 AgentModelGroupName("Fake".into()),
1286 vec![AgentModelInfo {
1287 id: AgentModelId("fake/fake".into()),
1288 name: "Fake".into(),
1289 icon: Some(ui::IconName::ZedAssistant),
1290 }]
1291 )])
1292 );
1293 }
1294
1295 #[gpui::test]
1296 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1297 init_test(cx);
1298 let fs = FakeFs::new(cx.executor());
1299 fs.create_dir(paths::settings_file().parent().unwrap())
1300 .await
1301 .unwrap();
1302 fs.insert_file(
1303 paths::settings_file(),
1304 json!({
1305 "agent": {
1306 "default_model": {
1307 "provider": "foo",
1308 "model": "bar"
1309 }
1310 }
1311 })
1312 .to_string()
1313 .into_bytes(),
1314 )
1315 .await;
1316 let project = Project::test(fs.clone(), [], cx).await;
1317
1318 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1319 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1320
1321 // Create the agent and connection
1322 let agent = NativeAgent::new(
1323 project.clone(),
1324 history_store,
1325 Templates::new(),
1326 None,
1327 fs.clone(),
1328 &mut cx.to_async(),
1329 )
1330 .await
1331 .unwrap();
1332 let connection = NativeAgentConnection(agent.clone());
1333
1334 // Create a thread/session
1335 let acp_thread = cx
1336 .update(|cx| {
1337 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1338 })
1339 .await
1340 .unwrap();
1341
1342 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1343
1344 // Select a model
1345 let model_id = AgentModelId("fake/fake".into());
1346 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1347 .await
1348 .unwrap();
1349
1350 // Verify the thread has the selected model
1351 agent.read_with(cx, |agent, _| {
1352 let session = agent.sessions.get(&session_id).unwrap();
1353 session.thread.read_with(cx, |thread, _| {
1354 assert_eq!(thread.model().unwrap().id().0, "fake");
1355 });
1356 });
1357
1358 cx.run_until_parked();
1359
1360 // Verify settings file was updated
1361 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1362 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1363
1364 // Check that the agent settings contain the selected model
1365 assert_eq!(
1366 settings_json["agent"]["default_model"]["model"],
1367 json!("fake")
1368 );
1369 assert_eq!(
1370 settings_json["agent"]["default_model"]["provider"],
1371 json!("fake")
1372 );
1373 }
1374
1375 #[gpui::test]
1376 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1377 async fn test_save_load_thread(cx: &mut TestAppContext) {
1378 init_test(cx);
1379 let fs = FakeFs::new(cx.executor());
1380 fs.insert_tree(
1381 "/",
1382 json!({
1383 "a": {
1384 "b.md": "Lorem"
1385 }
1386 }),
1387 )
1388 .await;
1389 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1390 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1391 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1392 let agent = NativeAgent::new(
1393 project.clone(),
1394 history_store.clone(),
1395 Templates::new(),
1396 None,
1397 fs.clone(),
1398 &mut cx.to_async(),
1399 )
1400 .await
1401 .unwrap();
1402 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1403
1404 let acp_thread = cx
1405 .update(|cx| {
1406 connection
1407 .clone()
1408 .new_thread(project.clone(), Path::new(""), cx)
1409 })
1410 .await
1411 .unwrap();
1412 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1413 let thread = agent.read_with(cx, |agent, _| {
1414 agent.sessions.get(&session_id).unwrap().thread.clone()
1415 });
1416
1417 // Ensure empty threads are not saved, even if they get mutated.
1418 let model = Arc::new(FakeLanguageModel::default());
1419 let summary_model = Arc::new(FakeLanguageModel::default());
1420 thread.update(cx, |thread, cx| {
1421 thread.set_model(model.clone(), cx);
1422 thread.set_summarization_model(Some(summary_model.clone()), cx);
1423 });
1424 cx.run_until_parked();
1425 assert_eq!(history_entries(&history_store, cx), vec![]);
1426
1427 let send = acp_thread.update(cx, |thread, cx| {
1428 thread.send(
1429 vec![
1430 "What does ".into(),
1431 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1432 name: "b.md".into(),
1433 uri: MentionUri::File {
1434 abs_path: path!("/a/b.md").into(),
1435 }
1436 .to_uri()
1437 .to_string(),
1438 annotations: None,
1439 description: None,
1440 mime_type: None,
1441 size: None,
1442 title: None,
1443 }),
1444 " mean?".into(),
1445 ],
1446 cx,
1447 )
1448 });
1449 let send = cx.foreground_executor().spawn(send);
1450 cx.run_until_parked();
1451
1452 model.send_last_completion_stream_text_chunk("Lorem.");
1453 model.end_last_completion_stream();
1454 cx.run_until_parked();
1455 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1456 summary_model.end_last_completion_stream();
1457
1458 send.await.unwrap();
1459 acp_thread.read_with(cx, |thread, cx| {
1460 assert_eq!(
1461 thread.to_markdown(cx),
1462 indoc! {"
1463 ## User
1464
1465 What does [@b.md](file:///a/b.md) mean?
1466
1467 ## Assistant
1468
1469 Lorem.
1470
1471 "}
1472 )
1473 });
1474
1475 cx.run_until_parked();
1476
1477 // Drop the ACP thread, which should cause the session to be dropped as well.
1478 cx.update(|_| {
1479 drop(thread);
1480 drop(acp_thread);
1481 });
1482 agent.read_with(cx, |agent, _| {
1483 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1484 });
1485
1486 // Ensure the thread can be reloaded from disk.
1487 assert_eq!(
1488 history_entries(&history_store, cx),
1489 vec![(
1490 HistoryEntryId::AcpThread(session_id.clone()),
1491 "Explaining /a/b.md".into()
1492 )]
1493 );
1494 let acp_thread = agent
1495 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1496 .await
1497 .unwrap();
1498 acp_thread.read_with(cx, |thread, cx| {
1499 assert_eq!(
1500 thread.to_markdown(cx),
1501 indoc! {"
1502 ## User
1503
1504 What does [@b.md](file:///a/b.md) mean?
1505
1506 ## Assistant
1507
1508 Lorem.
1509
1510 "}
1511 )
1512 });
1513 }
1514
1515 fn history_entries(
1516 history: &Entity<HistoryStore>,
1517 cx: &mut TestAppContext,
1518 ) -> Vec<(HistoryEntryId, String)> {
1519 history.read_with(cx, |history, _| {
1520 history
1521 .entries()
1522 .map(|e| (e.id(), e.title().to_string()))
1523 .collect::<Vec<_>>()
1524 })
1525 }
1526
1527 fn init_test(cx: &mut TestAppContext) {
1528 env_logger::try_init().ok();
1529 cx.update(|cx| {
1530 let settings_store = SettingsStore::test(cx);
1531 cx.set_global(settings_store);
1532 Project::init_settings(cx);
1533 agent_settings::init(cx);
1534 language::init(cx);
1535 LanguageModelRegistry::test(cx);
1536 });
1537 }
1538}