1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use agent_settings::AgentSettings;
10use anyhow::{Context as _, Result, anyhow};
11use collections::{HashSet, IndexMap};
12use fs::Fs;
13use futures::channel::{mpsc, oneshot};
14use futures::future::Shared;
15use futures::{StreamExt, future};
16use gpui::{
17 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
18};
19use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
20use project::{Project, ProjectItem, ProjectPath, Worktree};
21use prompt_store::{
22 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
23};
24use settings::update_settings_file;
25use std::any::Any;
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28use std::rc::Rc;
29use std::sync::Arc;
30use util::ResultExt;
31
32const RULES_FILE_NAMES: [&str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 pending_save: Task<()>,
55 _subscriptions: Vec<Subscription>,
56}
57
58pub struct LanguageModels {
59 /// Access language model by ID
60 models: HashMap<acp_thread::AgentModelId, Arc<dyn LanguageModel>>,
61 /// Cached list for returning language model information
62 model_list: acp_thread::AgentModelList,
63 refresh_models_rx: watch::Receiver<()>,
64 refresh_models_tx: watch::Sender<()>,
65 _authenticate_all_providers_task: Task<()>,
66}
67
68impl LanguageModels {
69 fn new(cx: &mut App) -> Self {
70 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
71
72 let mut this = Self {
73 models: HashMap::default(),
74 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
75 refresh_models_rx,
76 refresh_models_tx,
77 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
78 };
79 this.refresh_list(cx);
80 this
81 }
82
83 fn refresh_list(&mut self, cx: &App) {
84 let providers = LanguageModelRegistry::global(cx)
85 .read(cx)
86 .providers()
87 .into_iter()
88 .filter(|provider| provider.is_authenticated(cx))
89 .collect::<Vec<_>>();
90
91 let mut language_model_list = IndexMap::default();
92 let mut recommended_models = HashSet::default();
93
94 let mut recommended = Vec::new();
95 for provider in &providers {
96 for model in provider.recommended_models(cx) {
97 recommended_models.insert((model.provider_id(), model.id()));
98 recommended.push(Self::map_language_model_to_info(&model, provider));
99 }
100 }
101 if !recommended.is_empty() {
102 language_model_list.insert(
103 acp_thread::AgentModelGroupName("Recommended".into()),
104 recommended,
105 );
106 }
107
108 let mut models = HashMap::default();
109 for provider in providers {
110 let mut provider_models = Vec::new();
111 for model in provider.provided_models(cx) {
112 let model_info = Self::map_language_model_to_info(&model, &provider);
113 let model_id = model_info.id.clone();
114 if !recommended_models.contains(&(model.provider_id(), model.id())) {
115 provider_models.push(model_info);
116 }
117 models.insert(model_id, model);
118 }
119 if !provider_models.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName(provider.name().0.clone()),
122 provider_models,
123 );
124 }
125 }
126
127 self.models = models;
128 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
129 self.refresh_models_tx.send(()).ok();
130 }
131
132 fn watch(&self) -> watch::Receiver<()> {
133 self.refresh_models_rx.clone()
134 }
135
136 pub fn model_from_id(
137 &self,
138 model_id: &acp_thread::AgentModelId,
139 ) -> Option<Arc<dyn LanguageModel>> {
140 self.models.get(model_id).cloned()
141 }
142
143 fn map_language_model_to_info(
144 model: &Arc<dyn LanguageModel>,
145 provider: &Arc<dyn LanguageModelProvider>,
146 ) -> acp_thread::AgentModelInfo {
147 acp_thread::AgentModelInfo {
148 id: Self::model_id(model),
149 name: model.name().0,
150 icon: Some(provider.icon()),
151 }
152 }
153
154 fn model_id(model: &Arc<dyn LanguageModel>) -> acp_thread::AgentModelId {
155 acp_thread::AgentModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
156 }
157
158 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
159 let authenticate_all_providers = LanguageModelRegistry::global(cx)
160 .read(cx)
161 .providers()
162 .iter()
163 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
164 .collect::<Vec<_>>();
165
166 cx.background_spawn(async move {
167 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
168 if let Err(err) = authenticate_task.await {
169 match err {
170 language_model::AuthenticateError::CredentialsNotFound => {
171 // Since we're authenticating these providers in the
172 // background for the purposes of populating the
173 // language selector, we don't care about providers
174 // where the credentials are not found.
175 }
176 language_model::AuthenticateError::ConnectionRefused => {
177 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
178 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
179 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
180 }
181 _ => {
182 // Some providers have noisy failure states that we
183 // don't want to spam the logs with every time the
184 // language model selector is initialized.
185 //
186 // Ideally these should have more clear failure modes
187 // that we know are safe to ignore here, like what we do
188 // with `CredentialsNotFound` above.
189 match provider_id.0.as_ref() {
190 "lmstudio" | "ollama" => {
191 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
192 //
193 // These fail noisily, so we don't log them.
194 }
195 "copilot_chat" => {
196 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
197 }
198 _ => {
199 log::error!(
200 "Failed to authenticate provider: {}: {err}",
201 provider_name.0
202 );
203 }
204 }
205 }
206 }
207 }
208 }
209 })
210 }
211}
212
213pub struct NativeAgent {
214 /// Session ID -> Session mapping
215 sessions: HashMap<acp::SessionId, Session>,
216 history: Entity<HistoryStore>,
217 /// Shared project context for all threads
218 project_context: Entity<ProjectContext>,
219 project_context_needs_refresh: watch::Sender<()>,
220 _maintain_project_context: Task<Result<()>>,
221 context_server_registry: Entity<ContextServerRegistry>,
222 /// Shared templates for all threads
223 templates: Arc<Templates>,
224 /// Cached model information
225 models: LanguageModels,
226 project: Entity<Project>,
227 prompt_store: Option<Entity<PromptStore>>,
228 fs: Arc<dyn Fs>,
229 _subscriptions: Vec<Subscription>,
230}
231
232impl NativeAgent {
233 pub async fn new(
234 project: Entity<Project>,
235 history: Entity<HistoryStore>,
236 templates: Arc<Templates>,
237 prompt_store: Option<Entity<PromptStore>>,
238 fs: Arc<dyn Fs>,
239 cx: &mut AsyncApp,
240 ) -> Result<Entity<NativeAgent>> {
241 log::debug!("Creating new NativeAgent");
242
243 let project_context = cx
244 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
245 .await;
246
247 cx.new(|cx| {
248 let mut subscriptions = vec![
249 cx.subscribe(&project, Self::handle_project_event),
250 cx.subscribe(
251 &LanguageModelRegistry::global(cx),
252 Self::handle_models_updated_event,
253 ),
254 ];
255 if let Some(prompt_store) = prompt_store.as_ref() {
256 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
257 }
258
259 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
260 watch::channel(());
261 Self {
262 sessions: HashMap::new(),
263 history,
264 project_context: cx.new(|_| project_context),
265 project_context_needs_refresh: project_context_needs_refresh_tx,
266 _maintain_project_context: cx.spawn(async move |this, cx| {
267 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
268 }),
269 context_server_registry: cx.new(|cx| {
270 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
271 }),
272 templates,
273 models: LanguageModels::new(cx),
274 project,
275 prompt_store,
276 fs,
277 _subscriptions: subscriptions,
278 }
279 })
280 }
281
282 fn register_session(
283 &mut self,
284 thread_handle: Entity<Thread>,
285 cx: &mut Context<Self>,
286 ) -> Entity<AcpThread> {
287 let connection = Rc::new(NativeAgentConnection(cx.entity()));
288
289 let thread = thread_handle.read(cx);
290 let session_id = thread.id().clone();
291 let title = thread.title();
292 let project = thread.project.clone();
293 let action_log = thread.action_log.clone();
294 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
295 let acp_thread = cx.new(|cx| {
296 acp_thread::AcpThread::new(
297 title,
298 connection,
299 project.clone(),
300 action_log.clone(),
301 session_id.clone(),
302 prompt_capabilities_rx,
303 cx,
304 )
305 });
306
307 let registry = LanguageModelRegistry::read_global(cx);
308 let summarization_model = registry.thread_summary_model().map(|c| c.model);
309
310 thread_handle.update(cx, |thread, cx| {
311 thread.set_summarization_model(summarization_model, cx);
312 thread.add_default_tools(
313 Rc::new(AcpThreadEnvironment {
314 acp_thread: acp_thread.downgrade(),
315 }) as _,
316 cx,
317 )
318 });
319
320 let subscriptions = vec![
321 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
322 this.sessions.remove(acp_thread.session_id());
323 }),
324 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
325 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
326 cx.observe(&thread_handle, move |this, thread, cx| {
327 this.save_thread(thread, cx)
328 }),
329 ];
330
331 self.sessions.insert(
332 session_id,
333 Session {
334 thread: thread_handle,
335 acp_thread: acp_thread.downgrade(),
336 _subscriptions: subscriptions,
337 pending_save: Task::ready(()),
338 },
339 );
340 acp_thread
341 }
342
343 pub fn models(&self) -> &LanguageModels {
344 &self.models
345 }
346
347 async fn maintain_project_context(
348 this: WeakEntity<Self>,
349 mut needs_refresh: watch::Receiver<()>,
350 cx: &mut AsyncApp,
351 ) -> Result<()> {
352 while needs_refresh.changed().await.is_ok() {
353 let project_context = this
354 .update(cx, |this, cx| {
355 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
356 })?
357 .await;
358 this.update(cx, |this, cx| {
359 this.project_context = cx.new(|_| project_context);
360 })?;
361 }
362
363 Ok(())
364 }
365
366 fn build_project_context(
367 project: &Entity<Project>,
368 prompt_store: Option<&Entity<PromptStore>>,
369 cx: &mut App,
370 ) -> Task<ProjectContext> {
371 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
372 let worktree_tasks = worktrees
373 .into_iter()
374 .map(|worktree| {
375 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
376 })
377 .collect::<Vec<_>>();
378 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
379 prompt_store.read_with(cx, |prompt_store, cx| {
380 let prompts = prompt_store.default_prompt_metadata();
381 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
382 let contents = prompt_store.load(prompt_metadata.id, cx);
383 async move { (contents.await, prompt_metadata) }
384 });
385 cx.background_spawn(future::join_all(load_tasks))
386 })
387 } else {
388 Task::ready(vec![])
389 };
390
391 cx.spawn(async move |_cx| {
392 let (worktrees, default_user_rules) =
393 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
394
395 let worktrees = worktrees
396 .into_iter()
397 .map(|(worktree, _rules_error)| {
398 // TODO: show error message
399 // if let Some(rules_error) = rules_error {
400 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
401 // }
402 worktree
403 })
404 .collect::<Vec<_>>();
405
406 let default_user_rules = default_user_rules
407 .into_iter()
408 .flat_map(|(contents, prompt_metadata)| match contents {
409 Ok(contents) => Some(UserRulesContext {
410 uuid: match prompt_metadata.id {
411 PromptId::User { uuid } => uuid,
412 PromptId::EditWorkflow => return None,
413 },
414 title: prompt_metadata.title.map(|title| title.to_string()),
415 contents,
416 }),
417 Err(_err) => {
418 // TODO: show error message
419 // this.update(cx, |_, cx| {
420 // cx.emit(RulesLoadingError {
421 // message: format!("{err:?}").into(),
422 // });
423 // })
424 // .ok();
425 None
426 }
427 })
428 .collect::<Vec<_>>();
429
430 ProjectContext::new(worktrees, default_user_rules)
431 })
432 }
433
434 fn load_worktree_info_for_system_prompt(
435 worktree: Entity<Worktree>,
436 project: Entity<Project>,
437 cx: &mut App,
438 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
439 let tree = worktree.read(cx);
440 let root_name = tree.root_name().into();
441 let abs_path = tree.abs_path();
442
443 let mut context = WorktreeContext {
444 root_name,
445 abs_path,
446 rules_file: None,
447 };
448
449 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
450 let Some(rules_task) = rules_task else {
451 return Task::ready((context, None));
452 };
453
454 cx.spawn(async move |_| {
455 let (rules_file, rules_file_error) = match rules_task.await {
456 Ok(rules_file) => (Some(rules_file), None),
457 Err(err) => (
458 None,
459 Some(RulesLoadingError {
460 message: format!("{err}").into(),
461 }),
462 ),
463 };
464 context.rules_file = rules_file;
465 (context, rules_file_error)
466 })
467 }
468
469 fn load_worktree_rules_file(
470 worktree: Entity<Worktree>,
471 project: Entity<Project>,
472 cx: &mut App,
473 ) -> Option<Task<Result<RulesFileContext>>> {
474 let worktree = worktree.read(cx);
475 let worktree_id = worktree.id();
476 let selected_rules_file = RULES_FILE_NAMES
477 .into_iter()
478 .filter_map(|name| {
479 worktree
480 .entry_for_path(name)
481 .filter(|entry| entry.is_file())
482 .map(|entry| entry.path.clone())
483 })
484 .next();
485
486 // Note that Cline supports `.clinerules` being a directory, but that is not currently
487 // supported. This doesn't seem to occur often in GitHub repositories.
488 selected_rules_file.map(|path_in_worktree| {
489 let project_path = ProjectPath {
490 worktree_id,
491 path: path_in_worktree.clone(),
492 };
493 let buffer_task =
494 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
495 let rope_task = cx.spawn(async move |cx| {
496 buffer_task.await?.read_with(cx, |buffer, cx| {
497 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
498 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
499 })?
500 });
501 // Build a string from the rope on a background thread.
502 cx.background_spawn(async move {
503 let (project_entry_id, rope) = rope_task.await?;
504 anyhow::Ok(RulesFileContext {
505 path_in_worktree,
506 text: rope.to_string().trim().to_string(),
507 project_entry_id: project_entry_id.to_usize(),
508 })
509 })
510 })
511 }
512
513 fn handle_thread_title_updated(
514 &mut self,
515 thread: Entity<Thread>,
516 _: &TitleUpdated,
517 cx: &mut Context<Self>,
518 ) {
519 let session_id = thread.read(cx).id();
520 let Some(session) = self.sessions.get(session_id) else {
521 return;
522 };
523 let thread = thread.downgrade();
524 let acp_thread = session.acp_thread.clone();
525 cx.spawn(async move |_, cx| {
526 let title = thread.read_with(cx, |thread, _| thread.title())?;
527 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
528 task.await
529 })
530 .detach_and_log_err(cx);
531 }
532
533 fn handle_thread_token_usage_updated(
534 &mut self,
535 thread: Entity<Thread>,
536 usage: &TokenUsageUpdated,
537 cx: &mut Context<Self>,
538 ) {
539 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
540 return;
541 };
542 session
543 .acp_thread
544 .update(cx, |acp_thread, cx| {
545 acp_thread.update_token_usage(usage.0.clone(), cx);
546 })
547 .ok();
548 }
549
550 fn handle_project_event(
551 &mut self,
552 _project: Entity<Project>,
553 event: &project::Event,
554 _cx: &mut Context<Self>,
555 ) {
556 match event {
557 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
558 self.project_context_needs_refresh.send(()).ok();
559 }
560 project::Event::WorktreeUpdatedEntries(_, items) => {
561 if items.iter().any(|(path, _, _)| {
562 RULES_FILE_NAMES
563 .iter()
564 .any(|name| path.as_ref() == Path::new(name))
565 }) {
566 self.project_context_needs_refresh.send(()).ok();
567 }
568 }
569 _ => {}
570 }
571 }
572
573 fn handle_prompts_updated_event(
574 &mut self,
575 _prompt_store: Entity<PromptStore>,
576 _event: &prompt_store::PromptsUpdatedEvent,
577 _cx: &mut Context<Self>,
578 ) {
579 self.project_context_needs_refresh.send(()).ok();
580 }
581
582 fn handle_models_updated_event(
583 &mut self,
584 _registry: Entity<LanguageModelRegistry>,
585 _event: &language_model::Event,
586 cx: &mut Context<Self>,
587 ) {
588 self.models.refresh_list(cx);
589
590 let registry = LanguageModelRegistry::read_global(cx);
591 let default_model = registry.default_model().map(|m| m.model);
592 let summarization_model = registry.thread_summary_model().map(|m| m.model);
593
594 for session in self.sessions.values_mut() {
595 session.thread.update(cx, |thread, cx| {
596 if thread.model().is_none()
597 && let Some(model) = default_model.clone()
598 {
599 thread.set_model(model, cx);
600 cx.notify();
601 }
602 thread.set_summarization_model(summarization_model.clone(), cx);
603 });
604 }
605 }
606
607 pub fn open_thread(
608 &mut self,
609 id: acp::SessionId,
610 cx: &mut Context<Self>,
611 ) -> Task<Result<Entity<AcpThread>>> {
612 let database_future = ThreadsDatabase::connect(cx);
613 cx.spawn(async move |this, cx| {
614 let database = database_future.await.map_err(|err| anyhow!(err))?;
615 let db_thread = database
616 .load_thread(id.clone())
617 .await?
618 .with_context(|| format!("no thread found with ID: {id:?}"))?;
619
620 let thread = this.update(cx, |this, cx| {
621 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
622 cx.new(|cx| {
623 Thread::from_db(
624 id.clone(),
625 db_thread,
626 this.project.clone(),
627 this.project_context.clone(),
628 this.context_server_registry.clone(),
629 action_log.clone(),
630 this.templates.clone(),
631 cx,
632 )
633 })
634 })?;
635 let acp_thread =
636 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
637 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
638 cx.update(|cx| {
639 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
640 })?
641 .await?;
642 Ok(acp_thread)
643 })
644 }
645
646 pub fn thread_summary(
647 &mut self,
648 id: acp::SessionId,
649 cx: &mut Context<Self>,
650 ) -> Task<Result<SharedString>> {
651 let thread = self.open_thread(id.clone(), cx);
652 cx.spawn(async move |this, cx| {
653 let acp_thread = thread.await?;
654 let result = this
655 .update(cx, |this, cx| {
656 this.sessions
657 .get(&id)
658 .unwrap()
659 .thread
660 .update(cx, |thread, cx| thread.summary(cx))
661 })?
662 .await?;
663 drop(acp_thread);
664 Ok(result)
665 })
666 }
667
668 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
669 if thread.read(cx).is_empty() {
670 return;
671 }
672
673 let database_future = ThreadsDatabase::connect(cx);
674 let (id, db_thread) =
675 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
676 let Some(session) = self.sessions.get_mut(&id) else {
677 return;
678 };
679 let history = self.history.clone();
680 session.pending_save = cx.spawn(async move |_, cx| {
681 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
682 return;
683 };
684 let db_thread = db_thread.await;
685 database.save_thread(id, db_thread).await.log_err();
686 history.update(cx, |history, cx| history.reload(cx)).ok();
687 });
688 }
689}
690
691/// Wrapper struct that implements the AgentConnection trait
692#[derive(Clone)]
693pub struct NativeAgentConnection(pub Entity<NativeAgent>);
694
695impl NativeAgentConnection {
696 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
697 self.0
698 .read(cx)
699 .sessions
700 .get(session_id)
701 .map(|session| session.thread.clone())
702 }
703
704 fn run_turn(
705 &self,
706 session_id: acp::SessionId,
707 cx: &mut App,
708 f: impl 'static
709 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
710 ) -> Task<Result<acp::PromptResponse>> {
711 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
712 agent
713 .sessions
714 .get_mut(&session_id)
715 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
716 }) else {
717 return Task::ready(Err(anyhow!("Session not found")));
718 };
719 log::debug!("Found session for: {}", session_id);
720
721 let response_stream = match f(thread, cx) {
722 Ok(stream) => stream,
723 Err(err) => return Task::ready(Err(err)),
724 };
725 Self::handle_thread_events(response_stream, acp_thread, cx)
726 }
727
728 fn handle_thread_events(
729 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
730 acp_thread: WeakEntity<AcpThread>,
731 cx: &App,
732 ) -> Task<Result<acp::PromptResponse>> {
733 cx.spawn(async move |cx| {
734 // Handle response stream and forward to session.acp_thread
735 while let Some(result) = events.next().await {
736 match result {
737 Ok(event) => {
738 log::trace!("Received completion event: {:?}", event);
739
740 match event {
741 ThreadEvent::UserMessage(message) => {
742 acp_thread.update(cx, |thread, cx| {
743 for content in message.content {
744 thread.push_user_content_block(
745 Some(message.id.clone()),
746 content.into(),
747 cx,
748 );
749 }
750 })?;
751 }
752 ThreadEvent::AgentText(text) => {
753 acp_thread.update(cx, |thread, cx| {
754 thread.push_assistant_content_block(
755 acp::ContentBlock::Text(acp::TextContent {
756 text,
757 annotations: None,
758 meta: None,
759 }),
760 false,
761 cx,
762 )
763 })?;
764 }
765 ThreadEvent::AgentThinking(text) => {
766 acp_thread.update(cx, |thread, cx| {
767 thread.push_assistant_content_block(
768 acp::ContentBlock::Text(acp::TextContent {
769 text,
770 annotations: None,
771 meta: None,
772 }),
773 true,
774 cx,
775 )
776 })?;
777 }
778 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
779 tool_call,
780 options,
781 response,
782 }) => {
783 let outcome_task = acp_thread.update(cx, |thread, cx| {
784 thread.request_tool_call_authorization(
785 tool_call, options, true, cx,
786 )
787 })??;
788 cx.background_spawn(async move {
789 if let acp::RequestPermissionOutcome::Selected { option_id } =
790 outcome_task.await
791 {
792 response
793 .send(option_id)
794 .map(|_| anyhow!("authorization receiver was dropped"))
795 .log_err();
796 }
797 })
798 .detach();
799 }
800 ThreadEvent::ToolCall(tool_call) => {
801 acp_thread.update(cx, |thread, cx| {
802 thread.upsert_tool_call(tool_call, cx)
803 })??;
804 }
805 ThreadEvent::ToolCallUpdate(update) => {
806 acp_thread.update(cx, |thread, cx| {
807 thread.update_tool_call(update, cx)
808 })??;
809 }
810 ThreadEvent::Retry(status) => {
811 acp_thread.update(cx, |thread, cx| {
812 thread.update_retry_status(status, cx)
813 })?;
814 }
815 ThreadEvent::Stop(stop_reason) => {
816 log::debug!("Assistant message complete: {:?}", stop_reason);
817 return Ok(acp::PromptResponse {
818 stop_reason,
819 meta: None,
820 });
821 }
822 }
823 }
824 Err(e) => {
825 log::error!("Error in model response stream: {:?}", e);
826 return Err(e);
827 }
828 }
829 }
830
831 log::debug!("Response stream completed");
832 anyhow::Ok(acp::PromptResponse {
833 stop_reason: acp::StopReason::EndTurn,
834 meta: None,
835 })
836 })
837 }
838}
839
840impl AgentModelSelector for NativeAgentConnection {
841 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
842 log::debug!("NativeAgentConnection::list_models called");
843 let list = self.0.read(cx).models.model_list.clone();
844 Task::ready(if list.is_empty() {
845 Err(anyhow::anyhow!("No models available"))
846 } else {
847 Ok(list)
848 })
849 }
850
851 fn select_model(
852 &self,
853 session_id: acp::SessionId,
854 model_id: acp_thread::AgentModelId,
855 cx: &mut App,
856 ) -> Task<Result<()>> {
857 log::debug!("Setting model for session {}: {}", session_id, model_id);
858 let Some(thread) = self
859 .0
860 .read(cx)
861 .sessions
862 .get(&session_id)
863 .map(|session| session.thread.clone())
864 else {
865 return Task::ready(Err(anyhow!("Session not found")));
866 };
867
868 let Some(model) = self.0.read(cx).models.model_from_id(&model_id) else {
869 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
870 };
871
872 thread.update(cx, |thread, cx| {
873 thread.set_model(model.clone(), cx);
874 });
875
876 update_settings_file::<AgentSettings>(
877 self.0.read(cx).fs.clone(),
878 cx,
879 move |settings, _cx| {
880 settings.set_model(model);
881 },
882 );
883
884 Task::ready(Ok(()))
885 }
886
887 fn selected_model(
888 &self,
889 session_id: &acp::SessionId,
890 cx: &mut App,
891 ) -> Task<Result<acp_thread::AgentModelInfo>> {
892 let session_id = session_id.clone();
893
894 let Some(thread) = self
895 .0
896 .read(cx)
897 .sessions
898 .get(&session_id)
899 .map(|session| session.thread.clone())
900 else {
901 return Task::ready(Err(anyhow!("Session not found")));
902 };
903 let Some(model) = thread.read(cx).model() else {
904 return Task::ready(Err(anyhow!("Model not found")));
905 };
906 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
907 else {
908 return Task::ready(Err(anyhow!("Provider not found")));
909 };
910 Task::ready(Ok(LanguageModels::map_language_model_to_info(
911 model, &provider,
912 )))
913 }
914
915 fn watch(&self, cx: &mut App) -> watch::Receiver<()> {
916 self.0.read(cx).models.watch()
917 }
918}
919
920impl acp_thread::AgentConnection for NativeAgentConnection {
921 fn new_thread(
922 self: Rc<Self>,
923 project: Entity<Project>,
924 cwd: &Path,
925 cx: &mut App,
926 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
927 let agent = self.0.clone();
928 log::debug!("Creating new thread for project at: {:?}", cwd);
929
930 cx.spawn(async move |cx| {
931 log::debug!("Starting thread creation in async context");
932
933 // Create Thread
934 let thread = agent.update(
935 cx,
936 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
937 // Fetch default model from registry settings
938 let registry = LanguageModelRegistry::read_global(cx);
939 // Log available models for debugging
940 let available_count = registry.available_models(cx).count();
941 log::debug!("Total available models: {}", available_count);
942
943 let default_model = registry.default_model().and_then(|default_model| {
944 agent
945 .models
946 .model_from_id(&LanguageModels::model_id(&default_model.model))
947 });
948 Ok(cx.new(|cx| {
949 Thread::new(
950 project.clone(),
951 agent.project_context.clone(),
952 agent.context_server_registry.clone(),
953 agent.templates.clone(),
954 default_model,
955 cx,
956 )
957 }))
958 },
959 )??;
960 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
961 })
962 }
963
964 fn auth_methods(&self) -> &[acp::AuthMethod] {
965 &[] // No auth for in-process
966 }
967
968 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
969 Task::ready(Ok(()))
970 }
971
972 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
973 Some(Rc::new(self.clone()) as Rc<dyn AgentModelSelector>)
974 }
975
976 fn prompt(
977 &self,
978 id: Option<acp_thread::UserMessageId>,
979 params: acp::PromptRequest,
980 cx: &mut App,
981 ) -> Task<Result<acp::PromptResponse>> {
982 let id = id.expect("UserMessageId is required");
983 let session_id = params.session_id.clone();
984 log::info!("Received prompt request for session: {}", session_id);
985 log::debug!("Prompt blocks count: {}", params.prompt.len());
986
987 self.run_turn(session_id, cx, |thread, cx| {
988 let content: Vec<UserMessageContent> = params
989 .prompt
990 .into_iter()
991 .map(Into::into)
992 .collect::<Vec<_>>();
993 log::debug!("Converted prompt to message: {} chars", content.len());
994 log::debug!("Message id: {:?}", id);
995 log::debug!("Message content: {:?}", content);
996
997 thread.update(cx, |thread, cx| thread.send(id, content, cx))
998 })
999 }
1000
1001 fn resume(
1002 &self,
1003 session_id: &acp::SessionId,
1004 _cx: &App,
1005 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1006 Some(Rc::new(NativeAgentSessionResume {
1007 connection: self.clone(),
1008 session_id: session_id.clone(),
1009 }) as _)
1010 }
1011
1012 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1013 log::info!("Cancelling on session: {}", session_id);
1014 self.0.update(cx, |agent, cx| {
1015 if let Some(agent) = agent.sessions.get(session_id) {
1016 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1017 }
1018 });
1019 }
1020
1021 fn truncate(
1022 &self,
1023 session_id: &agent_client_protocol::SessionId,
1024 cx: &App,
1025 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1026 self.0.read_with(cx, |agent, _cx| {
1027 agent.sessions.get(session_id).map(|session| {
1028 Rc::new(NativeAgentSessionTruncate {
1029 thread: session.thread.clone(),
1030 acp_thread: session.acp_thread.clone(),
1031 }) as _
1032 })
1033 })
1034 }
1035
1036 fn set_title(
1037 &self,
1038 session_id: &acp::SessionId,
1039 _cx: &App,
1040 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1041 Some(Rc::new(NativeAgentSessionSetTitle {
1042 connection: self.clone(),
1043 session_id: session_id.clone(),
1044 }) as _)
1045 }
1046
1047 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1048 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1049 }
1050
1051 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1052 self
1053 }
1054}
1055
1056impl acp_thread::AgentTelemetry for NativeAgentConnection {
1057 fn agent_name(&self) -> String {
1058 "Zed".into()
1059 }
1060
1061 fn thread_data(
1062 &self,
1063 session_id: &acp::SessionId,
1064 cx: &mut App,
1065 ) -> Task<Result<serde_json::Value>> {
1066 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1067 return Task::ready(Err(anyhow!("Session not found")));
1068 };
1069
1070 let task = session.thread.read(cx).to_db(cx);
1071 cx.background_spawn(async move {
1072 serde_json::to_value(task.await).context("Failed to serialize thread")
1073 })
1074 }
1075}
1076
1077struct NativeAgentSessionTruncate {
1078 thread: Entity<Thread>,
1079 acp_thread: WeakEntity<AcpThread>,
1080}
1081
1082impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1083 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1084 match self.thread.update(cx, |thread, cx| {
1085 thread.truncate(message_id.clone(), cx)?;
1086 Ok(thread.latest_token_usage())
1087 }) {
1088 Ok(usage) => {
1089 self.acp_thread
1090 .update(cx, |thread, cx| {
1091 thread.update_token_usage(usage, cx);
1092 })
1093 .ok();
1094 Task::ready(Ok(()))
1095 }
1096 Err(error) => Task::ready(Err(error)),
1097 }
1098 }
1099}
1100
1101struct NativeAgentSessionResume {
1102 connection: NativeAgentConnection,
1103 session_id: acp::SessionId,
1104}
1105
1106impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1107 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1108 self.connection
1109 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1110 thread.update(cx, |thread, cx| thread.resume(cx))
1111 })
1112 }
1113}
1114
1115struct NativeAgentSessionSetTitle {
1116 connection: NativeAgentConnection,
1117 session_id: acp::SessionId,
1118}
1119
1120impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1121 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1122 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1123 return Task::ready(Err(anyhow!("session not found")));
1124 };
1125 let thread = session.thread.clone();
1126 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1127 Task::ready(Ok(()))
1128 }
1129}
1130
1131pub struct AcpThreadEnvironment {
1132 acp_thread: WeakEntity<AcpThread>,
1133}
1134
1135impl ThreadEnvironment for AcpThreadEnvironment {
1136 fn create_terminal(
1137 &self,
1138 command: String,
1139 cwd: Option<PathBuf>,
1140 output_byte_limit: Option<u64>,
1141 cx: &mut AsyncApp,
1142 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1143 let task = self.acp_thread.update(cx, |thread, cx| {
1144 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1145 });
1146
1147 let acp_thread = self.acp_thread.clone();
1148 cx.spawn(async move |cx| {
1149 let terminal = task?.await?;
1150
1151 let (drop_tx, drop_rx) = oneshot::channel();
1152 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1153
1154 cx.spawn(async move |cx| {
1155 drop_rx.await.ok();
1156 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1157 })
1158 .detach();
1159
1160 let handle = AcpTerminalHandle {
1161 terminal,
1162 _drop_tx: Some(drop_tx),
1163 };
1164
1165 Ok(Rc::new(handle) as _)
1166 })
1167 }
1168}
1169
1170pub struct AcpTerminalHandle {
1171 terminal: Entity<acp_thread::Terminal>,
1172 _drop_tx: Option<oneshot::Sender<()>>,
1173}
1174
1175impl TerminalHandle for AcpTerminalHandle {
1176 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1177 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1178 }
1179
1180 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1181 self.terminal
1182 .read_with(cx, |term, _cx| term.wait_for_exit())
1183 }
1184
1185 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1186 self.terminal
1187 .read_with(cx, |term, cx| term.current_output(cx))
1188 }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193 use crate::HistoryEntryId;
1194
1195 use super::*;
1196 use acp_thread::{
1197 AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
1198 };
1199 use fs::FakeFs;
1200 use gpui::TestAppContext;
1201 use indoc::indoc;
1202 use language_model::fake_provider::FakeLanguageModel;
1203 use serde_json::json;
1204 use settings::SettingsStore;
1205 use util::path;
1206
1207 #[gpui::test]
1208 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1209 init_test(cx);
1210 let fs = FakeFs::new(cx.executor());
1211 fs.insert_tree(
1212 "/",
1213 json!({
1214 "a": {}
1215 }),
1216 )
1217 .await;
1218 let project = Project::test(fs.clone(), [], cx).await;
1219 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1220 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1221 let agent = NativeAgent::new(
1222 project.clone(),
1223 history_store,
1224 Templates::new(),
1225 None,
1226 fs.clone(),
1227 &mut cx.to_async(),
1228 )
1229 .await
1230 .unwrap();
1231 agent.read_with(cx, |agent, cx| {
1232 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1233 });
1234
1235 let worktree = project
1236 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1237 .await
1238 .unwrap();
1239 cx.run_until_parked();
1240 agent.read_with(cx, |agent, cx| {
1241 assert_eq!(
1242 agent.project_context.read(cx).worktrees,
1243 vec![WorktreeContext {
1244 root_name: "a".into(),
1245 abs_path: Path::new("/a").into(),
1246 rules_file: None
1247 }]
1248 )
1249 });
1250
1251 // Creating `/a/.rules` updates the project context.
1252 fs.insert_file("/a/.rules", Vec::new()).await;
1253 cx.run_until_parked();
1254 agent.read_with(cx, |agent, cx| {
1255 let rules_entry = worktree.read(cx).entry_for_path(".rules").unwrap();
1256 assert_eq!(
1257 agent.project_context.read(cx).worktrees,
1258 vec![WorktreeContext {
1259 root_name: "a".into(),
1260 abs_path: Path::new("/a").into(),
1261 rules_file: Some(RulesFileContext {
1262 path_in_worktree: Path::new(".rules").into(),
1263 text: "".into(),
1264 project_entry_id: rules_entry.id.to_usize()
1265 })
1266 }]
1267 )
1268 });
1269 }
1270
1271 #[gpui::test]
1272 async fn test_listing_models(cx: &mut TestAppContext) {
1273 init_test(cx);
1274 let fs = FakeFs::new(cx.executor());
1275 fs.insert_tree("/", json!({ "a": {} })).await;
1276 let project = Project::test(fs.clone(), [], cx).await;
1277 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1278 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1279 let connection = NativeAgentConnection(
1280 NativeAgent::new(
1281 project.clone(),
1282 history_store,
1283 Templates::new(),
1284 None,
1285 fs.clone(),
1286 &mut cx.to_async(),
1287 )
1288 .await
1289 .unwrap(),
1290 );
1291
1292 let models = cx.update(|cx| connection.list_models(cx)).await.unwrap();
1293
1294 let acp_thread::AgentModelList::Grouped(models) = models else {
1295 panic!("Unexpected model group");
1296 };
1297 assert_eq!(
1298 models,
1299 IndexMap::from_iter([(
1300 AgentModelGroupName("Fake".into()),
1301 vec![AgentModelInfo {
1302 id: AgentModelId("fake/fake".into()),
1303 name: "Fake".into(),
1304 icon: Some(ui::IconName::ZedAssistant),
1305 }]
1306 )])
1307 );
1308 }
1309
1310 #[gpui::test]
1311 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1312 init_test(cx);
1313 let fs = FakeFs::new(cx.executor());
1314 fs.create_dir(paths::settings_file().parent().unwrap())
1315 .await
1316 .unwrap();
1317 fs.insert_file(
1318 paths::settings_file(),
1319 json!({
1320 "agent": {
1321 "default_model": {
1322 "provider": "foo",
1323 "model": "bar"
1324 }
1325 }
1326 })
1327 .to_string()
1328 .into_bytes(),
1329 )
1330 .await;
1331 let project = Project::test(fs.clone(), [], cx).await;
1332
1333 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1334 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1335
1336 // Create the agent and connection
1337 let agent = NativeAgent::new(
1338 project.clone(),
1339 history_store,
1340 Templates::new(),
1341 None,
1342 fs.clone(),
1343 &mut cx.to_async(),
1344 )
1345 .await
1346 .unwrap();
1347 let connection = NativeAgentConnection(agent.clone());
1348
1349 // Create a thread/session
1350 let acp_thread = cx
1351 .update(|cx| {
1352 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1353 })
1354 .await
1355 .unwrap();
1356
1357 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1358
1359 // Select a model
1360 let model_id = AgentModelId("fake/fake".into());
1361 cx.update(|cx| connection.select_model(session_id.clone(), model_id.clone(), cx))
1362 .await
1363 .unwrap();
1364
1365 // Verify the thread has the selected model
1366 agent.read_with(cx, |agent, _| {
1367 let session = agent.sessions.get(&session_id).unwrap();
1368 session.thread.read_with(cx, |thread, _| {
1369 assert_eq!(thread.model().unwrap().id().0, "fake");
1370 });
1371 });
1372
1373 cx.run_until_parked();
1374
1375 // Verify settings file was updated
1376 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1377 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1378
1379 // Check that the agent settings contain the selected model
1380 assert_eq!(
1381 settings_json["agent"]["default_model"]["model"],
1382 json!("fake")
1383 );
1384 assert_eq!(
1385 settings_json["agent"]["default_model"]["provider"],
1386 json!("fake")
1387 );
1388 }
1389
1390 #[gpui::test]
1391 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1392 async fn test_save_load_thread(cx: &mut TestAppContext) {
1393 init_test(cx);
1394 let fs = FakeFs::new(cx.executor());
1395 fs.insert_tree(
1396 "/",
1397 json!({
1398 "a": {
1399 "b.md": "Lorem"
1400 }
1401 }),
1402 )
1403 .await;
1404 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1405 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1406 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1407 let agent = NativeAgent::new(
1408 project.clone(),
1409 history_store.clone(),
1410 Templates::new(),
1411 None,
1412 fs.clone(),
1413 &mut cx.to_async(),
1414 )
1415 .await
1416 .unwrap();
1417 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1418
1419 let acp_thread = cx
1420 .update(|cx| {
1421 connection
1422 .clone()
1423 .new_thread(project.clone(), Path::new(""), cx)
1424 })
1425 .await
1426 .unwrap();
1427 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1428 let thread = agent.read_with(cx, |agent, _| {
1429 agent.sessions.get(&session_id).unwrap().thread.clone()
1430 });
1431
1432 // Ensure empty threads are not saved, even if they get mutated.
1433 let model = Arc::new(FakeLanguageModel::default());
1434 let summary_model = Arc::new(FakeLanguageModel::default());
1435 thread.update(cx, |thread, cx| {
1436 thread.set_model(model.clone(), cx);
1437 thread.set_summarization_model(Some(summary_model.clone()), cx);
1438 });
1439 cx.run_until_parked();
1440 assert_eq!(history_entries(&history_store, cx), vec![]);
1441
1442 let send = acp_thread.update(cx, |thread, cx| {
1443 thread.send(
1444 vec![
1445 "What does ".into(),
1446 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1447 name: "b.md".into(),
1448 uri: MentionUri::File {
1449 abs_path: path!("/a/b.md").into(),
1450 }
1451 .to_uri()
1452 .to_string(),
1453 annotations: None,
1454 description: None,
1455 mime_type: None,
1456 size: None,
1457 title: None,
1458 meta: None,
1459 }),
1460 " mean?".into(),
1461 ],
1462 cx,
1463 )
1464 });
1465 let send = cx.foreground_executor().spawn(send);
1466 cx.run_until_parked();
1467
1468 model.send_last_completion_stream_text_chunk("Lorem.");
1469 model.end_last_completion_stream();
1470 cx.run_until_parked();
1471 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1472 summary_model.end_last_completion_stream();
1473
1474 send.await.unwrap();
1475 acp_thread.read_with(cx, |thread, cx| {
1476 assert_eq!(
1477 thread.to_markdown(cx),
1478 indoc! {"
1479 ## User
1480
1481 What does [@b.md](file:///a/b.md) mean?
1482
1483 ## Assistant
1484
1485 Lorem.
1486
1487 "}
1488 )
1489 });
1490
1491 cx.run_until_parked();
1492
1493 // Drop the ACP thread, which should cause the session to be dropped as well.
1494 cx.update(|_| {
1495 drop(thread);
1496 drop(acp_thread);
1497 });
1498 agent.read_with(cx, |agent, _| {
1499 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1500 });
1501
1502 // Ensure the thread can be reloaded from disk.
1503 assert_eq!(
1504 history_entries(&history_store, cx),
1505 vec![(
1506 HistoryEntryId::AcpThread(session_id.clone()),
1507 "Explaining /a/b.md".into()
1508 )]
1509 );
1510 let acp_thread = agent
1511 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1512 .await
1513 .unwrap();
1514 acp_thread.read_with(cx, |thread, cx| {
1515 assert_eq!(
1516 thread.to_markdown(cx),
1517 indoc! {"
1518 ## User
1519
1520 What does [@b.md](file:///a/b.md) mean?
1521
1522 ## Assistant
1523
1524 Lorem.
1525
1526 "}
1527 )
1528 });
1529 }
1530
1531 fn history_entries(
1532 history: &Entity<HistoryStore>,
1533 cx: &mut TestAppContext,
1534 ) -> Vec<(HistoryEntryId, String)> {
1535 history.read_with(cx, |history, _| {
1536 history
1537 .entries()
1538 .map(|e| (e.id(), e.title().to_string()))
1539 .collect::<Vec<_>>()
1540 })
1541 }
1542
1543 fn init_test(cx: &mut TestAppContext) {
1544 env_logger::try_init().ok();
1545 cx.update(|cx| {
1546 let settings_store = SettingsStore::test(cx);
1547 cx.set_global(settings_store);
1548 Project::init_settings(cx);
1549 agent_settings::init(cx);
1550 language::init(cx);
1551 LanguageModelRegistry::test(cx);
1552 });
1553 }
1554}