1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tools;
10
11#[cfg(test)]
12mod tests;
13
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
37 WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::collections::HashMap;
43use std::path::{Path, PathBuf};
44use std::rc::Rc;
45use std::sync::Arc;
46use util::ResultExt;
47use util::rel_path::RelPath;
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct ProjectSnapshot {
51 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
52 pub timestamp: DateTime<Utc>,
53}
54
55pub struct RulesLoadingError {
56 pub message: SharedString,
57}
58
59/// Holds both the internal Thread and the AcpThread for a session
60struct Session {
61 /// The internal thread that processes messages
62 thread: Entity<Thread>,
63 /// The ACP thread that handles protocol communication
64 acp_thread: WeakEntity<acp_thread::AcpThread>,
65 pending_save: Task<()>,
66 _subscriptions: Vec<Subscription>,
67}
68
69pub struct LanguageModels {
70 /// Access language model by ID
71 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
72 /// Cached list for returning language model information
73 model_list: acp_thread::AgentModelList,
74 refresh_models_rx: watch::Receiver<()>,
75 refresh_models_tx: watch::Sender<()>,
76 _authenticate_all_providers_task: Task<()>,
77}
78
79impl LanguageModels {
80 fn new(cx: &mut App) -> Self {
81 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
82
83 let mut this = Self {
84 models: HashMap::default(),
85 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
86 refresh_models_rx,
87 refresh_models_tx,
88 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
89 };
90 this.refresh_list(cx);
91 this
92 }
93
94 fn refresh_list(&mut self, cx: &App) {
95 let providers = LanguageModelRegistry::global(cx)
96 .read(cx)
97 .providers()
98 .into_iter()
99 .filter(|provider| provider.is_authenticated(cx))
100 .collect::<Vec<_>>();
101
102 let mut language_model_list = IndexMap::default();
103 let mut recommended_models = HashSet::default();
104
105 let mut recommended = Vec::new();
106 for provider in &providers {
107 for model in provider.recommended_models(cx) {
108 recommended_models.insert((model.provider_id(), model.id()));
109 recommended.push(Self::map_language_model_to_info(&model, provider));
110 }
111 }
112 if !recommended.is_empty() {
113 language_model_list.insert(
114 acp_thread::AgentModelGroupName("Recommended".into()),
115 recommended,
116 );
117 }
118
119 let mut models = HashMap::default();
120 for provider in providers {
121 let mut provider_models = Vec::new();
122 for model in provider.provided_models(cx) {
123 let model_info = Self::map_language_model_to_info(&model, &provider);
124 let model_id = model_info.id.clone();
125 provider_models.push(model_info);
126 models.insert(model_id, model);
127 }
128 if !provider_models.is_empty() {
129 language_model_list.insert(
130 acp_thread::AgentModelGroupName(provider.name().0.clone()),
131 provider_models,
132 );
133 }
134 }
135
136 self.models = models;
137 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
138 self.refresh_models_tx.send(()).ok();
139 }
140
141 fn watch(&self) -> watch::Receiver<()> {
142 self.refresh_models_rx.clone()
143 }
144
145 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
146 self.models.get(model_id).cloned()
147 }
148
149 fn map_language_model_to_info(
150 model: &Arc<dyn LanguageModel>,
151 provider: &Arc<dyn LanguageModelProvider>,
152 ) -> acp_thread::AgentModelInfo {
153 acp_thread::AgentModelInfo {
154 id: Self::model_id(model),
155 name: model.name().0,
156 description: None,
157 icon: Some(provider.icon()),
158 }
159 }
160
161 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
162 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
163 }
164
165 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
166 let authenticate_all_providers = LanguageModelRegistry::global(cx)
167 .read(cx)
168 .providers()
169 .iter()
170 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
171 .collect::<Vec<_>>();
172
173 cx.background_spawn(async move {
174 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
175 if let Err(err) = authenticate_task.await {
176 match err {
177 language_model::AuthenticateError::CredentialsNotFound => {
178 // Since we're authenticating these providers in the
179 // background for the purposes of populating the
180 // language selector, we don't care about providers
181 // where the credentials are not found.
182 }
183 language_model::AuthenticateError::ConnectionRefused => {
184 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
185 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
186 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
187 }
188 _ => {
189 // Some providers have noisy failure states that we
190 // don't want to spam the logs with every time the
191 // language model selector is initialized.
192 //
193 // Ideally these should have more clear failure modes
194 // that we know are safe to ignore here, like what we do
195 // with `CredentialsNotFound` above.
196 match provider_id.0.as_ref() {
197 "lmstudio" | "ollama" => {
198 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
199 //
200 // These fail noisily, so we don't log them.
201 }
202 "copilot_chat" => {
203 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
204 }
205 _ => {
206 log::error!(
207 "Failed to authenticate provider: {}: {err:#}",
208 provider_name.0
209 );
210 }
211 }
212 }
213 }
214 }
215 }
216 })
217 }
218}
219
220pub struct NativeAgent {
221 /// Session ID -> Session mapping
222 sessions: HashMap<acp::SessionId, Session>,
223 history: Entity<HistoryStore>,
224 /// Shared project context for all threads
225 project_context: Entity<ProjectContext>,
226 project_context_needs_refresh: watch::Sender<()>,
227 _maintain_project_context: Task<Result<()>>,
228 context_server_registry: Entity<ContextServerRegistry>,
229 /// Shared templates for all threads
230 templates: Arc<Templates>,
231 /// Cached model information
232 models: LanguageModels,
233 project: Entity<Project>,
234 prompt_store: Option<Entity<PromptStore>>,
235 fs: Arc<dyn Fs>,
236 _subscriptions: Vec<Subscription>,
237}
238
239impl NativeAgent {
240 pub async fn new(
241 project: Entity<Project>,
242 history: Entity<HistoryStore>,
243 templates: Arc<Templates>,
244 prompt_store: Option<Entity<PromptStore>>,
245 fs: Arc<dyn Fs>,
246 cx: &mut AsyncApp,
247 ) -> Result<Entity<NativeAgent>> {
248 log::debug!("Creating new NativeAgent");
249
250 let project_context = cx
251 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
252 .await;
253
254 cx.new(|cx| {
255 let mut subscriptions = vec![
256 cx.subscribe(&project, Self::handle_project_event),
257 cx.subscribe(
258 &LanguageModelRegistry::global(cx),
259 Self::handle_models_updated_event,
260 ),
261 ];
262 if let Some(prompt_store) = prompt_store.as_ref() {
263 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
264 }
265
266 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
267 watch::channel(());
268 Self {
269 sessions: HashMap::new(),
270 history,
271 project_context: cx.new(|_| project_context),
272 project_context_needs_refresh: project_context_needs_refresh_tx,
273 _maintain_project_context: cx.spawn(async move |this, cx| {
274 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
275 }),
276 context_server_registry: cx.new(|cx| {
277 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
278 }),
279 templates,
280 models: LanguageModels::new(cx),
281 project,
282 prompt_store,
283 fs,
284 _subscriptions: subscriptions,
285 }
286 })
287 }
288
289 fn register_session(
290 &mut self,
291 thread_handle: Entity<Thread>,
292 cx: &mut Context<Self>,
293 ) -> Entity<AcpThread> {
294 let connection = Rc::new(NativeAgentConnection(cx.entity()));
295
296 let thread = thread_handle.read(cx);
297 let session_id = thread.id().clone();
298 let title = thread.title();
299 let project = thread.project.clone();
300 let action_log = thread.action_log.clone();
301 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
302 let acp_thread = cx.new(|cx| {
303 acp_thread::AcpThread::new(
304 title,
305 connection,
306 project.clone(),
307 action_log.clone(),
308 session_id.clone(),
309 prompt_capabilities_rx,
310 cx,
311 )
312 });
313
314 let registry = LanguageModelRegistry::read_global(cx);
315 let summarization_model = registry.thread_summary_model().map(|c| c.model);
316
317 thread_handle.update(cx, |thread, cx| {
318 thread.set_summarization_model(summarization_model, cx);
319 thread.add_default_tools(
320 Rc::new(AcpThreadEnvironment {
321 acp_thread: acp_thread.downgrade(),
322 }) as _,
323 cx,
324 )
325 });
326
327 let subscriptions = vec![
328 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
329 this.sessions.remove(acp_thread.session_id());
330 }),
331 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
332 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
333 cx.observe(&thread_handle, move |this, thread, cx| {
334 this.save_thread(thread, cx)
335 }),
336 ];
337
338 self.sessions.insert(
339 session_id,
340 Session {
341 thread: thread_handle,
342 acp_thread: acp_thread.downgrade(),
343 _subscriptions: subscriptions,
344 pending_save: Task::ready(()),
345 },
346 );
347 acp_thread
348 }
349
350 pub fn models(&self) -> &LanguageModels {
351 &self.models
352 }
353
354 async fn maintain_project_context(
355 this: WeakEntity<Self>,
356 mut needs_refresh: watch::Receiver<()>,
357 cx: &mut AsyncApp,
358 ) -> Result<()> {
359 while needs_refresh.changed().await.is_ok() {
360 let project_context = this
361 .update(cx, |this, cx| {
362 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
363 })?
364 .await;
365 this.update(cx, |this, cx| {
366 this.project_context = cx.new(|_| project_context);
367 })?;
368 }
369
370 Ok(())
371 }
372
373 fn build_project_context(
374 project: &Entity<Project>,
375 prompt_store: Option<&Entity<PromptStore>>,
376 cx: &mut App,
377 ) -> Task<ProjectContext> {
378 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
379 let worktree_tasks = worktrees
380 .into_iter()
381 .map(|worktree| {
382 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
383 })
384 .collect::<Vec<_>>();
385 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
386 prompt_store.read_with(cx, |prompt_store, cx| {
387 let prompts = prompt_store.default_prompt_metadata();
388 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
389 let contents = prompt_store.load(prompt_metadata.id, cx);
390 async move { (contents.await, prompt_metadata) }
391 });
392 cx.background_spawn(future::join_all(load_tasks))
393 })
394 } else {
395 Task::ready(vec![])
396 };
397
398 cx.spawn(async move |_cx| {
399 let (worktrees, default_user_rules) =
400 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
401
402 let worktrees = worktrees
403 .into_iter()
404 .map(|(worktree, _rules_error)| {
405 // TODO: show error message
406 // if let Some(rules_error) = rules_error {
407 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
408 // }
409 worktree
410 })
411 .collect::<Vec<_>>();
412
413 let default_user_rules = default_user_rules
414 .into_iter()
415 .flat_map(|(contents, prompt_metadata)| match contents {
416 Ok(contents) => Some(UserRulesContext {
417 uuid: prompt_metadata.id.user_id()?,
418 title: prompt_metadata.title.map(|title| title.to_string()),
419 contents,
420 }),
421 Err(_err) => {
422 // TODO: show error message
423 // this.update(cx, |_, cx| {
424 // cx.emit(RulesLoadingError {
425 // message: format!("{err:?}").into(),
426 // });
427 // })
428 // .ok();
429 None
430 }
431 })
432 .collect::<Vec<_>>();
433
434 ProjectContext::new(worktrees, default_user_rules)
435 })
436 }
437
438 fn load_worktree_info_for_system_prompt(
439 worktree: Entity<Worktree>,
440 project: Entity<Project>,
441 cx: &mut App,
442 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
443 let tree = worktree.read(cx);
444 let root_name = tree.root_name_str().into();
445 let abs_path = tree.abs_path();
446
447 let mut context = WorktreeContext {
448 root_name,
449 abs_path,
450 rules_file: None,
451 };
452
453 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
454 let Some(rules_task) = rules_task else {
455 return Task::ready((context, None));
456 };
457
458 cx.spawn(async move |_| {
459 let (rules_file, rules_file_error) = match rules_task.await {
460 Ok(rules_file) => (Some(rules_file), None),
461 Err(err) => (
462 None,
463 Some(RulesLoadingError {
464 message: format!("{err}").into(),
465 }),
466 ),
467 };
468 context.rules_file = rules_file;
469 (context, rules_file_error)
470 })
471 }
472
473 fn load_worktree_rules_file(
474 worktree: Entity<Worktree>,
475 project: Entity<Project>,
476 cx: &mut App,
477 ) -> Option<Task<Result<RulesFileContext>>> {
478 let worktree = worktree.read(cx);
479 let worktree_id = worktree.id();
480 let selected_rules_file = RULES_FILE_NAMES
481 .into_iter()
482 .filter_map(|name| {
483 worktree
484 .entry_for_path(RelPath::unix(name).unwrap())
485 .filter(|entry| entry.is_file())
486 .map(|entry| entry.path.clone())
487 })
488 .next();
489
490 // Note that Cline supports `.clinerules` being a directory, but that is not currently
491 // supported. This doesn't seem to occur often in GitHub repositories.
492 selected_rules_file.map(|path_in_worktree| {
493 let project_path = ProjectPath {
494 worktree_id,
495 path: path_in_worktree.clone(),
496 };
497 let buffer_task =
498 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
499 let rope_task = cx.spawn(async move |cx| {
500 buffer_task.await?.read_with(cx, |buffer, cx| {
501 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
502 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
503 })?
504 });
505 // Build a string from the rope on a background thread.
506 cx.background_spawn(async move {
507 let (project_entry_id, rope) = rope_task.await?;
508 anyhow::Ok(RulesFileContext {
509 path_in_worktree,
510 text: rope.to_string().trim().to_string(),
511 project_entry_id: project_entry_id.to_usize(),
512 })
513 })
514 })
515 }
516
517 fn handle_thread_title_updated(
518 &mut self,
519 thread: Entity<Thread>,
520 _: &TitleUpdated,
521 cx: &mut Context<Self>,
522 ) {
523 let session_id = thread.read(cx).id();
524 let Some(session) = self.sessions.get(session_id) else {
525 return;
526 };
527 let thread = thread.downgrade();
528 let acp_thread = session.acp_thread.clone();
529 cx.spawn(async move |_, cx| {
530 let title = thread.read_with(cx, |thread, _| thread.title())?;
531 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
532 task.await
533 })
534 .detach_and_log_err(cx);
535 }
536
537 fn handle_thread_token_usage_updated(
538 &mut self,
539 thread: Entity<Thread>,
540 usage: &TokenUsageUpdated,
541 cx: &mut Context<Self>,
542 ) {
543 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
544 return;
545 };
546 session
547 .acp_thread
548 .update(cx, |acp_thread, cx| {
549 acp_thread.update_token_usage(usage.0.clone(), cx);
550 })
551 .ok();
552 }
553
554 fn handle_project_event(
555 &mut self,
556 _project: Entity<Project>,
557 event: &project::Event,
558 _cx: &mut Context<Self>,
559 ) {
560 match event {
561 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
562 self.project_context_needs_refresh.send(()).ok();
563 }
564 project::Event::WorktreeUpdatedEntries(_, items) => {
565 if items.iter().any(|(path, _, _)| {
566 RULES_FILE_NAMES
567 .iter()
568 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
569 }) {
570 self.project_context_needs_refresh.send(()).ok();
571 }
572 }
573 _ => {}
574 }
575 }
576
577 fn handle_prompts_updated_event(
578 &mut self,
579 _prompt_store: Entity<PromptStore>,
580 _event: &prompt_store::PromptsUpdatedEvent,
581 _cx: &mut Context<Self>,
582 ) {
583 self.project_context_needs_refresh.send(()).ok();
584 }
585
586 fn handle_models_updated_event(
587 &mut self,
588 _registry: Entity<LanguageModelRegistry>,
589 _event: &language_model::Event,
590 cx: &mut Context<Self>,
591 ) {
592 self.models.refresh_list(cx);
593
594 let registry = LanguageModelRegistry::read_global(cx);
595 let default_model = registry.default_model().map(|m| m.model);
596 let summarization_model = registry.thread_summary_model().map(|m| m.model);
597
598 for session in self.sessions.values_mut() {
599 session.thread.update(cx, |thread, cx| {
600 if thread.model().is_none()
601 && let Some(model) = default_model.clone()
602 {
603 thread.set_model(model, cx);
604 cx.notify();
605 }
606 thread.set_summarization_model(summarization_model.clone(), cx);
607 });
608 }
609 }
610
611 pub fn load_thread(
612 &mut self,
613 id: acp::SessionId,
614 cx: &mut Context<Self>,
615 ) -> Task<Result<Entity<Thread>>> {
616 let database_future = ThreadsDatabase::connect(cx);
617 cx.spawn(async move |this, cx| {
618 let database = database_future.await.map_err(|err| anyhow!(err))?;
619 let db_thread = database
620 .load_thread(id.clone())
621 .await?
622 .with_context(|| format!("no thread found with ID: {id:?}"))?;
623
624 this.update(cx, |this, cx| {
625 let summarization_model = LanguageModelRegistry::read_global(cx)
626 .thread_summary_model()
627 .map(|c| c.model);
628
629 cx.new(|cx| {
630 let mut thread = Thread::from_db(
631 id.clone(),
632 db_thread,
633 this.project.clone(),
634 this.project_context.clone(),
635 this.context_server_registry.clone(),
636 this.templates.clone(),
637 cx,
638 );
639 thread.set_summarization_model(summarization_model, cx);
640 thread
641 })
642 })
643 })
644 }
645
646 pub fn open_thread(
647 &mut self,
648 id: acp::SessionId,
649 cx: &mut Context<Self>,
650 ) -> Task<Result<Entity<AcpThread>>> {
651 let task = self.load_thread(id, cx);
652 cx.spawn(async move |this, cx| {
653 let thread = task.await?;
654 let acp_thread =
655 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
656 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
657 cx.update(|cx| {
658 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
659 })?
660 .await?;
661 Ok(acp_thread)
662 })
663 }
664
665 pub fn thread_summary(
666 &mut self,
667 id: acp::SessionId,
668 cx: &mut Context<Self>,
669 ) -> Task<Result<SharedString>> {
670 let thread = self.open_thread(id.clone(), cx);
671 cx.spawn(async move |this, cx| {
672 let acp_thread = thread.await?;
673 let result = this
674 .update(cx, |this, cx| {
675 this.sessions
676 .get(&id)
677 .unwrap()
678 .thread
679 .update(cx, |thread, cx| thread.summary(cx))
680 })?
681 .await
682 .context("Failed to generate summary")?;
683 drop(acp_thread);
684 Ok(result)
685 })
686 }
687
688 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
689 if thread.read(cx).is_empty() {
690 return;
691 }
692
693 let database_future = ThreadsDatabase::connect(cx);
694 let (id, db_thread) =
695 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
696 let Some(session) = self.sessions.get_mut(&id) else {
697 return;
698 };
699 let history = self.history.clone();
700 session.pending_save = cx.spawn(async move |_, cx| {
701 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
702 return;
703 };
704 let db_thread = db_thread.await;
705 database.save_thread(id, db_thread).await.log_err();
706 history.update(cx, |history, cx| history.reload(cx)).ok();
707 });
708 }
709}
710
711/// Wrapper struct that implements the AgentConnection trait
712#[derive(Clone)]
713pub struct NativeAgentConnection(pub Entity<NativeAgent>);
714
715impl NativeAgentConnection {
716 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
717 self.0
718 .read(cx)
719 .sessions
720 .get(session_id)
721 .map(|session| session.thread.clone())
722 }
723
724 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
725 self.0.update(cx, |this, cx| this.load_thread(id, cx))
726 }
727
728 fn run_turn(
729 &self,
730 session_id: acp::SessionId,
731 cx: &mut App,
732 f: impl 'static
733 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
734 ) -> Task<Result<acp::PromptResponse>> {
735 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
736 agent
737 .sessions
738 .get_mut(&session_id)
739 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
740 }) else {
741 return Task::ready(Err(anyhow!("Session not found")));
742 };
743 log::debug!("Found session for: {}", session_id);
744
745 let response_stream = match f(thread, cx) {
746 Ok(stream) => stream,
747 Err(err) => return Task::ready(Err(err)),
748 };
749 Self::handle_thread_events(response_stream, acp_thread, cx)
750 }
751
752 fn handle_thread_events(
753 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
754 acp_thread: WeakEntity<AcpThread>,
755 cx: &App,
756 ) -> Task<Result<acp::PromptResponse>> {
757 cx.spawn(async move |cx| {
758 // Handle response stream and forward to session.acp_thread
759 while let Some(result) = events.next().await {
760 match result {
761 Ok(event) => {
762 log::trace!("Received completion event: {:?}", event);
763
764 match event {
765 ThreadEvent::UserMessage(message) => {
766 acp_thread.update(cx, |thread, cx| {
767 for content in message.content {
768 thread.push_user_content_block(
769 Some(message.id.clone()),
770 content.into(),
771 cx,
772 );
773 }
774 })?;
775 }
776 ThreadEvent::AgentText(text) => {
777 acp_thread.update(cx, |thread, cx| {
778 thread.push_assistant_content_block(text.into(), false, cx)
779 })?;
780 }
781 ThreadEvent::AgentThinking(text) => {
782 acp_thread.update(cx, |thread, cx| {
783 thread.push_assistant_content_block(text.into(), true, cx)
784 })?;
785 }
786 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
787 tool_call,
788 options,
789 response,
790 }) => {
791 let outcome_task = acp_thread.update(cx, |thread, cx| {
792 thread.request_tool_call_authorization(
793 tool_call, options, true, cx,
794 )
795 })??;
796 cx.background_spawn(async move {
797 if let acp::RequestPermissionOutcome::Selected(
798 acp::SelectedPermissionOutcome { option_id, .. },
799 ) = outcome_task.await
800 {
801 response
802 .send(option_id)
803 .map(|_| anyhow!("authorization receiver was dropped"))
804 .log_err();
805 }
806 })
807 .detach();
808 }
809 ThreadEvent::ToolCall(tool_call) => {
810 acp_thread.update(cx, |thread, cx| {
811 thread.upsert_tool_call(tool_call, cx)
812 })??;
813 }
814 ThreadEvent::ToolCallUpdate(update) => {
815 acp_thread.update(cx, |thread, cx| {
816 thread.update_tool_call(update, cx)
817 })??;
818 }
819 ThreadEvent::Retry(status) => {
820 acp_thread.update(cx, |thread, cx| {
821 thread.update_retry_status(status, cx)
822 })?;
823 }
824 ThreadEvent::Stop(stop_reason) => {
825 log::debug!("Assistant message complete: {:?}", stop_reason);
826 return Ok(acp::PromptResponse::new(stop_reason));
827 }
828 }
829 }
830 Err(e) => {
831 log::error!("Error in model response stream: {:?}", e);
832 return Err(e);
833 }
834 }
835 }
836
837 log::debug!("Response stream completed");
838 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
839 })
840 }
841}
842
843struct NativeAgentModelSelector {
844 session_id: acp::SessionId,
845 connection: NativeAgentConnection,
846}
847
848impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
849 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
850 log::debug!("NativeAgentConnection::list_models called");
851 let list = self.connection.0.read(cx).models.model_list.clone();
852 Task::ready(if list.is_empty() {
853 Err(anyhow::anyhow!("No models available"))
854 } else {
855 Ok(list)
856 })
857 }
858
859 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
860 log::debug!(
861 "Setting model for session {}: {}",
862 self.session_id,
863 model_id
864 );
865 let Some(thread) = self
866 .connection
867 .0
868 .read(cx)
869 .sessions
870 .get(&self.session_id)
871 .map(|session| session.thread.clone())
872 else {
873 return Task::ready(Err(anyhow!("Session not found")));
874 };
875
876 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
877 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
878 };
879
880 thread.update(cx, |thread, cx| {
881 thread.set_model(model.clone(), cx);
882 });
883
884 update_settings_file(
885 self.connection.0.read(cx).fs.clone(),
886 cx,
887 move |settings, _cx| {
888 let provider = model.provider_id().0.to_string();
889 let model = model.id().0.to_string();
890 settings
891 .agent
892 .get_or_insert_default()
893 .set_model(LanguageModelSelection {
894 provider: provider.into(),
895 model,
896 });
897 },
898 );
899
900 Task::ready(Ok(()))
901 }
902
903 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
904 let Some(thread) = self
905 .connection
906 .0
907 .read(cx)
908 .sessions
909 .get(&self.session_id)
910 .map(|session| session.thread.clone())
911 else {
912 return Task::ready(Err(anyhow!("Session not found")));
913 };
914 let Some(model) = thread.read(cx).model() else {
915 return Task::ready(Err(anyhow!("Model not found")));
916 };
917 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
918 else {
919 return Task::ready(Err(anyhow!("Provider not found")));
920 };
921 Task::ready(Ok(LanguageModels::map_language_model_to_info(
922 model, &provider,
923 )))
924 }
925
926 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
927 Some(self.connection.0.read(cx).models.watch())
928 }
929
930 fn should_render_footer(&self) -> bool {
931 true
932 }
933}
934
935impl acp_thread::AgentConnection for NativeAgentConnection {
936 fn telemetry_id(&self) -> SharedString {
937 "zed".into()
938 }
939
940 fn new_thread(
941 self: Rc<Self>,
942 project: Entity<Project>,
943 cwd: &Path,
944 cx: &mut App,
945 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
946 let agent = self.0.clone();
947 log::debug!("Creating new thread for project at: {:?}", cwd);
948
949 cx.spawn(async move |cx| {
950 log::debug!("Starting thread creation in async context");
951
952 // Create Thread
953 let thread = agent.update(
954 cx,
955 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
956 // Fetch default model from registry settings
957 let registry = LanguageModelRegistry::read_global(cx);
958 // Log available models for debugging
959 let available_count = registry.available_models(cx).count();
960 log::debug!("Total available models: {}", available_count);
961
962 let default_model = registry.default_model().and_then(|default_model| {
963 agent
964 .models
965 .model_from_id(&LanguageModels::model_id(&default_model.model))
966 });
967 Ok(cx.new(|cx| {
968 Thread::new(
969 project.clone(),
970 agent.project_context.clone(),
971 agent.context_server_registry.clone(),
972 agent.templates.clone(),
973 default_model,
974 cx,
975 )
976 }))
977 },
978 )??;
979 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
980 })
981 }
982
983 fn auth_methods(&self) -> &[acp::AuthMethod] {
984 &[] // No auth for in-process
985 }
986
987 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
988 Task::ready(Ok(()))
989 }
990
991 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
992 Some(Rc::new(NativeAgentModelSelector {
993 session_id: session_id.clone(),
994 connection: self.clone(),
995 }) as Rc<dyn AgentModelSelector>)
996 }
997
998 fn prompt(
999 &self,
1000 id: Option<acp_thread::UserMessageId>,
1001 params: acp::PromptRequest,
1002 cx: &mut App,
1003 ) -> Task<Result<acp::PromptResponse>> {
1004 let id = id.expect("UserMessageId is required");
1005 let session_id = params.session_id.clone();
1006 log::info!("Received prompt request for session: {}", session_id);
1007 log::debug!("Prompt blocks count: {}", params.prompt.len());
1008 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1009
1010 self.run_turn(session_id, cx, move |thread, cx| {
1011 let content: Vec<UserMessageContent> = params
1012 .prompt
1013 .into_iter()
1014 .map(|block| UserMessageContent::from_content_block(block, path_style))
1015 .collect::<Vec<_>>();
1016 log::debug!("Converted prompt to message: {} chars", content.len());
1017 log::debug!("Message id: {:?}", id);
1018 log::debug!("Message content: {:?}", content);
1019
1020 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1021 })
1022 }
1023
1024 fn resume(
1025 &self,
1026 session_id: &acp::SessionId,
1027 _cx: &App,
1028 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1029 Some(Rc::new(NativeAgentSessionResume {
1030 connection: self.clone(),
1031 session_id: session_id.clone(),
1032 }) as _)
1033 }
1034
1035 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1036 log::info!("Cancelling on session: {}", session_id);
1037 self.0.update(cx, |agent, cx| {
1038 if let Some(agent) = agent.sessions.get(session_id) {
1039 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1040 }
1041 });
1042 }
1043
1044 fn truncate(
1045 &self,
1046 session_id: &agent_client_protocol::SessionId,
1047 cx: &App,
1048 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1049 self.0.read_with(cx, |agent, _cx| {
1050 agent.sessions.get(session_id).map(|session| {
1051 Rc::new(NativeAgentSessionTruncate {
1052 thread: session.thread.clone(),
1053 acp_thread: session.acp_thread.clone(),
1054 }) as _
1055 })
1056 })
1057 }
1058
1059 fn set_title(
1060 &self,
1061 session_id: &acp::SessionId,
1062 _cx: &App,
1063 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1064 Some(Rc::new(NativeAgentSessionSetTitle {
1065 connection: self.clone(),
1066 session_id: session_id.clone(),
1067 }) as _)
1068 }
1069
1070 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1071 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1072 }
1073
1074 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1075 self
1076 }
1077}
1078
1079impl acp_thread::AgentTelemetry for NativeAgentConnection {
1080 fn thread_data(
1081 &self,
1082 session_id: &acp::SessionId,
1083 cx: &mut App,
1084 ) -> Task<Result<serde_json::Value>> {
1085 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1086 return Task::ready(Err(anyhow!("Session not found")));
1087 };
1088
1089 let task = session.thread.read(cx).to_db(cx);
1090 cx.background_spawn(async move {
1091 serde_json::to_value(task.await).context("Failed to serialize thread")
1092 })
1093 }
1094}
1095
1096struct NativeAgentSessionTruncate {
1097 thread: Entity<Thread>,
1098 acp_thread: WeakEntity<AcpThread>,
1099}
1100
1101impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1102 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1103 match self.thread.update(cx, |thread, cx| {
1104 thread.truncate(message_id.clone(), cx)?;
1105 Ok(thread.latest_token_usage())
1106 }) {
1107 Ok(usage) => {
1108 self.acp_thread
1109 .update(cx, |thread, cx| {
1110 thread.update_token_usage(usage, cx);
1111 })
1112 .ok();
1113 Task::ready(Ok(()))
1114 }
1115 Err(error) => Task::ready(Err(error)),
1116 }
1117 }
1118}
1119
1120struct NativeAgentSessionResume {
1121 connection: NativeAgentConnection,
1122 session_id: acp::SessionId,
1123}
1124
1125impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1126 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1127 self.connection
1128 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1129 thread.update(cx, |thread, cx| thread.resume(cx))
1130 })
1131 }
1132}
1133
1134struct NativeAgentSessionSetTitle {
1135 connection: NativeAgentConnection,
1136 session_id: acp::SessionId,
1137}
1138
1139impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1140 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1141 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1142 return Task::ready(Err(anyhow!("session not found")));
1143 };
1144 let thread = session.thread.clone();
1145 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1146 Task::ready(Ok(()))
1147 }
1148}
1149
1150pub struct AcpThreadEnvironment {
1151 acp_thread: WeakEntity<AcpThread>,
1152}
1153
1154impl ThreadEnvironment for AcpThreadEnvironment {
1155 fn create_terminal(
1156 &self,
1157 command: String,
1158 cwd: Option<PathBuf>,
1159 output_byte_limit: Option<u64>,
1160 cx: &mut AsyncApp,
1161 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1162 let task = self.acp_thread.update(cx, |thread, cx| {
1163 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1164 });
1165
1166 let acp_thread = self.acp_thread.clone();
1167 cx.spawn(async move |cx| {
1168 let terminal = task?.await?;
1169
1170 let (drop_tx, drop_rx) = oneshot::channel();
1171 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1172
1173 cx.spawn(async move |cx| {
1174 drop_rx.await.ok();
1175 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1176 })
1177 .detach();
1178
1179 let handle = AcpTerminalHandle {
1180 terminal,
1181 _drop_tx: Some(drop_tx),
1182 };
1183
1184 Ok(Rc::new(handle) as _)
1185 })
1186 }
1187}
1188
1189pub struct AcpTerminalHandle {
1190 terminal: Entity<acp_thread::Terminal>,
1191 _drop_tx: Option<oneshot::Sender<()>>,
1192}
1193
1194impl TerminalHandle for AcpTerminalHandle {
1195 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1196 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1197 }
1198
1199 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1200 self.terminal
1201 .read_with(cx, |term, _cx| term.wait_for_exit())
1202 }
1203
1204 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1205 self.terminal
1206 .read_with(cx, |term, cx| term.current_output(cx))
1207 }
1208
1209 fn kill(&self, cx: &AsyncApp) -> Result<()> {
1210 cx.update(|cx| {
1211 self.terminal.update(cx, |terminal, cx| {
1212 terminal.kill(cx);
1213 });
1214 })?;
1215 Ok(())
1216 }
1217}
1218
1219#[cfg(test)]
1220mod internal_tests {
1221 use crate::HistoryEntryId;
1222
1223 use super::*;
1224 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1225 use fs::FakeFs;
1226 use gpui::TestAppContext;
1227 use indoc::formatdoc;
1228 use language_model::fake_provider::FakeLanguageModel;
1229 use serde_json::json;
1230 use settings::SettingsStore;
1231 use util::{path, rel_path::rel_path};
1232
1233 #[gpui::test]
1234 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1235 init_test(cx);
1236 let fs = FakeFs::new(cx.executor());
1237 fs.insert_tree(
1238 "/",
1239 json!({
1240 "a": {}
1241 }),
1242 )
1243 .await;
1244 let project = Project::test(fs.clone(), [], cx).await;
1245 let text_thread_store =
1246 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1247 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1248 let agent = NativeAgent::new(
1249 project.clone(),
1250 history_store,
1251 Templates::new(),
1252 None,
1253 fs.clone(),
1254 &mut cx.to_async(),
1255 )
1256 .await
1257 .unwrap();
1258 agent.read_with(cx, |agent, cx| {
1259 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1260 });
1261
1262 let worktree = project
1263 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1264 .await
1265 .unwrap();
1266 cx.run_until_parked();
1267 agent.read_with(cx, |agent, cx| {
1268 assert_eq!(
1269 agent.project_context.read(cx).worktrees,
1270 vec![WorktreeContext {
1271 root_name: "a".into(),
1272 abs_path: Path::new("/a").into(),
1273 rules_file: None
1274 }]
1275 )
1276 });
1277
1278 // Creating `/a/.rules` updates the project context.
1279 fs.insert_file("/a/.rules", Vec::new()).await;
1280 cx.run_until_parked();
1281 agent.read_with(cx, |agent, cx| {
1282 let rules_entry = worktree
1283 .read(cx)
1284 .entry_for_path(rel_path(".rules"))
1285 .unwrap();
1286 assert_eq!(
1287 agent.project_context.read(cx).worktrees,
1288 vec![WorktreeContext {
1289 root_name: "a".into(),
1290 abs_path: Path::new("/a").into(),
1291 rules_file: Some(RulesFileContext {
1292 path_in_worktree: rel_path(".rules").into(),
1293 text: "".into(),
1294 project_entry_id: rules_entry.id.to_usize()
1295 })
1296 }]
1297 )
1298 });
1299 }
1300
1301 #[gpui::test]
1302 async fn test_listing_models(cx: &mut TestAppContext) {
1303 init_test(cx);
1304 let fs = FakeFs::new(cx.executor());
1305 fs.insert_tree("/", json!({ "a": {} })).await;
1306 let project = Project::test(fs.clone(), [], cx).await;
1307 let text_thread_store =
1308 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1309 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1310 let connection = NativeAgentConnection(
1311 NativeAgent::new(
1312 project.clone(),
1313 history_store,
1314 Templates::new(),
1315 None,
1316 fs.clone(),
1317 &mut cx.to_async(),
1318 )
1319 .await
1320 .unwrap(),
1321 );
1322
1323 // Create a thread/session
1324 let acp_thread = cx
1325 .update(|cx| {
1326 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1327 })
1328 .await
1329 .unwrap();
1330
1331 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1332
1333 let models = cx
1334 .update(|cx| {
1335 connection
1336 .model_selector(&session_id)
1337 .unwrap()
1338 .list_models(cx)
1339 })
1340 .await
1341 .unwrap();
1342
1343 let acp_thread::AgentModelList::Grouped(models) = models else {
1344 panic!("Unexpected model group");
1345 };
1346 assert_eq!(
1347 models,
1348 IndexMap::from_iter([(
1349 AgentModelGroupName("Fake".into()),
1350 vec![AgentModelInfo {
1351 id: acp::ModelId::new("fake/fake"),
1352 name: "Fake".into(),
1353 description: None,
1354 icon: Some(ui::IconName::ZedAssistant),
1355 }]
1356 )])
1357 );
1358 }
1359
1360 #[gpui::test]
1361 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1362 init_test(cx);
1363 let fs = FakeFs::new(cx.executor());
1364 fs.create_dir(paths::settings_file().parent().unwrap())
1365 .await
1366 .unwrap();
1367 fs.insert_file(
1368 paths::settings_file(),
1369 json!({
1370 "agent": {
1371 "default_model": {
1372 "provider": "foo",
1373 "model": "bar"
1374 }
1375 }
1376 })
1377 .to_string()
1378 .into_bytes(),
1379 )
1380 .await;
1381 let project = Project::test(fs.clone(), [], cx).await;
1382
1383 let text_thread_store =
1384 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1385 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1386
1387 // Create the agent and connection
1388 let agent = NativeAgent::new(
1389 project.clone(),
1390 history_store,
1391 Templates::new(),
1392 None,
1393 fs.clone(),
1394 &mut cx.to_async(),
1395 )
1396 .await
1397 .unwrap();
1398 let connection = NativeAgentConnection(agent.clone());
1399
1400 // Create a thread/session
1401 let acp_thread = cx
1402 .update(|cx| {
1403 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1404 })
1405 .await
1406 .unwrap();
1407
1408 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1409
1410 // Select a model
1411 let selector = connection.model_selector(&session_id).unwrap();
1412 let model_id = acp::ModelId::new("fake/fake");
1413 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1414 .await
1415 .unwrap();
1416
1417 // Verify the thread has the selected model
1418 agent.read_with(cx, |agent, _| {
1419 let session = agent.sessions.get(&session_id).unwrap();
1420 session.thread.read_with(cx, |thread, _| {
1421 assert_eq!(thread.model().unwrap().id().0, "fake");
1422 });
1423 });
1424
1425 cx.run_until_parked();
1426
1427 // Verify settings file was updated
1428 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1429 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1430
1431 // Check that the agent settings contain the selected model
1432 assert_eq!(
1433 settings_json["agent"]["default_model"]["model"],
1434 json!("fake")
1435 );
1436 assert_eq!(
1437 settings_json["agent"]["default_model"]["provider"],
1438 json!("fake")
1439 );
1440 }
1441
1442 #[gpui::test]
1443 async fn test_save_load_thread(cx: &mut TestAppContext) {
1444 init_test(cx);
1445 let fs = FakeFs::new(cx.executor());
1446 fs.insert_tree(
1447 "/",
1448 json!({
1449 "a": {
1450 "b.md": "Lorem"
1451 }
1452 }),
1453 )
1454 .await;
1455 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1456 let text_thread_store =
1457 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1458 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1459 let agent = NativeAgent::new(
1460 project.clone(),
1461 history_store.clone(),
1462 Templates::new(),
1463 None,
1464 fs.clone(),
1465 &mut cx.to_async(),
1466 )
1467 .await
1468 .unwrap();
1469 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1470
1471 let acp_thread = cx
1472 .update(|cx| {
1473 connection
1474 .clone()
1475 .new_thread(project.clone(), Path::new(""), cx)
1476 })
1477 .await
1478 .unwrap();
1479 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1480 let thread = agent.read_with(cx, |agent, _| {
1481 agent.sessions.get(&session_id).unwrap().thread.clone()
1482 });
1483
1484 // Ensure empty threads are not saved, even if they get mutated.
1485 let model = Arc::new(FakeLanguageModel::default());
1486 let summary_model = Arc::new(FakeLanguageModel::default());
1487 thread.update(cx, |thread, cx| {
1488 thread.set_model(model.clone(), cx);
1489 thread.set_summarization_model(Some(summary_model.clone()), cx);
1490 });
1491 cx.run_until_parked();
1492 assert_eq!(history_entries(&history_store, cx), vec![]);
1493
1494 let send = acp_thread.update(cx, |thread, cx| {
1495 thread.send(
1496 vec![
1497 "What does ".into(),
1498 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1499 "b.md",
1500 MentionUri::File {
1501 abs_path: path!("/a/b.md").into(),
1502 }
1503 .to_uri()
1504 .to_string(),
1505 )),
1506 " mean?".into(),
1507 ],
1508 cx,
1509 )
1510 });
1511 let send = cx.foreground_executor().spawn(send);
1512 cx.run_until_parked();
1513
1514 model.send_last_completion_stream_text_chunk("Lorem.");
1515 model.end_last_completion_stream();
1516 cx.run_until_parked();
1517 summary_model
1518 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1519 summary_model.end_last_completion_stream();
1520
1521 send.await.unwrap();
1522 let uri = MentionUri::File {
1523 abs_path: path!("/a/b.md").into(),
1524 }
1525 .to_uri();
1526 acp_thread.read_with(cx, |thread, cx| {
1527 assert_eq!(
1528 thread.to_markdown(cx),
1529 formatdoc! {"
1530 ## User
1531
1532 What does [@b.md]({uri}) mean?
1533
1534 ## Assistant
1535
1536 Lorem.
1537
1538 "}
1539 )
1540 });
1541
1542 cx.run_until_parked();
1543
1544 // Drop the ACP thread, which should cause the session to be dropped as well.
1545 cx.update(|_| {
1546 drop(thread);
1547 drop(acp_thread);
1548 });
1549 agent.read_with(cx, |agent, _| {
1550 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1551 });
1552
1553 // Ensure the thread can be reloaded from disk.
1554 assert_eq!(
1555 history_entries(&history_store, cx),
1556 vec![(
1557 HistoryEntryId::AcpThread(session_id.clone()),
1558 format!("Explaining {}", path!("/a/b.md"))
1559 )]
1560 );
1561 let acp_thread = agent
1562 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1563 .await
1564 .unwrap();
1565 acp_thread.read_with(cx, |thread, cx| {
1566 assert_eq!(
1567 thread.to_markdown(cx),
1568 formatdoc! {"
1569 ## User
1570
1571 What does [@b.md]({uri}) mean?
1572
1573 ## Assistant
1574
1575 Lorem.
1576
1577 "}
1578 )
1579 });
1580 }
1581
1582 fn history_entries(
1583 history: &Entity<HistoryStore>,
1584 cx: &mut TestAppContext,
1585 ) -> Vec<(HistoryEntryId, String)> {
1586 history.read_with(cx, |history, _| {
1587 history
1588 .entries()
1589 .map(|e| (e.id(), e.title().to_string()))
1590 .collect::<Vec<_>>()
1591 })
1592 }
1593
1594 fn init_test(cx: &mut TestAppContext) {
1595 env_logger::try_init().ok();
1596 cx.update(|cx| {
1597 let settings_store = SettingsStore::test(cx);
1598 cx.set_global(settings_store);
1599
1600 LanguageModelRegistry::test(cx);
1601 });
1602 }
1603}