1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tools;
10
11#[cfg(test)]
12mod tests;
13
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::collections::HashMap;
43use std::path::{Path, PathBuf};
44use std::rc::Rc;
45use std::sync::Arc;
46use util::ResultExt;
47use util::rel_path::RelPath;
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct ProjectSnapshot {
51 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
52 pub timestamp: DateTime<Utc>,
53}
54
55pub struct RulesLoadingError {
56 pub message: SharedString,
57}
58
59/// Holds both the internal Thread and the AcpThread for a session
60struct Session {
61 /// The internal thread that processes messages
62 thread: Entity<Thread>,
63 /// The ACP thread that handles protocol communication
64 acp_thread: WeakEntity<acp_thread::AcpThread>,
65 pending_save: Task<()>,
66 _subscriptions: Vec<Subscription>,
67}
68
69pub struct LanguageModels {
70 /// Access language model by ID
71 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
72 /// Cached list for returning language model information
73 model_list: acp_thread::AgentModelList,
74 refresh_models_rx: watch::Receiver<()>,
75 refresh_models_tx: watch::Sender<()>,
76 _authenticate_all_providers_task: Task<()>,
77}
78
79impl LanguageModels {
80 fn new(cx: &mut App) -> Self {
81 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
82
83 let mut this = Self {
84 models: HashMap::default(),
85 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
86 refresh_models_rx,
87 refresh_models_tx,
88 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
89 };
90 this.refresh_list(cx);
91 this
92 }
93
94 fn refresh_list(&mut self, cx: &App) {
95 let providers = LanguageModelRegistry::global(cx)
96 .read(cx)
97 .providers()
98 .into_iter()
99 .filter(|provider| provider.is_authenticated(cx))
100 .collect::<Vec<_>>();
101
102 let mut language_model_list = IndexMap::default();
103 let mut recommended_models = HashSet::default();
104
105 let mut recommended = Vec::new();
106 for provider in &providers {
107 for model in provider.recommended_models(cx) {
108 recommended_models.insert((model.provider_id(), model.id()));
109 recommended.push(Self::map_language_model_to_info(&model, provider));
110 }
111 }
112 if !recommended.is_empty() {
113 language_model_list.insert(
114 acp_thread::AgentModelGroupName("Recommended".into()),
115 recommended,
116 );
117 }
118
119 let mut models = HashMap::default();
120 for provider in providers {
121 let mut provider_models = Vec::new();
122 for model in provider.provided_models(cx) {
123 let model_info = Self::map_language_model_to_info(&model, &provider);
124 let model_id = model_info.id.clone();
125 provider_models.push(model_info);
126 models.insert(model_id, model);
127 }
128 if !provider_models.is_empty() {
129 language_model_list.insert(
130 acp_thread::AgentModelGroupName(provider.name().0.clone()),
131 provider_models,
132 );
133 }
134 }
135
136 self.models = models;
137 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
138 self.refresh_models_tx.send(()).ok();
139 }
140
141 fn watch(&self) -> watch::Receiver<()> {
142 self.refresh_models_rx.clone()
143 }
144
145 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
146 self.models.get(model_id).cloned()
147 }
148
149 fn map_language_model_to_info(
150 model: &Arc<dyn LanguageModel>,
151 provider: &Arc<dyn LanguageModelProvider>,
152 ) -> acp_thread::AgentModelInfo {
153 acp_thread::AgentModelInfo {
154 id: Self::model_id(model),
155 name: model.name().0,
156 description: None,
157 icon: Some(provider.icon()),
158 }
159 }
160
161 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
162 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
163 }
164
165 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
166 let authenticate_all_providers = LanguageModelRegistry::global(cx)
167 .read(cx)
168 .providers()
169 .iter()
170 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
171 .collect::<Vec<_>>();
172
173 cx.background_spawn(async move {
174 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
175 if let Err(err) = authenticate_task.await {
176 match err {
177 language_model::AuthenticateError::CredentialsNotFound => {
178 // Since we're authenticating these providers in the
179 // background for the purposes of populating the
180 // language selector, we don't care about providers
181 // where the credentials are not found.
182 }
183 language_model::AuthenticateError::ConnectionRefused => {
184 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
185 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
186 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
187 }
188 _ => {
189 // Some providers have noisy failure states that we
190 // don't want to spam the logs with every time the
191 // language model selector is initialized.
192 //
193 // Ideally these should have more clear failure modes
194 // that we know are safe to ignore here, like what we do
195 // with `CredentialsNotFound` above.
196 match provider_id.0.as_ref() {
197 "lmstudio" | "ollama" => {
198 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
199 //
200 // These fail noisily, so we don't log them.
201 }
202 "copilot_chat" => {
203 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
204 }
205 _ => {
206 log::error!(
207 "Failed to authenticate provider: {}: {err:#}",
208 provider_name.0
209 );
210 }
211 }
212 }
213 }
214 }
215 }
216 })
217 }
218}
219
220pub struct NativeAgent {
221 /// Session ID -> Session mapping
222 sessions: HashMap<acp::SessionId, Session>,
223 history: Entity<HistoryStore>,
224 /// Shared project context for all threads
225 project_context: Entity<ProjectContext>,
226 project_context_needs_refresh: watch::Sender<()>,
227 _maintain_project_context: Task<Result<()>>,
228 context_server_registry: Entity<ContextServerRegistry>,
229 /// Shared templates for all threads
230 templates: Arc<Templates>,
231 /// Cached model information
232 models: LanguageModels,
233 project: Entity<Project>,
234 prompt_store: Option<Entity<PromptStore>>,
235 fs: Arc<dyn Fs>,
236 _subscriptions: Vec<Subscription>,
237}
238
239impl NativeAgent {
240 pub async fn new(
241 project: Entity<Project>,
242 history: Entity<HistoryStore>,
243 templates: Arc<Templates>,
244 prompt_store: Option<Entity<PromptStore>>,
245 fs: Arc<dyn Fs>,
246 cx: &mut AsyncApp,
247 ) -> Result<Entity<NativeAgent>> {
248 log::debug!("Creating new NativeAgent");
249
250 let project_context = cx
251 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
252 .await;
253
254 cx.new(|cx| {
255 let mut subscriptions = vec![
256 cx.subscribe(&project, Self::handle_project_event),
257 cx.subscribe(
258 &LanguageModelRegistry::global(cx),
259 Self::handle_models_updated_event,
260 ),
261 ];
262 if let Some(prompt_store) = prompt_store.as_ref() {
263 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
264 }
265
266 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
267 watch::channel(());
268 Self {
269 sessions: HashMap::new(),
270 history,
271 project_context: cx.new(|_| project_context),
272 project_context_needs_refresh: project_context_needs_refresh_tx,
273 _maintain_project_context: cx.spawn(async move |this, cx| {
274 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
275 }),
276 context_server_registry: cx.new(|cx| {
277 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
278 }),
279 templates,
280 models: LanguageModels::new(cx),
281 project,
282 prompt_store,
283 fs,
284 _subscriptions: subscriptions,
285 }
286 })
287 }
288
289 fn register_session(
290 &mut self,
291 thread_handle: Entity<Thread>,
292 cx: &mut Context<Self>,
293 ) -> Entity<AcpThread> {
294 let connection = Rc::new(NativeAgentConnection(cx.entity()));
295
296 let thread = thread_handle.read(cx);
297 let session_id = thread.id().clone();
298 let title = thread.title();
299 let project = thread.project.clone();
300 let action_log = thread.action_log.clone();
301 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
302 let acp_thread = cx.new(|cx| {
303 acp_thread::AcpThread::new(
304 title,
305 connection,
306 project.clone(),
307 action_log.clone(),
308 session_id.clone(),
309 prompt_capabilities_rx,
310 cx,
311 )
312 });
313
314 let registry = LanguageModelRegistry::read_global(cx);
315 let summarization_model = registry.thread_summary_model().map(|c| c.model);
316
317 thread_handle.update(cx, |thread, cx| {
318 thread.set_summarization_model(summarization_model, cx);
319 thread.add_default_tools(
320 Rc::new(AcpThreadEnvironment {
321 acp_thread: acp_thread.downgrade(),
322 }) as _,
323 cx,
324 )
325 });
326
327 let subscriptions = vec![
328 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
329 this.sessions.remove(acp_thread.session_id());
330 }),
331 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
332 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
333 cx.observe(&thread_handle, move |this, thread, cx| {
334 this.save_thread(thread, cx)
335 }),
336 ];
337
338 self.sessions.insert(
339 session_id,
340 Session {
341 thread: thread_handle,
342 acp_thread: acp_thread.downgrade(),
343 _subscriptions: subscriptions,
344 pending_save: Task::ready(()),
345 },
346 );
347 acp_thread
348 }
349
350 pub fn models(&self) -> &LanguageModels {
351 &self.models
352 }
353
354 async fn maintain_project_context(
355 this: WeakEntity<Self>,
356 mut needs_refresh: watch::Receiver<()>,
357 cx: &mut AsyncApp,
358 ) -> Result<()> {
359 while needs_refresh.changed().await.is_ok() {
360 let project_context = this
361 .update(cx, |this, cx| {
362 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
363 })?
364 .await;
365 this.update(cx, |this, cx| {
366 this.project_context = cx.new(|_| project_context);
367 })?;
368 }
369
370 Ok(())
371 }
372
373 fn build_project_context(
374 project: &Entity<Project>,
375 prompt_store: Option<&Entity<PromptStore>>,
376 cx: &mut App,
377 ) -> Task<ProjectContext> {
378 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
379 let worktree_tasks = worktrees
380 .into_iter()
381 .map(|worktree| {
382 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
383 })
384 .collect::<Vec<_>>();
385 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
386 prompt_store.read_with(cx, |prompt_store, cx| {
387 let prompts = prompt_store.default_prompt_metadata();
388 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
389 let contents = prompt_store.load(prompt_metadata.id, cx);
390 async move { (contents.await, prompt_metadata) }
391 });
392 cx.background_spawn(future::join_all(load_tasks))
393 })
394 } else {
395 Task::ready(vec![])
396 };
397
398 cx.spawn(async move |_cx| {
399 let (worktrees, default_user_rules) =
400 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
401
402 let worktrees = worktrees
403 .into_iter()
404 .map(|(worktree, _rules_error)| {
405 // TODO: show error message
406 // if let Some(rules_error) = rules_error {
407 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
408 // }
409 worktree
410 })
411 .collect::<Vec<_>>();
412
413 let default_user_rules = default_user_rules
414 .into_iter()
415 .flat_map(|(contents, prompt_metadata)| match contents {
416 Ok(contents) => Some(UserRulesContext {
417 uuid: match prompt_metadata.id {
418 prompt_store::PromptId::User { uuid } => uuid,
419 prompt_store::PromptId::EditWorkflow => return None,
420 },
421 title: prompt_metadata.title.map(|title| title.to_string()),
422 contents,
423 }),
424 Err(_err) => {
425 // TODO: show error message
426 // this.update(cx, |_, cx| {
427 // cx.emit(RulesLoadingError {
428 // message: format!("{err:?}").into(),
429 // });
430 // })
431 // .ok();
432 None
433 }
434 })
435 .collect::<Vec<_>>();
436
437 ProjectContext::new(worktrees, default_user_rules)
438 })
439 }
440
441 fn load_worktree_info_for_system_prompt(
442 worktree: Entity<Worktree>,
443 project: Entity<Project>,
444 cx: &mut App,
445 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
446 let tree = worktree.read(cx);
447 let root_name = tree.root_name_str().into();
448 let abs_path = tree.abs_path();
449
450 let mut context = WorktreeContext {
451 root_name,
452 abs_path,
453 rules_file: None,
454 };
455
456 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
457 let Some(rules_task) = rules_task else {
458 return Task::ready((context, None));
459 };
460
461 cx.spawn(async move |_| {
462 let (rules_file, rules_file_error) = match rules_task.await {
463 Ok(rules_file) => (Some(rules_file), None),
464 Err(err) => (
465 None,
466 Some(RulesLoadingError {
467 message: format!("{err}").into(),
468 }),
469 ),
470 };
471 context.rules_file = rules_file;
472 (context, rules_file_error)
473 })
474 }
475
476 fn load_worktree_rules_file(
477 worktree: Entity<Worktree>,
478 project: Entity<Project>,
479 cx: &mut App,
480 ) -> Option<Task<Result<RulesFileContext>>> {
481 let worktree = worktree.read(cx);
482 let worktree_id = worktree.id();
483 let selected_rules_file = RULES_FILE_NAMES
484 .into_iter()
485 .filter_map(|name| {
486 worktree
487 .entry_for_path(RelPath::unix(name).unwrap())
488 .filter(|entry| entry.is_file())
489 .map(|entry| entry.path.clone())
490 })
491 .next();
492
493 // Note that Cline supports `.clinerules` being a directory, but that is not currently
494 // supported. This doesn't seem to occur often in GitHub repositories.
495 selected_rules_file.map(|path_in_worktree| {
496 let project_path = ProjectPath {
497 worktree_id,
498 path: path_in_worktree.clone(),
499 };
500 let buffer_task =
501 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
502 let rope_task = cx.spawn(async move |cx| {
503 buffer_task.await?.read_with(cx, |buffer, cx| {
504 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
505 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
506 })?
507 });
508 // Build a string from the rope on a background thread.
509 cx.background_spawn(async move {
510 let (project_entry_id, rope) = rope_task.await?;
511 anyhow::Ok(RulesFileContext {
512 path_in_worktree,
513 text: rope.to_string().trim().to_string(),
514 project_entry_id: project_entry_id.to_usize(),
515 })
516 })
517 })
518 }
519
520 fn handle_thread_title_updated(
521 &mut self,
522 thread: Entity<Thread>,
523 _: &TitleUpdated,
524 cx: &mut Context<Self>,
525 ) {
526 let session_id = thread.read(cx).id();
527 let Some(session) = self.sessions.get(session_id) else {
528 return;
529 };
530 let thread = thread.downgrade();
531 let acp_thread = session.acp_thread.clone();
532 cx.spawn(async move |_, cx| {
533 let title = thread.read_with(cx, |thread, _| thread.title())?;
534 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
535 task.await
536 })
537 .detach_and_log_err(cx);
538 }
539
540 fn handle_thread_token_usage_updated(
541 &mut self,
542 thread: Entity<Thread>,
543 usage: &TokenUsageUpdated,
544 cx: &mut Context<Self>,
545 ) {
546 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
547 return;
548 };
549 session
550 .acp_thread
551 .update(cx, |acp_thread, cx| {
552 acp_thread.update_token_usage(usage.0.clone(), cx);
553 })
554 .ok();
555 }
556
557 fn handle_project_event(
558 &mut self,
559 _project: Entity<Project>,
560 event: &project::Event,
561 _cx: &mut Context<Self>,
562 ) {
563 match event {
564 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
565 self.project_context_needs_refresh.send(()).ok();
566 }
567 project::Event::WorktreeUpdatedEntries(_, items) => {
568 if items.iter().any(|(path, _, _)| {
569 RULES_FILE_NAMES
570 .iter()
571 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
572 }) {
573 self.project_context_needs_refresh.send(()).ok();
574 }
575 }
576 _ => {}
577 }
578 }
579
580 fn handle_prompts_updated_event(
581 &mut self,
582 _prompt_store: Entity<PromptStore>,
583 _event: &prompt_store::PromptsUpdatedEvent,
584 _cx: &mut Context<Self>,
585 ) {
586 self.project_context_needs_refresh.send(()).ok();
587 }
588
589 fn handle_models_updated_event(
590 &mut self,
591 _registry: Entity<LanguageModelRegistry>,
592 _event: &language_model::Event,
593 cx: &mut Context<Self>,
594 ) {
595 self.models.refresh_list(cx);
596
597 let registry = LanguageModelRegistry::read_global(cx);
598 let default_model = registry.default_model().map(|m| m.model);
599 let summarization_model = registry.thread_summary_model().map(|m| m.model);
600
601 for session in self.sessions.values_mut() {
602 session.thread.update(cx, |thread, cx| {
603 if thread.model().is_none()
604 && let Some(model) = default_model.clone()
605 {
606 thread.set_model(model, cx);
607 cx.notify();
608 }
609 thread.set_summarization_model(summarization_model.clone(), cx);
610 });
611 }
612 }
613
614 pub fn load_thread(
615 &mut self,
616 id: acp::SessionId,
617 cx: &mut Context<Self>,
618 ) -> Task<Result<Entity<Thread>>> {
619 let database_future = ThreadsDatabase::connect(cx);
620 cx.spawn(async move |this, cx| {
621 let database = database_future.await.map_err(|err| anyhow!(err))?;
622 let db_thread = database
623 .load_thread(id.clone())
624 .await?
625 .with_context(|| format!("no thread found with ID: {id:?}"))?;
626
627 this.update(cx, |this, cx| {
628 let summarization_model = LanguageModelRegistry::read_global(cx)
629 .thread_summary_model()
630 .map(|c| c.model);
631
632 cx.new(|cx| {
633 let mut thread = Thread::from_db(
634 id.clone(),
635 db_thread,
636 this.project.clone(),
637 this.project_context.clone(),
638 this.context_server_registry.clone(),
639 this.templates.clone(),
640 cx,
641 );
642 thread.set_summarization_model(summarization_model, cx);
643 thread
644 })
645 })
646 })
647 }
648
649 pub fn open_thread(
650 &mut self,
651 id: acp::SessionId,
652 cx: &mut Context<Self>,
653 ) -> Task<Result<Entity<AcpThread>>> {
654 let task = self.load_thread(id, cx);
655 cx.spawn(async move |this, cx| {
656 let thread = task.await?;
657 let acp_thread =
658 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
659 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
660 cx.update(|cx| {
661 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
662 })?
663 .await?;
664 Ok(acp_thread)
665 })
666 }
667
668 pub fn thread_summary(
669 &mut self,
670 id: acp::SessionId,
671 cx: &mut Context<Self>,
672 ) -> Task<Result<SharedString>> {
673 let thread = self.open_thread(id.clone(), cx);
674 cx.spawn(async move |this, cx| {
675 let acp_thread = thread.await?;
676 let result = this
677 .update(cx, |this, cx| {
678 this.sessions
679 .get(&id)
680 .unwrap()
681 .thread
682 .update(cx, |thread, cx| thread.summary(cx))
683 })?
684 .await
685 .context("Failed to generate summary")?;
686 drop(acp_thread);
687 Ok(result)
688 })
689 }
690
691 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
692 if thread.read(cx).is_empty() {
693 return;
694 }
695
696 let database_future = ThreadsDatabase::connect(cx);
697 let (id, db_thread) =
698 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
699 let Some(session) = self.sessions.get_mut(&id) else {
700 return;
701 };
702 let history = self.history.clone();
703 session.pending_save = cx.spawn(async move |_, cx| {
704 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
705 return;
706 };
707 let db_thread = db_thread.await;
708 database.save_thread(id, db_thread).await.log_err();
709 history.update(cx, |history, cx| history.reload(cx)).ok();
710 });
711 }
712}
713
714/// Wrapper struct that implements the AgentConnection trait
715#[derive(Clone)]
716pub struct NativeAgentConnection(pub Entity<NativeAgent>);
717
718impl NativeAgentConnection {
719 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
720 self.0
721 .read(cx)
722 .sessions
723 .get(session_id)
724 .map(|session| session.thread.clone())
725 }
726
727 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
728 self.0.update(cx, |this, cx| this.load_thread(id, cx))
729 }
730
731 fn run_turn(
732 &self,
733 session_id: acp::SessionId,
734 cx: &mut App,
735 f: impl 'static
736 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
737 ) -> Task<Result<acp::PromptResponse>> {
738 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
739 agent
740 .sessions
741 .get_mut(&session_id)
742 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
743 }) else {
744 return Task::ready(Err(anyhow!("Session not found")));
745 };
746 log::debug!("Found session for: {}", session_id);
747
748 let response_stream = match f(thread, cx) {
749 Ok(stream) => stream,
750 Err(err) => return Task::ready(Err(err)),
751 };
752 Self::handle_thread_events(response_stream, acp_thread, cx)
753 }
754
755 fn handle_thread_events(
756 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
757 acp_thread: WeakEntity<AcpThread>,
758 cx: &App,
759 ) -> Task<Result<acp::PromptResponse>> {
760 cx.spawn(async move |cx| {
761 // Handle response stream and forward to session.acp_thread
762 while let Some(result) = events.next().await {
763 match result {
764 Ok(event) => {
765 log::trace!("Received completion event: {:?}", event);
766
767 match event {
768 ThreadEvent::UserMessage(message) => {
769 acp_thread.update(cx, |thread, cx| {
770 for content in message.content {
771 thread.push_user_content_block(
772 Some(message.id.clone()),
773 content.into(),
774 cx,
775 );
776 }
777 })?;
778 }
779 ThreadEvent::AgentText(text) => {
780 acp_thread.update(cx, |thread, cx| {
781 thread.push_assistant_content_block(text.into(), false, cx)
782 })?;
783 }
784 ThreadEvent::AgentThinking(text) => {
785 acp_thread.update(cx, |thread, cx| {
786 thread.push_assistant_content_block(text.into(), true, cx)
787 })?;
788 }
789 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
790 tool_call,
791 options,
792 response,
793 }) => {
794 let outcome_task = acp_thread.update(cx, |thread, cx| {
795 thread.request_tool_call_authorization(
796 tool_call, options, true, cx,
797 )
798 })??;
799 cx.background_spawn(async move {
800 if let acp::RequestPermissionOutcome::Selected(
801 acp::SelectedPermissionOutcome { option_id, .. },
802 ) = outcome_task.await
803 {
804 response
805 .send(option_id)
806 .map(|_| anyhow!("authorization receiver was dropped"))
807 .log_err();
808 }
809 })
810 .detach();
811 }
812 ThreadEvent::ToolCall(tool_call) => {
813 acp_thread.update(cx, |thread, cx| {
814 thread.upsert_tool_call(tool_call, cx)
815 })??;
816 }
817 ThreadEvent::ToolCallUpdate(update) => {
818 acp_thread.update(cx, |thread, cx| {
819 thread.update_tool_call(update, cx)
820 })??;
821 }
822 ThreadEvent::Retry(status) => {
823 acp_thread.update(cx, |thread, cx| {
824 thread.update_retry_status(status, cx)
825 })?;
826 }
827 ThreadEvent::Stop(stop_reason) => {
828 log::debug!("Assistant message complete: {:?}", stop_reason);
829 return Ok(acp::PromptResponse::new(stop_reason));
830 }
831 }
832 }
833 Err(e) => {
834 log::error!("Error in model response stream: {:?}", e);
835 return Err(e);
836 }
837 }
838 }
839
840 log::debug!("Response stream completed");
841 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
842 })
843 }
844}
845
846struct NativeAgentModelSelector {
847 session_id: acp::SessionId,
848 connection: NativeAgentConnection,
849}
850
851impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
852 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
853 log::debug!("NativeAgentConnection::list_models called");
854 let list = self.connection.0.read(cx).models.model_list.clone();
855 Task::ready(if list.is_empty() {
856 Err(anyhow::anyhow!("No models available"))
857 } else {
858 Ok(list)
859 })
860 }
861
862 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
863 log::debug!(
864 "Setting model for session {}: {}",
865 self.session_id,
866 model_id
867 );
868 let Some(thread) = self
869 .connection
870 .0
871 .read(cx)
872 .sessions
873 .get(&self.session_id)
874 .map(|session| session.thread.clone())
875 else {
876 return Task::ready(Err(anyhow!("Session not found")));
877 };
878
879 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
880 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
881 };
882
883 thread.update(cx, |thread, cx| {
884 thread.set_model(model.clone(), cx);
885 });
886
887 update_settings_file(
888 self.connection.0.read(cx).fs.clone(),
889 cx,
890 move |settings, _cx| {
891 let provider = model.provider_id().0.to_string();
892 let model = model.id().0.to_string();
893 settings
894 .agent
895 .get_or_insert_default()
896 .set_model(LanguageModelSelection {
897 provider: provider.into(),
898 model,
899 });
900 },
901 );
902
903 Task::ready(Ok(()))
904 }
905
906 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
907 let Some(thread) = self
908 .connection
909 .0
910 .read(cx)
911 .sessions
912 .get(&self.session_id)
913 .map(|session| session.thread.clone())
914 else {
915 return Task::ready(Err(anyhow!("Session not found")));
916 };
917 let Some(model) = thread.read(cx).model() else {
918 return Task::ready(Err(anyhow!("Model not found")));
919 };
920 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
921 else {
922 return Task::ready(Err(anyhow!("Provider not found")));
923 };
924 Task::ready(Ok(LanguageModels::map_language_model_to_info(
925 model, &provider,
926 )))
927 }
928
929 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
930 Some(self.connection.0.read(cx).models.watch())
931 }
932
933 fn should_render_footer(&self) -> bool {
934 true
935 }
936}
937
938impl acp_thread::AgentConnection for NativeAgentConnection {
939 fn telemetry_id(&self) -> SharedString {
940 "zed".into()
941 }
942
943 fn new_thread(
944 self: Rc<Self>,
945 project: Entity<Project>,
946 cwd: &Path,
947 cx: &mut App,
948 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
949 let agent = self.0.clone();
950 log::debug!("Creating new thread for project at: {:?}", cwd);
951
952 cx.spawn(async move |cx| {
953 log::debug!("Starting thread creation in async context");
954
955 // Create Thread
956 let thread = agent.update(
957 cx,
958 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
959 // Fetch default model from registry settings
960 let registry = LanguageModelRegistry::read_global(cx);
961 // Log available models for debugging
962 let available_count = registry.available_models(cx).count();
963 log::debug!("Total available models: {}", available_count);
964
965 let default_model = registry.default_model().and_then(|default_model| {
966 agent
967 .models
968 .model_from_id(&LanguageModels::model_id(&default_model.model))
969 });
970 Ok(cx.new(|cx| {
971 Thread::new(
972 project.clone(),
973 agent.project_context.clone(),
974 agent.context_server_registry.clone(),
975 agent.templates.clone(),
976 default_model,
977 cx,
978 )
979 }))
980 },
981 )??;
982 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
983 })
984 }
985
986 fn auth_methods(&self) -> &[acp::AuthMethod] {
987 &[] // No auth for in-process
988 }
989
990 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
991 Task::ready(Ok(()))
992 }
993
994 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
995 Some(Rc::new(NativeAgentModelSelector {
996 session_id: session_id.clone(),
997 connection: self.clone(),
998 }) as Rc<dyn AgentModelSelector>)
999 }
1000
1001 fn prompt(
1002 &self,
1003 id: Option<acp_thread::UserMessageId>,
1004 params: acp::PromptRequest,
1005 cx: &mut App,
1006 ) -> Task<Result<acp::PromptResponse>> {
1007 let id = id.expect("UserMessageId is required");
1008 let session_id = params.session_id.clone();
1009 log::info!("Received prompt request for session: {}", session_id);
1010 log::debug!("Prompt blocks count: {}", params.prompt.len());
1011 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1012
1013 self.run_turn(session_id, cx, move |thread, cx| {
1014 let content: Vec<UserMessageContent> = params
1015 .prompt
1016 .into_iter()
1017 .map(|block| UserMessageContent::from_content_block(block, path_style))
1018 .collect::<Vec<_>>();
1019 log::debug!("Converted prompt to message: {} chars", content.len());
1020 log::debug!("Message id: {:?}", id);
1021 log::debug!("Message content: {:?}", content);
1022
1023 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1024 })
1025 }
1026
1027 fn resume(
1028 &self,
1029 session_id: &acp::SessionId,
1030 _cx: &App,
1031 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1032 Some(Rc::new(NativeAgentSessionResume {
1033 connection: self.clone(),
1034 session_id: session_id.clone(),
1035 }) as _)
1036 }
1037
1038 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1039 log::info!("Cancelling on session: {}", session_id);
1040 self.0.update(cx, |agent, cx| {
1041 if let Some(agent) = agent.sessions.get(session_id) {
1042 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1043 }
1044 });
1045 }
1046
1047 fn truncate(
1048 &self,
1049 session_id: &agent_client_protocol::SessionId,
1050 cx: &App,
1051 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1052 self.0.read_with(cx, |agent, _cx| {
1053 agent.sessions.get(session_id).map(|session| {
1054 Rc::new(NativeAgentSessionTruncate {
1055 thread: session.thread.clone(),
1056 acp_thread: session.acp_thread.clone(),
1057 }) as _
1058 })
1059 })
1060 }
1061
1062 fn set_title(
1063 &self,
1064 session_id: &acp::SessionId,
1065 _cx: &App,
1066 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1067 Some(Rc::new(NativeAgentSessionSetTitle {
1068 connection: self.clone(),
1069 session_id: session_id.clone(),
1070 }) as _)
1071 }
1072
1073 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1074 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1075 }
1076
1077 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1078 self
1079 }
1080}
1081
1082impl acp_thread::AgentTelemetry for NativeAgentConnection {
1083 fn thread_data(
1084 &self,
1085 session_id: &acp::SessionId,
1086 cx: &mut App,
1087 ) -> Task<Result<serde_json::Value>> {
1088 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1089 return Task::ready(Err(anyhow!("Session not found")));
1090 };
1091
1092 let task = session.thread.read(cx).to_db(cx);
1093 cx.background_spawn(async move {
1094 serde_json::to_value(task.await).context("Failed to serialize thread")
1095 })
1096 }
1097}
1098
1099struct NativeAgentSessionTruncate {
1100 thread: Entity<Thread>,
1101 acp_thread: WeakEntity<AcpThread>,
1102}
1103
1104impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1105 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1106 match self.thread.update(cx, |thread, cx| {
1107 thread.truncate(message_id.clone(), cx)?;
1108 Ok(thread.latest_token_usage())
1109 }) {
1110 Ok(usage) => {
1111 self.acp_thread
1112 .update(cx, |thread, cx| {
1113 thread.update_token_usage(usage, cx);
1114 })
1115 .ok();
1116 Task::ready(Ok(()))
1117 }
1118 Err(error) => Task::ready(Err(error)),
1119 }
1120 }
1121}
1122
1123struct NativeAgentSessionResume {
1124 connection: NativeAgentConnection,
1125 session_id: acp::SessionId,
1126}
1127
1128impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1129 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1130 self.connection
1131 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1132 thread.update(cx, |thread, cx| thread.resume(cx))
1133 })
1134 }
1135}
1136
1137struct NativeAgentSessionSetTitle {
1138 connection: NativeAgentConnection,
1139 session_id: acp::SessionId,
1140}
1141
1142impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1143 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1144 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1145 return Task::ready(Err(anyhow!("session not found")));
1146 };
1147 let thread = session.thread.clone();
1148 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1149 Task::ready(Ok(()))
1150 }
1151}
1152
1153pub struct AcpThreadEnvironment {
1154 acp_thread: WeakEntity<AcpThread>,
1155}
1156
1157impl ThreadEnvironment for AcpThreadEnvironment {
1158 fn create_terminal(
1159 &self,
1160 command: String,
1161 cwd: Option<PathBuf>,
1162 output_byte_limit: Option<u64>,
1163 cx: &mut AsyncApp,
1164 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1165 let task = self.acp_thread.update(cx, |thread, cx| {
1166 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1167 });
1168
1169 let acp_thread = self.acp_thread.clone();
1170 cx.spawn(async move |cx| {
1171 let terminal = task?.await?;
1172
1173 let (drop_tx, drop_rx) = oneshot::channel();
1174 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1175
1176 cx.spawn(async move |cx| {
1177 drop_rx.await.ok();
1178 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1179 })
1180 .detach();
1181
1182 let handle = AcpTerminalHandle {
1183 terminal,
1184 _drop_tx: Some(drop_tx),
1185 };
1186
1187 Ok(Rc::new(handle) as _)
1188 })
1189 }
1190}
1191
1192pub struct AcpTerminalHandle {
1193 terminal: Entity<acp_thread::Terminal>,
1194 _drop_tx: Option<oneshot::Sender<()>>,
1195}
1196
1197impl TerminalHandle for AcpTerminalHandle {
1198 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1199 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1200 }
1201
1202 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1203 self.terminal
1204 .read_with(cx, |term, _cx| term.wait_for_exit())
1205 }
1206
1207 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1208 self.terminal
1209 .read_with(cx, |term, cx| term.current_output(cx))
1210 }
1211
1212 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1213 cx.update(|cx| {
1214 self.terminal.update(cx, |terminal, cx| {
1215 terminal.kill(cx);
1216 });
1217 })?;
1218 Ok(())
1219 }
1220}
1221
1222#[cfg(test)]
1223mod internal_tests {
1224 use crate::HistoryEntryId;
1225
1226 use super::*;
1227 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1228 use fs::FakeFs;
1229 use gpui::TestAppContext;
1230 use indoc::formatdoc;
1231 use language_model::fake_provider::FakeLanguageModel;
1232 use serde_json::json;
1233 use settings::SettingsStore;
1234 use util::{path, rel_path::rel_path};
1235
1236 #[gpui::test]
1237 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1238 init_test(cx);
1239 let fs = FakeFs::new(cx.executor());
1240 fs.insert_tree(
1241 "/",
1242 json!({
1243 "a": {}
1244 }),
1245 )
1246 .await;
1247 let project = Project::test(fs.clone(), [], cx).await;
1248 let text_thread_store =
1249 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1250 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1251 let agent = NativeAgent::new(
1252 project.clone(),
1253 history_store,
1254 Templates::new(),
1255 None,
1256 fs.clone(),
1257 &mut cx.to_async(),
1258 )
1259 .await
1260 .unwrap();
1261 agent.read_with(cx, |agent, cx| {
1262 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1263 });
1264
1265 let worktree = project
1266 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1267 .await
1268 .unwrap();
1269 cx.run_until_parked();
1270 agent.read_with(cx, |agent, cx| {
1271 assert_eq!(
1272 agent.project_context.read(cx).worktrees,
1273 vec![WorktreeContext {
1274 root_name: "a".into(),
1275 abs_path: Path::new("/a").into(),
1276 rules_file: None
1277 }]
1278 )
1279 });
1280
1281 // Creating `/a/.rules` updates the project context.
1282 fs.insert_file("/a/.rules", Vec::new()).await;
1283 cx.run_until_parked();
1284 agent.read_with(cx, |agent, cx| {
1285 let rules_entry = worktree
1286 .read(cx)
1287 .entry_for_path(rel_path(".rules"))
1288 .unwrap();
1289 assert_eq!(
1290 agent.project_context.read(cx).worktrees,
1291 vec![WorktreeContext {
1292 root_name: "a".into(),
1293 abs_path: Path::new("/a").into(),
1294 rules_file: Some(RulesFileContext {
1295 path_in_worktree: rel_path(".rules").into(),
1296 text: "".into(),
1297 project_entry_id: rules_entry.id.to_usize()
1298 })
1299 }]
1300 )
1301 });
1302 }
1303
1304 #[gpui::test]
1305 async fn test_listing_models(cx: &mut TestAppContext) {
1306 init_test(cx);
1307 let fs = FakeFs::new(cx.executor());
1308 fs.insert_tree("/", json!({ "a": {} })).await;
1309 let project = Project::test(fs.clone(), [], cx).await;
1310 let text_thread_store =
1311 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1312 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1313 let connection = NativeAgentConnection(
1314 NativeAgent::new(
1315 project.clone(),
1316 history_store,
1317 Templates::new(),
1318 None,
1319 fs.clone(),
1320 &mut cx.to_async(),
1321 )
1322 .await
1323 .unwrap(),
1324 );
1325
1326 // Create a thread/session
1327 let acp_thread = cx
1328 .update(|cx| {
1329 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1330 })
1331 .await
1332 .unwrap();
1333
1334 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1335
1336 let models = cx
1337 .update(|cx| {
1338 connection
1339 .model_selector(&session_id)
1340 .unwrap()
1341 .list_models(cx)
1342 })
1343 .await
1344 .unwrap();
1345
1346 let acp_thread::AgentModelList::Grouped(models) = models else {
1347 panic!("Unexpected model group");
1348 };
1349 assert_eq!(
1350 models,
1351 IndexMap::from_iter([(
1352 AgentModelGroupName("Fake".into()),
1353 vec![AgentModelInfo {
1354 id: acp::ModelId::new("fake/fake"),
1355 name: "Fake".into(),
1356 description: None,
1357 icon: Some(ui::IconName::ZedAssistant),
1358 }]
1359 )])
1360 );
1361 }
1362
1363 #[gpui::test]
1364 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1365 init_test(cx);
1366 let fs = FakeFs::new(cx.executor());
1367 fs.create_dir(paths::settings_file().parent().unwrap())
1368 .await
1369 .unwrap();
1370 fs.insert_file(
1371 paths::settings_file(),
1372 json!({
1373 "agent": {
1374 "default_model": {
1375 "provider": "foo",
1376 "model": "bar"
1377 }
1378 }
1379 })
1380 .to_string()
1381 .into_bytes(),
1382 )
1383 .await;
1384 let project = Project::test(fs.clone(), [], cx).await;
1385
1386 let text_thread_store =
1387 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1388 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1389
1390 // Create the agent and connection
1391 let agent = NativeAgent::new(
1392 project.clone(),
1393 history_store,
1394 Templates::new(),
1395 None,
1396 fs.clone(),
1397 &mut cx.to_async(),
1398 )
1399 .await
1400 .unwrap();
1401 let connection = NativeAgentConnection(agent.clone());
1402
1403 // Create a thread/session
1404 let acp_thread = cx
1405 .update(|cx| {
1406 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1407 })
1408 .await
1409 .unwrap();
1410
1411 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1412
1413 // Select a model
1414 let selector = connection.model_selector(&session_id).unwrap();
1415 let model_id = acp::ModelId::new("fake/fake");
1416 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1417 .await
1418 .unwrap();
1419
1420 // Verify the thread has the selected model
1421 agent.read_with(cx, |agent, _| {
1422 let session = agent.sessions.get(&session_id).unwrap();
1423 session.thread.read_with(cx, |thread, _| {
1424 assert_eq!(thread.model().unwrap().id().0, "fake");
1425 });
1426 });
1427
1428 cx.run_until_parked();
1429
1430 // Verify settings file was updated
1431 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1432 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1433
1434 // Check that the agent settings contain the selected model
1435 assert_eq!(
1436 settings_json["agent"]["default_model"]["model"],
1437 json!("fake")
1438 );
1439 assert_eq!(
1440 settings_json["agent"]["default_model"]["provider"],
1441 json!("fake")
1442 );
1443 }
1444
1445 #[gpui::test]
1446 async fn test_save_load_thread(cx: &mut TestAppContext) {
1447 init_test(cx);
1448 let fs = FakeFs::new(cx.executor());
1449 fs.insert_tree(
1450 "/",
1451 json!({
1452 "a": {
1453 "b.md": "Lorem"
1454 }
1455 }),
1456 )
1457 .await;
1458 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1459 let text_thread_store =
1460 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1461 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1462 let agent = NativeAgent::new(
1463 project.clone(),
1464 history_store.clone(),
1465 Templates::new(),
1466 None,
1467 fs.clone(),
1468 &mut cx.to_async(),
1469 )
1470 .await
1471 .unwrap();
1472 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1473
1474 let acp_thread = cx
1475 .update(|cx| {
1476 connection
1477 .clone()
1478 .new_thread(project.clone(), Path::new(""), cx)
1479 })
1480 .await
1481 .unwrap();
1482 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1483 let thread = agent.read_with(cx, |agent, _| {
1484 agent.sessions.get(&session_id).unwrap().thread.clone()
1485 });
1486
1487 // Ensure empty threads are not saved, even if they get mutated.
1488 let model = Arc::new(FakeLanguageModel::default());
1489 let summary_model = Arc::new(FakeLanguageModel::default());
1490 thread.update(cx, |thread, cx| {
1491 thread.set_model(model.clone(), cx);
1492 thread.set_summarization_model(Some(summary_model.clone()), cx);
1493 });
1494 cx.run_until_parked();
1495 assert_eq!(history_entries(&history_store, cx), vec![]);
1496
1497 let send = acp_thread.update(cx, |thread, cx| {
1498 thread.send(
1499 vec![
1500 "What does ".into(),
1501 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1502 "b.md",
1503 MentionUri::File {
1504 abs_path: path!("/a/b.md").into(),
1505 }
1506 .to_uri()
1507 .to_string(),
1508 )),
1509 " mean?".into(),
1510 ],
1511 cx,
1512 )
1513 });
1514 let send = cx.foreground_executor().spawn(send);
1515 cx.run_until_parked();
1516
1517 model.send_last_completion_stream_text_chunk("Lorem.");
1518 model.end_last_completion_stream();
1519 cx.run_until_parked();
1520 summary_model
1521 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1522 summary_model.end_last_completion_stream();
1523
1524 send.await.unwrap();
1525 let uri = MentionUri::File {
1526 abs_path: path!("/a/b.md").into(),
1527 }
1528 .to_uri();
1529 acp_thread.read_with(cx, |thread, cx| {
1530 assert_eq!(
1531 thread.to_markdown(cx),
1532 formatdoc! {"
1533 ## User
1534
1535 What does [@b.md]({uri}) mean?
1536
1537 ## Assistant
1538
1539 Lorem.
1540
1541 "}
1542 )
1543 });
1544
1545 cx.run_until_parked();
1546
1547 // Drop the ACP thread, which should cause the session to be dropped as well.
1548 cx.update(|_| {
1549 drop(thread);
1550 drop(acp_thread);
1551 });
1552 agent.read_with(cx, |agent, _| {
1553 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1554 });
1555
1556 // Ensure the thread can be reloaded from disk.
1557 assert_eq!(
1558 history_entries(&history_store, cx),
1559 vec![(
1560 HistoryEntryId::AcpThread(session_id.clone()),
1561 format!("Explaining {}", path!("/a/b.md"))
1562 )]
1563 );
1564 let acp_thread = agent
1565 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1566 .await
1567 .unwrap();
1568 acp_thread.read_with(cx, |thread, cx| {
1569 assert_eq!(
1570 thread.to_markdown(cx),
1571 formatdoc! {"
1572 ## User
1573
1574 What does [@b.md]({uri}) mean?
1575
1576 ## Assistant
1577
1578 Lorem.
1579
1580 "}
1581 )
1582 });
1583 }
1584
1585 fn history_entries(
1586 history: &Entity<HistoryStore>,
1587 cx: &mut TestAppContext,
1588 ) -> Vec<(HistoryEntryId, String)> {
1589 history.read_with(cx, |history, _| {
1590 history
1591 .entries()
1592 .map(|e| (e.id(), e.title().to_string()))
1593 .collect::<Vec<_>>()
1594 })
1595 }
1596
1597 fn init_test(cx: &mut TestAppContext) {
1598 env_logger::try_init().ok();
1599 cx.update(|cx| {
1600 let settings_store = SettingsStore::test(cx);
1601 cx.set_global(settings_store);
1602
1603 LanguageModelRegistry::test(cx);
1604 });
1605 }
1606}