1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use anyhow::{Context as _, Result, anyhow};
10use collections::{HashSet, IndexMap};
11use fs::Fs;
12use futures::channel::{mpsc, oneshot};
13use futures::future::Shared;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::{LanguageModelSelection, update_settings_file};
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::{Path, PathBuf};
27use std::rc::Rc;
28use std::sync::{Arc, LazyLock};
29use util::ResultExt;
30use util::rel_path::RelPath;
31
32static RULES_FILE_NAMES: LazyLock<[&RelPath; 9]> = LazyLock::new(|| {
33 [
34 RelPath::unix(".rules").unwrap(),
35 RelPath::unix(".cursorrules").unwrap(),
36 RelPath::unix(".windsurfrules").unwrap(),
37 RelPath::unix(".clinerules").unwrap(),
38 RelPath::unix(".github/copilot-instructions.md").unwrap(),
39 RelPath::unix("CLAUDE.md").unwrap(),
40 RelPath::unix("AGENT.md").unwrap(),
41 RelPath::unix("AGENTS.md").unwrap(),
42 RelPath::unix("GEMINI.md").unwrap(),
43 ]
44});
45
46pub struct RulesLoadingError {
47 pub message: SharedString,
48}
49
50/// Holds both the internal Thread and the AcpThread for a session
51struct Session {
52 /// The internal thread that processes messages
53 thread: Entity<Thread>,
54 /// The ACP thread that handles protocol communication
55 acp_thread: WeakEntity<acp_thread::AcpThread>,
56 pending_save: Task<()>,
57 _subscriptions: Vec<Subscription>,
58}
59
60pub struct LanguageModels {
61 /// Access language model by ID
62 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
63 /// Cached list for returning language model information
64 model_list: acp_thread::AgentModelList,
65 refresh_models_rx: watch::Receiver<()>,
66 refresh_models_tx: watch::Sender<()>,
67 _authenticate_all_providers_task: Task<()>,
68}
69
70impl LanguageModels {
71 fn new(cx: &mut App) -> Self {
72 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
73
74 let mut this = Self {
75 models: HashMap::default(),
76 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
77 refresh_models_rx,
78 refresh_models_tx,
79 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
80 };
81 this.refresh_list(cx);
82 this
83 }
84
85 fn refresh_list(&mut self, cx: &App) {
86 let providers = LanguageModelRegistry::global(cx)
87 .read(cx)
88 .providers()
89 .into_iter()
90 .filter(|provider| provider.is_authenticated(cx))
91 .collect::<Vec<_>>();
92
93 let mut language_model_list = IndexMap::default();
94 let mut recommended_models = HashSet::default();
95
96 let mut recommended = Vec::new();
97 for provider in &providers {
98 for model in provider.recommended_models(cx) {
99 recommended_models.insert((model.provider_id(), model.id()));
100 recommended.push(Self::map_language_model_to_info(&model, provider));
101 }
102 }
103 if !recommended.is_empty() {
104 language_model_list.insert(
105 acp_thread::AgentModelGroupName("Recommended".into()),
106 recommended,
107 );
108 }
109
110 let mut models = HashMap::default();
111 for provider in providers {
112 let mut provider_models = Vec::new();
113 for model in provider.provided_models(cx) {
114 let model_info = Self::map_language_model_to_info(&model, &provider);
115 let model_id = model_info.id.clone();
116 if !recommended_models.contains(&(model.provider_id(), model.id())) {
117 provider_models.push(model_info);
118 }
119 models.insert(model_id, model);
120 }
121 if !provider_models.is_empty() {
122 language_model_list.insert(
123 acp_thread::AgentModelGroupName(provider.name().0.clone()),
124 provider_models,
125 );
126 }
127 }
128
129 self.models = models;
130 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
131 self.refresh_models_tx.send(()).ok();
132 }
133
134 fn watch(&self) -> watch::Receiver<()> {
135 self.refresh_models_rx.clone()
136 }
137
138 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
139 self.models.get(model_id).cloned()
140 }
141
142 fn map_language_model_to_info(
143 model: &Arc<dyn LanguageModel>,
144 provider: &Arc<dyn LanguageModelProvider>,
145 ) -> acp_thread::AgentModelInfo {
146 acp_thread::AgentModelInfo {
147 id: Self::model_id(model),
148 name: model.name().0,
149 description: None,
150 icon: Some(provider.icon()),
151 }
152 }
153
154 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
155 acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
156 }
157
158 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
159 let authenticate_all_providers = LanguageModelRegistry::global(cx)
160 .read(cx)
161 .providers()
162 .iter()
163 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
164 .collect::<Vec<_>>();
165
166 cx.background_spawn(async move {
167 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
168 if let Err(err) = authenticate_task.await {
169 match err {
170 language_model::AuthenticateError::CredentialsNotFound => {
171 // Since we're authenticating these providers in the
172 // background for the purposes of populating the
173 // language selector, we don't care about providers
174 // where the credentials are not found.
175 }
176 language_model::AuthenticateError::ConnectionRefused => {
177 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
178 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
179 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
180 }
181 _ => {
182 // Some providers have noisy failure states that we
183 // don't want to spam the logs with every time the
184 // language model selector is initialized.
185 //
186 // Ideally these should have more clear failure modes
187 // that we know are safe to ignore here, like what we do
188 // with `CredentialsNotFound` above.
189 match provider_id.0.as_ref() {
190 "lmstudio" | "ollama" => {
191 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
192 //
193 // These fail noisily, so we don't log them.
194 }
195 "copilot_chat" => {
196 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
197 }
198 _ => {
199 log::error!(
200 "Failed to authenticate provider: {}: {err}",
201 provider_name.0
202 );
203 }
204 }
205 }
206 }
207 }
208 }
209 })
210 }
211}
212
213pub struct NativeAgent {
214 /// Session ID -> Session mapping
215 sessions: HashMap<acp::SessionId, Session>,
216 history: Entity<HistoryStore>,
217 /// Shared project context for all threads
218 project_context: Entity<ProjectContext>,
219 project_context_needs_refresh: watch::Sender<()>,
220 _maintain_project_context: Task<Result<()>>,
221 context_server_registry: Entity<ContextServerRegistry>,
222 /// Shared templates for all threads
223 templates: Arc<Templates>,
224 /// Cached model information
225 models: LanguageModels,
226 project: Entity<Project>,
227 prompt_store: Option<Entity<PromptStore>>,
228 fs: Arc<dyn Fs>,
229 _subscriptions: Vec<Subscription>,
230}
231
232impl NativeAgent {
233 pub async fn new(
234 project: Entity<Project>,
235 history: Entity<HistoryStore>,
236 templates: Arc<Templates>,
237 prompt_store: Option<Entity<PromptStore>>,
238 fs: Arc<dyn Fs>,
239 cx: &mut AsyncApp,
240 ) -> Result<Entity<NativeAgent>> {
241 log::debug!("Creating new NativeAgent");
242
243 let project_context = cx
244 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
245 .await;
246
247 cx.new(|cx| {
248 let mut subscriptions = vec![
249 cx.subscribe(&project, Self::handle_project_event),
250 cx.subscribe(
251 &LanguageModelRegistry::global(cx),
252 Self::handle_models_updated_event,
253 ),
254 ];
255 if let Some(prompt_store) = prompt_store.as_ref() {
256 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
257 }
258
259 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
260 watch::channel(());
261 Self {
262 sessions: HashMap::new(),
263 history,
264 project_context: cx.new(|_| project_context),
265 project_context_needs_refresh: project_context_needs_refresh_tx,
266 _maintain_project_context: cx.spawn(async move |this, cx| {
267 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
268 }),
269 context_server_registry: cx.new(|cx| {
270 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
271 }),
272 templates,
273 models: LanguageModels::new(cx),
274 project,
275 prompt_store,
276 fs,
277 _subscriptions: subscriptions,
278 }
279 })
280 }
281
282 fn register_session(
283 &mut self,
284 thread_handle: Entity<Thread>,
285 cx: &mut Context<Self>,
286 ) -> Entity<AcpThread> {
287 let connection = Rc::new(NativeAgentConnection(cx.entity()));
288
289 let thread = thread_handle.read(cx);
290 let session_id = thread.id().clone();
291 let title = thread.title();
292 let project = thread.project.clone();
293 let action_log = thread.action_log.clone();
294 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
295 let acp_thread = cx.new(|cx| {
296 acp_thread::AcpThread::new(
297 title,
298 connection,
299 project.clone(),
300 action_log.clone(),
301 session_id.clone(),
302 prompt_capabilities_rx,
303 cx,
304 )
305 });
306
307 let registry = LanguageModelRegistry::read_global(cx);
308 let summarization_model = registry.thread_summary_model().map(|c| c.model);
309
310 thread_handle.update(cx, |thread, cx| {
311 thread.set_summarization_model(summarization_model, cx);
312 thread.add_default_tools(
313 Rc::new(AcpThreadEnvironment {
314 acp_thread: acp_thread.downgrade(),
315 }) as _,
316 cx,
317 )
318 });
319
320 let subscriptions = vec![
321 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
322 this.sessions.remove(acp_thread.session_id());
323 }),
324 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
325 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
326 cx.observe(&thread_handle, move |this, thread, cx| {
327 this.save_thread(thread, cx)
328 }),
329 ];
330
331 self.sessions.insert(
332 session_id,
333 Session {
334 thread: thread_handle,
335 acp_thread: acp_thread.downgrade(),
336 _subscriptions: subscriptions,
337 pending_save: Task::ready(()),
338 },
339 );
340 acp_thread
341 }
342
343 pub fn models(&self) -> &LanguageModels {
344 &self.models
345 }
346
347 async fn maintain_project_context(
348 this: WeakEntity<Self>,
349 mut needs_refresh: watch::Receiver<()>,
350 cx: &mut AsyncApp,
351 ) -> Result<()> {
352 while needs_refresh.changed().await.is_ok() {
353 let project_context = this
354 .update(cx, |this, cx| {
355 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
356 })?
357 .await;
358 this.update(cx, |this, cx| {
359 this.project_context = cx.new(|_| project_context);
360 })?;
361 }
362
363 Ok(())
364 }
365
366 fn build_project_context(
367 project: &Entity<Project>,
368 prompt_store: Option<&Entity<PromptStore>>,
369 cx: &mut App,
370 ) -> Task<ProjectContext> {
371 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
372 let worktree_tasks = worktrees
373 .into_iter()
374 .map(|worktree| {
375 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
376 })
377 .collect::<Vec<_>>();
378 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
379 prompt_store.read_with(cx, |prompt_store, cx| {
380 let prompts = prompt_store.default_prompt_metadata();
381 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
382 let contents = prompt_store.load(prompt_metadata.id, cx);
383 async move { (contents.await, prompt_metadata) }
384 });
385 cx.background_spawn(future::join_all(load_tasks))
386 })
387 } else {
388 Task::ready(vec![])
389 };
390
391 cx.spawn(async move |_cx| {
392 let (worktrees, default_user_rules) =
393 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
394
395 let worktrees = worktrees
396 .into_iter()
397 .map(|(worktree, _rules_error)| {
398 // TODO: show error message
399 // if let Some(rules_error) = rules_error {
400 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
401 // }
402 worktree
403 })
404 .collect::<Vec<_>>();
405
406 let default_user_rules = default_user_rules
407 .into_iter()
408 .flat_map(|(contents, prompt_metadata)| match contents {
409 Ok(contents) => Some(UserRulesContext {
410 uuid: match prompt_metadata.id {
411 PromptId::User { uuid } => uuid,
412 PromptId::EditWorkflow => return None,
413 },
414 title: prompt_metadata.title.map(|title| title.to_string()),
415 contents,
416 }),
417 Err(_err) => {
418 // TODO: show error message
419 // this.update(cx, |_, cx| {
420 // cx.emit(RulesLoadingError {
421 // message: format!("{err:?}").into(),
422 // });
423 // })
424 // .ok();
425 None
426 }
427 })
428 .collect::<Vec<_>>();
429
430 ProjectContext::new(worktrees, default_user_rules)
431 })
432 }
433
434 fn load_worktree_info_for_system_prompt(
435 worktree: Entity<Worktree>,
436 project: Entity<Project>,
437 cx: &mut App,
438 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
439 let tree = worktree.read(cx);
440 let root_name = tree.root_name_str().into();
441 let abs_path = tree.abs_path();
442
443 let mut context = WorktreeContext {
444 root_name,
445 abs_path,
446 rules_file: None,
447 };
448
449 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
450 let Some(rules_task) = rules_task else {
451 return Task::ready((context, None));
452 };
453
454 cx.spawn(async move |_| {
455 let (rules_file, rules_file_error) = match rules_task.await {
456 Ok(rules_file) => (Some(rules_file), None),
457 Err(err) => (
458 None,
459 Some(RulesLoadingError {
460 message: format!("{err}").into(),
461 }),
462 ),
463 };
464 context.rules_file = rules_file;
465 (context, rules_file_error)
466 })
467 }
468
469 fn load_worktree_rules_file(
470 worktree: Entity<Worktree>,
471 project: Entity<Project>,
472 cx: &mut App,
473 ) -> Option<Task<Result<RulesFileContext>>> {
474 let worktree = worktree.read(cx);
475 let worktree_id = worktree.id();
476 let selected_rules_file = RULES_FILE_NAMES
477 .into_iter()
478 .filter_map(|name| {
479 worktree
480 .entry_for_path(name)
481 .filter(|entry| entry.is_file())
482 .map(|entry| entry.path.clone())
483 })
484 .next();
485
486 // Note that Cline supports `.clinerules` being a directory, but that is not currently
487 // supported. This doesn't seem to occur often in GitHub repositories.
488 selected_rules_file.map(|path_in_worktree| {
489 let project_path = ProjectPath {
490 worktree_id,
491 path: path_in_worktree.clone(),
492 };
493 let buffer_task =
494 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
495 let rope_task = cx.spawn(async move |cx| {
496 buffer_task.await?.read_with(cx, |buffer, cx| {
497 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
498 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
499 })?
500 });
501 // Build a string from the rope on a background thread.
502 cx.background_spawn(async move {
503 let (project_entry_id, rope) = rope_task.await?;
504 anyhow::Ok(RulesFileContext {
505 path_in_worktree,
506 text: rope.to_string().trim().to_string(),
507 project_entry_id: project_entry_id.to_usize(),
508 })
509 })
510 })
511 }
512
513 fn handle_thread_title_updated(
514 &mut self,
515 thread: Entity<Thread>,
516 _: &TitleUpdated,
517 cx: &mut Context<Self>,
518 ) {
519 let session_id = thread.read(cx).id();
520 let Some(session) = self.sessions.get(session_id) else {
521 return;
522 };
523 let thread = thread.downgrade();
524 let acp_thread = session.acp_thread.clone();
525 cx.spawn(async move |_, cx| {
526 let title = thread.read_with(cx, |thread, _| thread.title())?;
527 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
528 task.await
529 })
530 .detach_and_log_err(cx);
531 }
532
533 fn handle_thread_token_usage_updated(
534 &mut self,
535 thread: Entity<Thread>,
536 usage: &TokenUsageUpdated,
537 cx: &mut Context<Self>,
538 ) {
539 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
540 return;
541 };
542 session
543 .acp_thread
544 .update(cx, |acp_thread, cx| {
545 acp_thread.update_token_usage(usage.0.clone(), cx);
546 })
547 .ok();
548 }
549
550 fn handle_project_event(
551 &mut self,
552 _project: Entity<Project>,
553 event: &project::Event,
554 _cx: &mut Context<Self>,
555 ) {
556 match event {
557 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
558 self.project_context_needs_refresh.send(()).ok();
559 }
560 project::Event::WorktreeUpdatedEntries(_, items) => {
561 if items
562 .iter()
563 .any(|(path, _, _)| RULES_FILE_NAMES.iter().any(|name| path.as_ref() == *name))
564 {
565 self.project_context_needs_refresh.send(()).ok();
566 }
567 }
568 _ => {}
569 }
570 }
571
572 fn handle_prompts_updated_event(
573 &mut self,
574 _prompt_store: Entity<PromptStore>,
575 _event: &prompt_store::PromptsUpdatedEvent,
576 _cx: &mut Context<Self>,
577 ) {
578 self.project_context_needs_refresh.send(()).ok();
579 }
580
581 fn handle_models_updated_event(
582 &mut self,
583 _registry: Entity<LanguageModelRegistry>,
584 _event: &language_model::Event,
585 cx: &mut Context<Self>,
586 ) {
587 self.models.refresh_list(cx);
588
589 let registry = LanguageModelRegistry::read_global(cx);
590 let default_model = registry.default_model().map(|m| m.model);
591 let summarization_model = registry.thread_summary_model().map(|m| m.model);
592
593 for session in self.sessions.values_mut() {
594 session.thread.update(cx, |thread, cx| {
595 if thread.model().is_none()
596 && let Some(model) = default_model.clone()
597 {
598 thread.set_model(model, cx);
599 cx.notify();
600 }
601 thread.set_summarization_model(summarization_model.clone(), cx);
602 });
603 }
604 }
605
606 pub fn open_thread(
607 &mut self,
608 id: acp::SessionId,
609 cx: &mut Context<Self>,
610 ) -> Task<Result<Entity<AcpThread>>> {
611 let database_future = ThreadsDatabase::connect(cx);
612 cx.spawn(async move |this, cx| {
613 let database = database_future.await.map_err(|err| anyhow!(err))?;
614 let db_thread = database
615 .load_thread(id.clone())
616 .await?
617 .with_context(|| format!("no thread found with ID: {id:?}"))?;
618
619 let thread = this.update(cx, |this, cx| {
620 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
621 cx.new(|cx| {
622 Thread::from_db(
623 id.clone(),
624 db_thread,
625 this.project.clone(),
626 this.project_context.clone(),
627 this.context_server_registry.clone(),
628 action_log.clone(),
629 this.templates.clone(),
630 cx,
631 )
632 })
633 })?;
634 let acp_thread =
635 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
636 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
637 cx.update(|cx| {
638 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
639 })?
640 .await?;
641 Ok(acp_thread)
642 })
643 }
644
645 pub fn thread_summary(
646 &mut self,
647 id: acp::SessionId,
648 cx: &mut Context<Self>,
649 ) -> Task<Result<SharedString>> {
650 let thread = self.open_thread(id.clone(), cx);
651 cx.spawn(async move |this, cx| {
652 let acp_thread = thread.await?;
653 let result = this
654 .update(cx, |this, cx| {
655 this.sessions
656 .get(&id)
657 .unwrap()
658 .thread
659 .update(cx, |thread, cx| thread.summary(cx))
660 })?
661 .await?;
662 drop(acp_thread);
663 Ok(result)
664 })
665 }
666
667 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
668 if thread.read(cx).is_empty() {
669 return;
670 }
671
672 let database_future = ThreadsDatabase::connect(cx);
673 let (id, db_thread) =
674 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
675 let Some(session) = self.sessions.get_mut(&id) else {
676 return;
677 };
678 let history = self.history.clone();
679 session.pending_save = cx.spawn(async move |_, cx| {
680 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
681 return;
682 };
683 let db_thread = db_thread.await;
684 database.save_thread(id, db_thread).await.log_err();
685 history.update(cx, |history, cx| history.reload(cx)).ok();
686 });
687 }
688}
689
690/// Wrapper struct that implements the AgentConnection trait
691#[derive(Clone)]
692pub struct NativeAgentConnection(pub Entity<NativeAgent>);
693
694impl NativeAgentConnection {
695 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
696 self.0
697 .read(cx)
698 .sessions
699 .get(session_id)
700 .map(|session| session.thread.clone())
701 }
702
703 fn run_turn(
704 &self,
705 session_id: acp::SessionId,
706 cx: &mut App,
707 f: impl 'static
708 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
709 ) -> Task<Result<acp::PromptResponse>> {
710 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
711 agent
712 .sessions
713 .get_mut(&session_id)
714 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
715 }) else {
716 return Task::ready(Err(anyhow!("Session not found")));
717 };
718 log::debug!("Found session for: {}", session_id);
719
720 let response_stream = match f(thread, cx) {
721 Ok(stream) => stream,
722 Err(err) => return Task::ready(Err(err)),
723 };
724 Self::handle_thread_events(response_stream, acp_thread, cx)
725 }
726
727 fn handle_thread_events(
728 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
729 acp_thread: WeakEntity<AcpThread>,
730 cx: &App,
731 ) -> Task<Result<acp::PromptResponse>> {
732 cx.spawn(async move |cx| {
733 // Handle response stream and forward to session.acp_thread
734 while let Some(result) = events.next().await {
735 match result {
736 Ok(event) => {
737 log::trace!("Received completion event: {:?}", event);
738
739 match event {
740 ThreadEvent::UserMessage(message) => {
741 acp_thread.update(cx, |thread, cx| {
742 for content in message.content {
743 thread.push_user_content_block(
744 Some(message.id.clone()),
745 content.into(),
746 cx,
747 );
748 }
749 })?;
750 }
751 ThreadEvent::AgentText(text) => {
752 acp_thread.update(cx, |thread, cx| {
753 thread.push_assistant_content_block(
754 acp::ContentBlock::Text(acp::TextContent {
755 text,
756 annotations: None,
757 meta: None,
758 }),
759 false,
760 cx,
761 )
762 })?;
763 }
764 ThreadEvent::AgentThinking(text) => {
765 acp_thread.update(cx, |thread, cx| {
766 thread.push_assistant_content_block(
767 acp::ContentBlock::Text(acp::TextContent {
768 text,
769 annotations: None,
770 meta: None,
771 }),
772 true,
773 cx,
774 )
775 })?;
776 }
777 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
778 tool_call,
779 options,
780 response,
781 }) => {
782 let outcome_task = acp_thread.update(cx, |thread, cx| {
783 thread.request_tool_call_authorization(
784 tool_call, options, true, cx,
785 )
786 })??;
787 cx.background_spawn(async move {
788 if let acp::RequestPermissionOutcome::Selected { option_id } =
789 outcome_task.await
790 {
791 response
792 .send(option_id)
793 .map(|_| anyhow!("authorization receiver was dropped"))
794 .log_err();
795 }
796 })
797 .detach();
798 }
799 ThreadEvent::ToolCall(tool_call) => {
800 acp_thread.update(cx, |thread, cx| {
801 thread.upsert_tool_call(tool_call, cx)
802 })??;
803 }
804 ThreadEvent::ToolCallUpdate(update) => {
805 acp_thread.update(cx, |thread, cx| {
806 thread.update_tool_call(update, cx)
807 })??;
808 }
809 ThreadEvent::Retry(status) => {
810 acp_thread.update(cx, |thread, cx| {
811 thread.update_retry_status(status, cx)
812 })?;
813 }
814 ThreadEvent::Stop(stop_reason) => {
815 log::debug!("Assistant message complete: {:?}", stop_reason);
816 return Ok(acp::PromptResponse {
817 stop_reason,
818 meta: None,
819 });
820 }
821 }
822 }
823 Err(e) => {
824 log::error!("Error in model response stream: {:?}", e);
825 return Err(e);
826 }
827 }
828 }
829
830 log::debug!("Response stream completed");
831 anyhow::Ok(acp::PromptResponse {
832 stop_reason: acp::StopReason::EndTurn,
833 meta: None,
834 })
835 })
836 }
837}
838
839struct NativeAgentModelSelector {
840 session_id: acp::SessionId,
841 connection: NativeAgentConnection,
842}
843
844impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
845 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
846 log::debug!("NativeAgentConnection::list_models called");
847 let list = self.connection.0.read(cx).models.model_list.clone();
848 Task::ready(if list.is_empty() {
849 Err(anyhow::anyhow!("No models available"))
850 } else {
851 Ok(list)
852 })
853 }
854
855 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
856 log::debug!(
857 "Setting model for session {}: {}",
858 self.session_id,
859 model_id
860 );
861 let Some(thread) = self
862 .connection
863 .0
864 .read(cx)
865 .sessions
866 .get(&self.session_id)
867 .map(|session| session.thread.clone())
868 else {
869 return Task::ready(Err(anyhow!("Session not found")));
870 };
871
872 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
873 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
874 };
875
876 thread.update(cx, |thread, cx| {
877 thread.set_model(model.clone(), cx);
878 });
879
880 update_settings_file(
881 self.connection.0.read(cx).fs.clone(),
882 cx,
883 move |settings, _cx| {
884 let provider = model.provider_id().0.to_string();
885 let model = model.id().0.to_string();
886 settings
887 .agent
888 .get_or_insert_default()
889 .set_model(LanguageModelSelection {
890 provider: provider.into(),
891 model,
892 });
893 },
894 );
895
896 Task::ready(Ok(()))
897 }
898
899 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
900 let Some(thread) = self
901 .connection
902 .0
903 .read(cx)
904 .sessions
905 .get(&self.session_id)
906 .map(|session| session.thread.clone())
907 else {
908 return Task::ready(Err(anyhow!("Session not found")));
909 };
910 let Some(model) = thread.read(cx).model() else {
911 return Task::ready(Err(anyhow!("Model not found")));
912 };
913 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
914 else {
915 return Task::ready(Err(anyhow!("Provider not found")));
916 };
917 Task::ready(Ok(LanguageModels::map_language_model_to_info(
918 model, &provider,
919 )))
920 }
921
922 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
923 Some(self.connection.0.read(cx).models.watch())
924 }
925}
926
927impl acp_thread::AgentConnection for NativeAgentConnection {
928 fn new_thread(
929 self: Rc<Self>,
930 project: Entity<Project>,
931 cwd: &Path,
932 cx: &mut App,
933 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
934 let agent = self.0.clone();
935 log::debug!("Creating new thread for project at: {:?}", cwd);
936
937 cx.spawn(async move |cx| {
938 log::debug!("Starting thread creation in async context");
939
940 // Create Thread
941 let thread = agent.update(
942 cx,
943 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
944 // Fetch default model from registry settings
945 let registry = LanguageModelRegistry::read_global(cx);
946 // Log available models for debugging
947 let available_count = registry.available_models(cx).count();
948 log::debug!("Total available models: {}", available_count);
949
950 let default_model = registry.default_model().and_then(|default_model| {
951 agent
952 .models
953 .model_from_id(&LanguageModels::model_id(&default_model.model))
954 });
955 Ok(cx.new(|cx| {
956 Thread::new(
957 project.clone(),
958 agent.project_context.clone(),
959 agent.context_server_registry.clone(),
960 agent.templates.clone(),
961 default_model,
962 cx,
963 )
964 }))
965 },
966 )??;
967 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
968 })
969 }
970
971 fn auth_methods(&self) -> &[acp::AuthMethod] {
972 &[] // No auth for in-process
973 }
974
975 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
976 Task::ready(Ok(()))
977 }
978
979 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
980 Some(Rc::new(NativeAgentModelSelector {
981 session_id: session_id.clone(),
982 connection: self.clone(),
983 }) as Rc<dyn AgentModelSelector>)
984 }
985
986 fn prompt(
987 &self,
988 id: Option<acp_thread::UserMessageId>,
989 params: acp::PromptRequest,
990 cx: &mut App,
991 ) -> Task<Result<acp::PromptResponse>> {
992 let id = id.expect("UserMessageId is required");
993 let session_id = params.session_id.clone();
994 log::info!("Received prompt request for session: {}", session_id);
995 log::debug!("Prompt blocks count: {}", params.prompt.len());
996
997 self.run_turn(session_id, cx, |thread, cx| {
998 let content: Vec<UserMessageContent> = params
999 .prompt
1000 .into_iter()
1001 .map(Into::into)
1002 .collect::<Vec<_>>();
1003 log::debug!("Converted prompt to message: {} chars", content.len());
1004 log::debug!("Message id: {:?}", id);
1005 log::debug!("Message content: {:?}", content);
1006
1007 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1008 })
1009 }
1010
1011 fn resume(
1012 &self,
1013 session_id: &acp::SessionId,
1014 _cx: &App,
1015 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1016 Some(Rc::new(NativeAgentSessionResume {
1017 connection: self.clone(),
1018 session_id: session_id.clone(),
1019 }) as _)
1020 }
1021
1022 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1023 log::info!("Cancelling on session: {}", session_id);
1024 self.0.update(cx, |agent, cx| {
1025 if let Some(agent) = agent.sessions.get(session_id) {
1026 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1027 }
1028 });
1029 }
1030
1031 fn truncate(
1032 &self,
1033 session_id: &agent_client_protocol::SessionId,
1034 cx: &App,
1035 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1036 self.0.read_with(cx, |agent, _cx| {
1037 agent.sessions.get(session_id).map(|session| {
1038 Rc::new(NativeAgentSessionTruncate {
1039 thread: session.thread.clone(),
1040 acp_thread: session.acp_thread.clone(),
1041 }) as _
1042 })
1043 })
1044 }
1045
1046 fn set_title(
1047 &self,
1048 session_id: &acp::SessionId,
1049 _cx: &App,
1050 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1051 Some(Rc::new(NativeAgentSessionSetTitle {
1052 connection: self.clone(),
1053 session_id: session_id.clone(),
1054 }) as _)
1055 }
1056
1057 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1058 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1059 }
1060
1061 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1062 self
1063 }
1064}
1065
1066impl acp_thread::AgentTelemetry for NativeAgentConnection {
1067 fn agent_name(&self) -> String {
1068 "Zed".into()
1069 }
1070
1071 fn thread_data(
1072 &self,
1073 session_id: &acp::SessionId,
1074 cx: &mut App,
1075 ) -> Task<Result<serde_json::Value>> {
1076 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1077 return Task::ready(Err(anyhow!("Session not found")));
1078 };
1079
1080 let task = session.thread.read(cx).to_db(cx);
1081 cx.background_spawn(async move {
1082 serde_json::to_value(task.await).context("Failed to serialize thread")
1083 })
1084 }
1085}
1086
1087struct NativeAgentSessionTruncate {
1088 thread: Entity<Thread>,
1089 acp_thread: WeakEntity<AcpThread>,
1090}
1091
1092impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1093 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1094 match self.thread.update(cx, |thread, cx| {
1095 thread.truncate(message_id.clone(), cx)?;
1096 Ok(thread.latest_token_usage())
1097 }) {
1098 Ok(usage) => {
1099 self.acp_thread
1100 .update(cx, |thread, cx| {
1101 thread.update_token_usage(usage, cx);
1102 })
1103 .ok();
1104 Task::ready(Ok(()))
1105 }
1106 Err(error) => Task::ready(Err(error)),
1107 }
1108 }
1109}
1110
1111struct NativeAgentSessionResume {
1112 connection: NativeAgentConnection,
1113 session_id: acp::SessionId,
1114}
1115
1116impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1117 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1118 self.connection
1119 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1120 thread.update(cx, |thread, cx| thread.resume(cx))
1121 })
1122 }
1123}
1124
1125struct NativeAgentSessionSetTitle {
1126 connection: NativeAgentConnection,
1127 session_id: acp::SessionId,
1128}
1129
1130impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1131 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1132 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1133 return Task::ready(Err(anyhow!("session not found")));
1134 };
1135 let thread = session.thread.clone();
1136 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1137 Task::ready(Ok(()))
1138 }
1139}
1140
1141pub struct AcpThreadEnvironment {
1142 acp_thread: WeakEntity<AcpThread>,
1143}
1144
1145impl ThreadEnvironment for AcpThreadEnvironment {
1146 fn create_terminal(
1147 &self,
1148 command: String,
1149 cwd: Option<PathBuf>,
1150 output_byte_limit: Option<u64>,
1151 cx: &mut AsyncApp,
1152 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1153 let task = self.acp_thread.update(cx, |thread, cx| {
1154 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1155 });
1156
1157 let acp_thread = self.acp_thread.clone();
1158 cx.spawn(async move |cx| {
1159 let terminal = task?.await?;
1160
1161 let (drop_tx, drop_rx) = oneshot::channel();
1162 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1163
1164 cx.spawn(async move |cx| {
1165 drop_rx.await.ok();
1166 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1167 })
1168 .detach();
1169
1170 let handle = AcpTerminalHandle {
1171 terminal,
1172 _drop_tx: Some(drop_tx),
1173 };
1174
1175 Ok(Rc::new(handle) as _)
1176 })
1177 }
1178}
1179
1180pub struct AcpTerminalHandle {
1181 terminal: Entity<acp_thread::Terminal>,
1182 _drop_tx: Option<oneshot::Sender<()>>,
1183}
1184
1185impl TerminalHandle for AcpTerminalHandle {
1186 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1187 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1188 }
1189
1190 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1191 self.terminal
1192 .read_with(cx, |term, _cx| term.wait_for_exit())
1193 }
1194
1195 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1196 self.terminal
1197 .read_with(cx, |term, cx| term.current_output(cx))
1198 }
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203 use crate::HistoryEntryId;
1204
1205 use super::*;
1206 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1207 use fs::FakeFs;
1208 use gpui::TestAppContext;
1209 use indoc::formatdoc;
1210 use language_model::fake_provider::FakeLanguageModel;
1211 use serde_json::json;
1212 use settings::SettingsStore;
1213 use util::{path, rel_path::rel_path};
1214
1215 #[gpui::test]
1216 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1217 init_test(cx);
1218 let fs = FakeFs::new(cx.executor());
1219 fs.insert_tree(
1220 "/",
1221 json!({
1222 "a": {}
1223 }),
1224 )
1225 .await;
1226 let project = Project::test(fs.clone(), [], cx).await;
1227 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1228 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1229 let agent = NativeAgent::new(
1230 project.clone(),
1231 history_store,
1232 Templates::new(),
1233 None,
1234 fs.clone(),
1235 &mut cx.to_async(),
1236 )
1237 .await
1238 .unwrap();
1239 agent.read_with(cx, |agent, cx| {
1240 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1241 });
1242
1243 let worktree = project
1244 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1245 .await
1246 .unwrap();
1247 cx.run_until_parked();
1248 agent.read_with(cx, |agent, cx| {
1249 assert_eq!(
1250 agent.project_context.read(cx).worktrees,
1251 vec![WorktreeContext {
1252 root_name: "a".into(),
1253 abs_path: Path::new("/a").into(),
1254 rules_file: None
1255 }]
1256 )
1257 });
1258
1259 // Creating `/a/.rules` updates the project context.
1260 fs.insert_file("/a/.rules", Vec::new()).await;
1261 cx.run_until_parked();
1262 agent.read_with(cx, |agent, cx| {
1263 let rules_entry = worktree
1264 .read(cx)
1265 .entry_for_path(rel_path(".rules"))
1266 .unwrap();
1267 assert_eq!(
1268 agent.project_context.read(cx).worktrees,
1269 vec![WorktreeContext {
1270 root_name: "a".into(),
1271 abs_path: Path::new("/a").into(),
1272 rules_file: Some(RulesFileContext {
1273 path_in_worktree: rel_path(".rules").into(),
1274 text: "".into(),
1275 project_entry_id: rules_entry.id.to_usize()
1276 })
1277 }]
1278 )
1279 });
1280 }
1281
1282 #[gpui::test]
1283 async fn test_listing_models(cx: &mut TestAppContext) {
1284 init_test(cx);
1285 let fs = FakeFs::new(cx.executor());
1286 fs.insert_tree("/", json!({ "a": {} })).await;
1287 let project = Project::test(fs.clone(), [], cx).await;
1288 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1289 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1290 let connection = NativeAgentConnection(
1291 NativeAgent::new(
1292 project.clone(),
1293 history_store,
1294 Templates::new(),
1295 None,
1296 fs.clone(),
1297 &mut cx.to_async(),
1298 )
1299 .await
1300 .unwrap(),
1301 );
1302
1303 // Create a thread/session
1304 let acp_thread = cx
1305 .update(|cx| {
1306 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1307 })
1308 .await
1309 .unwrap();
1310
1311 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1312
1313 let models = cx
1314 .update(|cx| {
1315 connection
1316 .model_selector(&session_id)
1317 .unwrap()
1318 .list_models(cx)
1319 })
1320 .await
1321 .unwrap();
1322
1323 let acp_thread::AgentModelList::Grouped(models) = models else {
1324 panic!("Unexpected model group");
1325 };
1326 assert_eq!(
1327 models,
1328 IndexMap::from_iter([(
1329 AgentModelGroupName("Fake".into()),
1330 vec![AgentModelInfo {
1331 id: acp::ModelId("fake/fake".into()),
1332 name: "Fake".into(),
1333 description: None,
1334 icon: Some(ui::IconName::ZedAssistant),
1335 }]
1336 )])
1337 );
1338 }
1339
1340 #[gpui::test]
1341 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1342 init_test(cx);
1343 let fs = FakeFs::new(cx.executor());
1344 fs.create_dir(paths::settings_file().parent().unwrap())
1345 .await
1346 .unwrap();
1347 fs.insert_file(
1348 paths::settings_file(),
1349 json!({
1350 "agent": {
1351 "default_model": {
1352 "provider": "foo",
1353 "model": "bar"
1354 }
1355 }
1356 })
1357 .to_string()
1358 .into_bytes(),
1359 )
1360 .await;
1361 let project = Project::test(fs.clone(), [], cx).await;
1362
1363 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1364 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1365
1366 // Create the agent and connection
1367 let agent = NativeAgent::new(
1368 project.clone(),
1369 history_store,
1370 Templates::new(),
1371 None,
1372 fs.clone(),
1373 &mut cx.to_async(),
1374 )
1375 .await
1376 .unwrap();
1377 let connection = NativeAgentConnection(agent.clone());
1378
1379 // Create a thread/session
1380 let acp_thread = cx
1381 .update(|cx| {
1382 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1383 })
1384 .await
1385 .unwrap();
1386
1387 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1388
1389 // Select a model
1390 let selector = connection.model_selector(&session_id).unwrap();
1391 let model_id = acp::ModelId("fake/fake".into());
1392 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1393 .await
1394 .unwrap();
1395
1396 // Verify the thread has the selected model
1397 agent.read_with(cx, |agent, _| {
1398 let session = agent.sessions.get(&session_id).unwrap();
1399 session.thread.read_with(cx, |thread, _| {
1400 assert_eq!(thread.model().unwrap().id().0, "fake");
1401 });
1402 });
1403
1404 cx.run_until_parked();
1405
1406 // Verify settings file was updated
1407 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1408 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1409
1410 // Check that the agent settings contain the selected model
1411 assert_eq!(
1412 settings_json["agent"]["default_model"]["model"],
1413 json!("fake")
1414 );
1415 assert_eq!(
1416 settings_json["agent"]["default_model"]["provider"],
1417 json!("fake")
1418 );
1419 }
1420
1421 #[gpui::test]
1422 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1423 async fn test_save_load_thread(cx: &mut TestAppContext) {
1424 init_test(cx);
1425 let fs = FakeFs::new(cx.executor());
1426 fs.insert_tree(
1427 "/",
1428 json!({
1429 "a": {
1430 "b.md": "Lorem"
1431 }
1432 }),
1433 )
1434 .await;
1435 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1436 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1437 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1438 let agent = NativeAgent::new(
1439 project.clone(),
1440 history_store.clone(),
1441 Templates::new(),
1442 None,
1443 fs.clone(),
1444 &mut cx.to_async(),
1445 )
1446 .await
1447 .unwrap();
1448 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1449
1450 let acp_thread = cx
1451 .update(|cx| {
1452 connection
1453 .clone()
1454 .new_thread(project.clone(), Path::new(""), cx)
1455 })
1456 .await
1457 .unwrap();
1458 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1459 let thread = agent.read_with(cx, |agent, _| {
1460 agent.sessions.get(&session_id).unwrap().thread.clone()
1461 });
1462
1463 // Ensure empty threads are not saved, even if they get mutated.
1464 let model = Arc::new(FakeLanguageModel::default());
1465 let summary_model = Arc::new(FakeLanguageModel::default());
1466 thread.update(cx, |thread, cx| {
1467 thread.set_model(model.clone(), cx);
1468 thread.set_summarization_model(Some(summary_model.clone()), cx);
1469 });
1470 cx.run_until_parked();
1471 assert_eq!(history_entries(&history_store, cx), vec![]);
1472
1473 let send = acp_thread.update(cx, |thread, cx| {
1474 thread.send(
1475 vec![
1476 "What does ".into(),
1477 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1478 name: "b.md".into(),
1479 uri: MentionUri::File {
1480 abs_path: path!("/a/b.md").into(),
1481 }
1482 .to_uri()
1483 .to_string(),
1484 annotations: None,
1485 description: None,
1486 mime_type: None,
1487 size: None,
1488 title: None,
1489 meta: None,
1490 }),
1491 " mean?".into(),
1492 ],
1493 cx,
1494 )
1495 });
1496 let send = cx.foreground_executor().spawn(send);
1497 cx.run_until_parked();
1498
1499 model.send_last_completion_stream_text_chunk("Lorem.");
1500 model.end_last_completion_stream();
1501 cx.run_until_parked();
1502 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1503 summary_model.end_last_completion_stream();
1504
1505 send.await.unwrap();
1506 let uri = MentionUri::File {
1507 abs_path: path!("/a/b.md").into(),
1508 }
1509 .to_uri();
1510 acp_thread.read_with(cx, |thread, cx| {
1511 assert_eq!(
1512 thread.to_markdown(cx),
1513 formatdoc! {"
1514 ## User
1515
1516 What does [@b.md]({uri}) mean?
1517
1518 ## Assistant
1519
1520 Lorem.
1521
1522 "}
1523 )
1524 });
1525
1526 cx.run_until_parked();
1527
1528 // Drop the ACP thread, which should cause the session to be dropped as well.
1529 cx.update(|_| {
1530 drop(thread);
1531 drop(acp_thread);
1532 });
1533 agent.read_with(cx, |agent, _| {
1534 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1535 });
1536
1537 // Ensure the thread can be reloaded from disk.
1538 assert_eq!(
1539 history_entries(&history_store, cx),
1540 vec![(
1541 HistoryEntryId::AcpThread(session_id.clone()),
1542 "Explaining /a/b.md".into()
1543 )]
1544 );
1545 let acp_thread = agent
1546 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1547 .await
1548 .unwrap();
1549 acp_thread.read_with(cx, |thread, cx| {
1550 assert_eq!(
1551 thread.to_markdown(cx),
1552 formatdoc! {"
1553 ## User
1554
1555 What does [@b.md]({uri}) mean?
1556
1557 ## Assistant
1558
1559 Lorem.
1560
1561 "}
1562 )
1563 });
1564 }
1565
1566 fn history_entries(
1567 history: &Entity<HistoryStore>,
1568 cx: &mut TestAppContext,
1569 ) -> Vec<(HistoryEntryId, String)> {
1570 history.read_with(cx, |history, _| {
1571 history
1572 .entries()
1573 .map(|e| (e.id(), e.title().to_string()))
1574 .collect::<Vec<_>>()
1575 })
1576 }
1577
1578 fn init_test(cx: &mut TestAppContext) {
1579 env_logger::try_init().ok();
1580 cx.update(|cx| {
1581 let settings_store = SettingsStore::test(cx);
1582 cx.set_global(settings_store);
1583 Project::init_settings(cx);
1584 agent_settings::init(cx);
1585 language::init(cx);
1586 LanguageModelRegistry::test(cx);
1587 });
1588 }
1589}