1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tools;
10
11#[cfg(test)]
12mod tests;
13
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
37};
38use serde::{Deserialize, Serialize};
39use settings::{LanguageModelSelection, update_settings_file};
40use std::any::Any;
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54const RULES_FILE_NAMES: [&str; 9] = [
55 ".rules",
56 ".cursorrules",
57 ".windsurfrules",
58 ".clinerules",
59 ".github/copilot-instructions.md",
60 "CLAUDE.md",
61 "AGENT.md",
62 "AGENTS.md",
63 "GEMINI.md",
64];
65
66pub struct RulesLoadingError {
67 pub message: SharedString,
68}
69
70/// Holds both the internal Thread and the AcpThread for a session
71struct Session {
72 /// The internal thread that processes messages
73 thread: Entity<Thread>,
74 /// The ACP thread that handles protocol communication
75 acp_thread: WeakEntity<acp_thread::AcpThread>,
76 pending_save: Task<()>,
77 _subscriptions: Vec<Subscription>,
78}
79
80pub struct LanguageModels {
81 /// Access language model by ID
82 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
83 /// Cached list for returning language model information
84 model_list: acp_thread::AgentModelList,
85 refresh_models_rx: watch::Receiver<()>,
86 refresh_models_tx: watch::Sender<()>,
87 _authenticate_all_providers_task: Task<()>,
88}
89
90impl LanguageModels {
91 fn new(cx: &mut App) -> Self {
92 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
93
94 let mut this = Self {
95 models: HashMap::default(),
96 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
97 refresh_models_rx,
98 refresh_models_tx,
99 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
100 };
101 this.refresh_list(cx);
102 this
103 }
104
105 fn refresh_list(&mut self, cx: &App) {
106 let providers = LanguageModelRegistry::global(cx)
107 .read(cx)
108 .visible_providers()
109 .into_iter()
110 .filter(|provider| provider.is_authenticated(cx))
111 .collect::<Vec<_>>();
112
113 let mut language_model_list = IndexMap::default();
114 let mut recommended_models = HashSet::default();
115
116 let mut recommended = Vec::new();
117 for provider in &providers {
118 for model in provider.recommended_models(cx) {
119 recommended_models.insert((model.provider_id(), model.id()));
120 recommended.push(Self::map_language_model_to_info(&model, provider));
121 }
122 }
123 if !recommended.is_empty() {
124 language_model_list.insert(
125 acp_thread::AgentModelGroupName("Recommended".into()),
126 recommended,
127 );
128 }
129
130 let mut models = HashMap::default();
131 for provider in providers {
132 let mut provider_models = Vec::new();
133 for model in provider.provided_models(cx) {
134 let model_info = Self::map_language_model_to_info(&model, &provider);
135 let model_id = model_info.id.clone();
136 provider_models.push(model_info);
137 models.insert(model_id, model);
138 }
139 if !provider_models.is_empty() {
140 language_model_list.insert(
141 acp_thread::AgentModelGroupName(provider.name().0.clone()),
142 provider_models,
143 );
144 }
145 }
146
147 self.models = models;
148 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
149 self.refresh_models_tx.send(()).ok();
150 }
151
152 fn watch(&self) -> watch::Receiver<()> {
153 self.refresh_models_rx.clone()
154 }
155
156 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
157 self.models.get(model_id).cloned()
158 }
159
160 fn map_language_model_to_info(
161 model: &Arc<dyn LanguageModel>,
162 provider: &Arc<dyn LanguageModelProvider>,
163 ) -> acp_thread::AgentModelInfo {
164 let icon = if let Some(path) = provider.icon_path() {
165 Some(AgentModelIcon::Path(path))
166 } else {
167 Some(AgentModelIcon::Named(provider.icon()))
168 };
169 acp_thread::AgentModelInfo {
170 id: Self::model_id(model),
171 name: model.name().0,
172 description: None,
173 icon,
174 }
175 }
176
177 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
178 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
179 }
180
181 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
182 let authenticate_all_providers = LanguageModelRegistry::global(cx)
183 .read(cx)
184 .providers()
185 .iter()
186 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
187 .collect::<Vec<_>>();
188
189 cx.background_spawn(async move {
190 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
191 if let Err(err) = authenticate_task.await {
192 match err {
193 language_model::AuthenticateError::CredentialsNotFound => {
194 // Since we're authenticating these providers in the
195 // background for the purposes of populating the
196 // language selector, we don't care about providers
197 // where the credentials are not found.
198 }
199 language_model::AuthenticateError::ConnectionRefused => {
200 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
201 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
202 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
203 }
204 _ => {
205 // Some providers have noisy failure states that we
206 // don't want to spam the logs with every time the
207 // language model selector is initialized.
208 //
209 // Ideally these should have more clear failure modes
210 // that we know are safe to ignore here, like what we do
211 // with `CredentialsNotFound` above.
212 match provider_id.0.as_ref() {
213 "lmstudio" | "ollama" => {
214 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
215 //
216 // These fail noisily, so we don't log them.
217 }
218 "copilot_chat" => {
219 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
220 }
221 _ => {
222 log::error!(
223 "Failed to authenticate provider: {}: {err:#}",
224 provider_name.0
225 );
226 }
227 }
228 }
229 }
230 }
231 }
232 })
233 }
234}
235
236pub struct NativeAgent {
237 /// Session ID -> Session mapping
238 sessions: HashMap<acp::SessionId, Session>,
239 history: Entity<HistoryStore>,
240 /// Shared project context for all threads
241 project_context: Entity<ProjectContext>,
242 project_context_needs_refresh: watch::Sender<()>,
243 _maintain_project_context: Task<Result<()>>,
244 context_server_registry: Entity<ContextServerRegistry>,
245 /// Shared templates for all threads
246 templates: Arc<Templates>,
247 /// Cached model information
248 models: LanguageModels,
249 project: Entity<Project>,
250 prompt_store: Option<Entity<PromptStore>>,
251 fs: Arc<dyn Fs>,
252 _subscriptions: Vec<Subscription>,
253}
254
255impl NativeAgent {
256 pub async fn new(
257 project: Entity<Project>,
258 history: Entity<HistoryStore>,
259 templates: Arc<Templates>,
260 prompt_store: Option<Entity<PromptStore>>,
261 fs: Arc<dyn Fs>,
262 cx: &mut AsyncApp,
263 ) -> Result<Entity<NativeAgent>> {
264 log::debug!("Creating new NativeAgent");
265
266 let project_context = cx
267 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
268 .await;
269
270 cx.new(|cx| {
271 let mut subscriptions = vec![
272 cx.subscribe(&project, Self::handle_project_event),
273 cx.subscribe(
274 &LanguageModelRegistry::global(cx),
275 Self::handle_models_updated_event,
276 ),
277 ];
278 if let Some(prompt_store) = prompt_store.as_ref() {
279 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
280 }
281
282 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
283 watch::channel(());
284 Self {
285 sessions: HashMap::new(),
286 history,
287 project_context: cx.new(|_| project_context),
288 project_context_needs_refresh: project_context_needs_refresh_tx,
289 _maintain_project_context: cx.spawn(async move |this, cx| {
290 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
291 }),
292 context_server_registry: cx.new(|cx| {
293 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
294 }),
295 templates,
296 models: LanguageModels::new(cx),
297 project,
298 prompt_store,
299 fs,
300 _subscriptions: subscriptions,
301 }
302 })
303 }
304
305 fn register_session(
306 &mut self,
307 thread_handle: Entity<Thread>,
308 cx: &mut Context<Self>,
309 ) -> Entity<AcpThread> {
310 let connection = Rc::new(NativeAgentConnection(cx.entity()));
311
312 let thread = thread_handle.read(cx);
313 let session_id = thread.id().clone();
314 let title = thread.title();
315 let project = thread.project.clone();
316 let action_log = thread.action_log.clone();
317 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
318 let acp_thread = cx.new(|cx| {
319 acp_thread::AcpThread::new(
320 title,
321 connection,
322 project.clone(),
323 action_log.clone(),
324 session_id.clone(),
325 prompt_capabilities_rx,
326 cx,
327 )
328 });
329
330 let registry = LanguageModelRegistry::read_global(cx);
331 let summarization_model = registry.thread_summary_model().map(|c| c.model);
332
333 thread_handle.update(cx, |thread, cx| {
334 thread.set_summarization_model(summarization_model, cx);
335 thread.add_default_tools(
336 Rc::new(AcpThreadEnvironment {
337 acp_thread: acp_thread.downgrade(),
338 }) as _,
339 cx,
340 )
341 });
342
343 let subscriptions = vec![
344 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
345 this.sessions.remove(acp_thread.session_id());
346 }),
347 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
348 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
349 cx.observe(&thread_handle, move |this, thread, cx| {
350 this.save_thread(thread, cx)
351 }),
352 ];
353
354 self.sessions.insert(
355 session_id,
356 Session {
357 thread: thread_handle,
358 acp_thread: acp_thread.downgrade(),
359 _subscriptions: subscriptions,
360 pending_save: Task::ready(()),
361 },
362 );
363 acp_thread
364 }
365
366 pub fn models(&self) -> &LanguageModels {
367 &self.models
368 }
369
370 async fn maintain_project_context(
371 this: WeakEntity<Self>,
372 mut needs_refresh: watch::Receiver<()>,
373 cx: &mut AsyncApp,
374 ) -> Result<()> {
375 while needs_refresh.changed().await.is_ok() {
376 let project_context = this
377 .update(cx, |this, cx| {
378 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
379 })?
380 .await;
381 this.update(cx, |this, cx| {
382 this.project_context = cx.new(|_| project_context);
383 })?;
384 }
385
386 Ok(())
387 }
388
389 fn build_project_context(
390 project: &Entity<Project>,
391 prompt_store: Option<&Entity<PromptStore>>,
392 cx: &mut App,
393 ) -> Task<ProjectContext> {
394 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
395 let worktree_tasks = worktrees
396 .into_iter()
397 .map(|worktree| {
398 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
399 })
400 .collect::<Vec<_>>();
401 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
402 prompt_store.read_with(cx, |prompt_store, cx| {
403 let prompts = prompt_store.default_prompt_metadata();
404 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
405 let contents = prompt_store.load(prompt_metadata.id, cx);
406 async move { (contents.await, prompt_metadata) }
407 });
408 cx.background_spawn(future::join_all(load_tasks))
409 })
410 } else {
411 Task::ready(vec![])
412 };
413
414 cx.spawn(async move |_cx| {
415 let (worktrees, default_user_rules) =
416 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
417
418 let worktrees = worktrees
419 .into_iter()
420 .map(|(worktree, _rules_error)| {
421 // TODO: show error message
422 // if let Some(rules_error) = rules_error {
423 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
424 // }
425 worktree
426 })
427 .collect::<Vec<_>>();
428
429 let default_user_rules = default_user_rules
430 .into_iter()
431 .flat_map(|(contents, prompt_metadata)| match contents {
432 Ok(contents) => Some(UserRulesContext {
433 uuid: match prompt_metadata.id {
434 prompt_store::PromptId::User { uuid } => uuid,
435 prompt_store::PromptId::EditWorkflow => return None,
436 },
437 title: prompt_metadata.title.map(|title| title.to_string()),
438 contents,
439 }),
440 Err(_err) => {
441 // TODO: show error message
442 // this.update(cx, |_, cx| {
443 // cx.emit(RulesLoadingError {
444 // message: format!("{err:?}").into(),
445 // });
446 // })
447 // .ok();
448 None
449 }
450 })
451 .collect::<Vec<_>>();
452
453 ProjectContext::new(worktrees, default_user_rules)
454 })
455 }
456
457 fn load_worktree_info_for_system_prompt(
458 worktree: Entity<Worktree>,
459 project: Entity<Project>,
460 cx: &mut App,
461 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
462 let tree = worktree.read(cx);
463 let root_name = tree.root_name_str().into();
464 let abs_path = tree.abs_path();
465
466 let mut context = WorktreeContext {
467 root_name,
468 abs_path,
469 rules_file: None,
470 };
471
472 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
473 let Some(rules_task) = rules_task else {
474 return Task::ready((context, None));
475 };
476
477 cx.spawn(async move |_| {
478 let (rules_file, rules_file_error) = match rules_task.await {
479 Ok(rules_file) => (Some(rules_file), None),
480 Err(err) => (
481 None,
482 Some(RulesLoadingError {
483 message: format!("{err}").into(),
484 }),
485 ),
486 };
487 context.rules_file = rules_file;
488 (context, rules_file_error)
489 })
490 }
491
492 fn load_worktree_rules_file(
493 worktree: Entity<Worktree>,
494 project: Entity<Project>,
495 cx: &mut App,
496 ) -> Option<Task<Result<RulesFileContext>>> {
497 let worktree = worktree.read(cx);
498 let worktree_id = worktree.id();
499 let selected_rules_file = RULES_FILE_NAMES
500 .into_iter()
501 .filter_map(|name| {
502 worktree
503 .entry_for_path(RelPath::unix(name).unwrap())
504 .filter(|entry| entry.is_file())
505 .map(|entry| entry.path.clone())
506 })
507 .next();
508
509 // Note that Cline supports `.clinerules` being a directory, but that is not currently
510 // supported. This doesn't seem to occur often in GitHub repositories.
511 selected_rules_file.map(|path_in_worktree| {
512 let project_path = ProjectPath {
513 worktree_id,
514 path: path_in_worktree.clone(),
515 };
516 let buffer_task =
517 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
518 let rope_task = cx.spawn(async move |cx| {
519 buffer_task.await?.read_with(cx, |buffer, cx| {
520 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
521 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
522 })?
523 });
524 // Build a string from the rope on a background thread.
525 cx.background_spawn(async move {
526 let (project_entry_id, rope) = rope_task.await?;
527 anyhow::Ok(RulesFileContext {
528 path_in_worktree,
529 text: rope.to_string().trim().to_string(),
530 project_entry_id: project_entry_id.to_usize(),
531 })
532 })
533 })
534 }
535
536 fn handle_thread_title_updated(
537 &mut self,
538 thread: Entity<Thread>,
539 _: &TitleUpdated,
540 cx: &mut Context<Self>,
541 ) {
542 let session_id = thread.read(cx).id();
543 let Some(session) = self.sessions.get(session_id) else {
544 return;
545 };
546 let thread = thread.downgrade();
547 let acp_thread = session.acp_thread.clone();
548 cx.spawn(async move |_, cx| {
549 let title = thread.read_with(cx, |thread, _| thread.title())?;
550 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
551 task.await
552 })
553 .detach_and_log_err(cx);
554 }
555
556 fn handle_thread_token_usage_updated(
557 &mut self,
558 thread: Entity<Thread>,
559 usage: &TokenUsageUpdated,
560 cx: &mut Context<Self>,
561 ) {
562 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
563 return;
564 };
565 session
566 .acp_thread
567 .update(cx, |acp_thread, cx| {
568 acp_thread.update_token_usage(usage.0.clone(), cx);
569 })
570 .ok();
571 }
572
573 fn handle_project_event(
574 &mut self,
575 _project: Entity<Project>,
576 event: &project::Event,
577 _cx: &mut Context<Self>,
578 ) {
579 match event {
580 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
581 self.project_context_needs_refresh.send(()).ok();
582 }
583 project::Event::WorktreeUpdatedEntries(_, items) => {
584 if items.iter().any(|(path, _, _)| {
585 RULES_FILE_NAMES
586 .iter()
587 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
588 }) {
589 self.project_context_needs_refresh.send(()).ok();
590 }
591 }
592 _ => {}
593 }
594 }
595
596 fn handle_prompts_updated_event(
597 &mut self,
598 _prompt_store: Entity<PromptStore>,
599 _event: &prompt_store::PromptsUpdatedEvent,
600 _cx: &mut Context<Self>,
601 ) {
602 self.project_context_needs_refresh.send(()).ok();
603 }
604
605 fn handle_models_updated_event(
606 &mut self,
607 _registry: Entity<LanguageModelRegistry>,
608 _event: &language_model::Event,
609 cx: &mut Context<Self>,
610 ) {
611 self.models.refresh_list(cx);
612
613 let registry = LanguageModelRegistry::read_global(cx);
614 let default_model = registry.default_model().map(|m| m.model);
615 let summarization_model = registry.thread_summary_model().map(|m| m.model);
616
617 for session in self.sessions.values_mut() {
618 session.thread.update(cx, |thread, cx| {
619 if thread.model().is_none()
620 && let Some(model) = default_model.clone()
621 {
622 thread.set_model(model, cx);
623 cx.notify();
624 }
625 thread.set_summarization_model(summarization_model.clone(), cx);
626 });
627 }
628 }
629
630 pub fn load_thread(
631 &mut self,
632 id: acp::SessionId,
633 cx: &mut Context<Self>,
634 ) -> Task<Result<Entity<Thread>>> {
635 let database_future = ThreadsDatabase::connect(cx);
636 cx.spawn(async move |this, cx| {
637 let database = database_future.await.map_err(|err| anyhow!(err))?;
638 let db_thread = database
639 .load_thread(id.clone())
640 .await?
641 .with_context(|| format!("no thread found with ID: {id:?}"))?;
642
643 this.update(cx, |this, cx| {
644 let summarization_model = LanguageModelRegistry::read_global(cx)
645 .thread_summary_model()
646 .map(|c| c.model);
647
648 cx.new(|cx| {
649 let mut thread = Thread::from_db(
650 id.clone(),
651 db_thread,
652 this.project.clone(),
653 this.project_context.clone(),
654 this.context_server_registry.clone(),
655 this.templates.clone(),
656 cx,
657 );
658 thread.set_summarization_model(summarization_model, cx);
659 thread
660 })
661 })
662 })
663 }
664
665 pub fn open_thread(
666 &mut self,
667 id: acp::SessionId,
668 cx: &mut Context<Self>,
669 ) -> Task<Result<Entity<AcpThread>>> {
670 let task = self.load_thread(id, cx);
671 cx.spawn(async move |this, cx| {
672 let thread = task.await?;
673 let acp_thread =
674 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
675 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
676 cx.update(|cx| {
677 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
678 })?
679 .await?;
680 Ok(acp_thread)
681 })
682 }
683
684 pub fn thread_summary(
685 &mut self,
686 id: acp::SessionId,
687 cx: &mut Context<Self>,
688 ) -> Task<Result<SharedString>> {
689 let thread = self.open_thread(id.clone(), cx);
690 cx.spawn(async move |this, cx| {
691 let acp_thread = thread.await?;
692 let result = this
693 .update(cx, |this, cx| {
694 this.sessions
695 .get(&id)
696 .unwrap()
697 .thread
698 .update(cx, |thread, cx| thread.summary(cx))
699 })?
700 .await
701 .context("Failed to generate summary")?;
702 drop(acp_thread);
703 Ok(result)
704 })
705 }
706
707 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
708 if thread.read(cx).is_empty() {
709 return;
710 }
711
712 let database_future = ThreadsDatabase::connect(cx);
713 let (id, db_thread) =
714 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
715 let Some(session) = self.sessions.get_mut(&id) else {
716 return;
717 };
718 let history = self.history.clone();
719 session.pending_save = cx.spawn(async move |_, cx| {
720 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
721 return;
722 };
723 let db_thread = db_thread.await;
724 database.save_thread(id, db_thread).await.log_err();
725 history.update(cx, |history, cx| history.reload(cx)).ok();
726 });
727 }
728}
729
730/// Wrapper struct that implements the AgentConnection trait
731#[derive(Clone)]
732pub struct NativeAgentConnection(pub Entity<NativeAgent>);
733
734impl NativeAgentConnection {
735 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
736 self.0
737 .read(cx)
738 .sessions
739 .get(session_id)
740 .map(|session| session.thread.clone())
741 }
742
743 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
744 self.0.update(cx, |this, cx| this.load_thread(id, cx))
745 }
746
747 fn run_turn(
748 &self,
749 session_id: acp::SessionId,
750 cx: &mut App,
751 f: impl 'static
752 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
753 ) -> Task<Result<acp::PromptResponse>> {
754 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
755 agent
756 .sessions
757 .get_mut(&session_id)
758 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
759 }) else {
760 return Task::ready(Err(anyhow!("Session not found")));
761 };
762 log::debug!("Found session for: {}", session_id);
763
764 let response_stream = match f(thread, cx) {
765 Ok(stream) => stream,
766 Err(err) => return Task::ready(Err(err)),
767 };
768 Self::handle_thread_events(response_stream, acp_thread, cx)
769 }
770
771 fn handle_thread_events(
772 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
773 acp_thread: WeakEntity<AcpThread>,
774 cx: &App,
775 ) -> Task<Result<acp::PromptResponse>> {
776 cx.spawn(async move |cx| {
777 // Handle response stream and forward to session.acp_thread
778 while let Some(result) = events.next().await {
779 match result {
780 Ok(event) => {
781 log::trace!("Received completion event: {:?}", event);
782
783 match event {
784 ThreadEvent::UserMessage(message) => {
785 acp_thread.update(cx, |thread, cx| {
786 for content in message.content {
787 thread.push_user_content_block(
788 Some(message.id.clone()),
789 content.into(),
790 cx,
791 );
792 }
793 })?;
794 }
795 ThreadEvent::AgentText(text) => {
796 acp_thread.update(cx, |thread, cx| {
797 thread.push_assistant_content_block(text.into(), false, cx)
798 })?;
799 }
800 ThreadEvent::AgentThinking(text) => {
801 acp_thread.update(cx, |thread, cx| {
802 thread.push_assistant_content_block(text.into(), true, cx)
803 })?;
804 }
805 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
806 tool_call,
807 options,
808 response,
809 }) => {
810 let outcome_task = acp_thread.update(cx, |thread, cx| {
811 thread.request_tool_call_authorization(
812 tool_call, options, true, cx,
813 )
814 })??;
815 cx.background_spawn(async move {
816 if let acp::RequestPermissionOutcome::Selected(
817 acp::SelectedPermissionOutcome { option_id, .. },
818 ) = outcome_task.await
819 {
820 response
821 .send(option_id)
822 .map(|_| anyhow!("authorization receiver was dropped"))
823 .log_err();
824 }
825 })
826 .detach();
827 }
828 ThreadEvent::ToolCall(tool_call) => {
829 acp_thread.update(cx, |thread, cx| {
830 thread.upsert_tool_call(tool_call, cx)
831 })??;
832 }
833 ThreadEvent::ToolCallUpdate(update) => {
834 acp_thread.update(cx, |thread, cx| {
835 thread.update_tool_call(update, cx)
836 })??;
837 }
838 ThreadEvent::Retry(status) => {
839 acp_thread.update(cx, |thread, cx| {
840 thread.update_retry_status(status, cx)
841 })?;
842 }
843 ThreadEvent::Stop(stop_reason) => {
844 log::debug!("Assistant message complete: {:?}", stop_reason);
845 return Ok(acp::PromptResponse::new(stop_reason));
846 }
847 }
848 }
849 Err(e) => {
850 log::error!("Error in model response stream: {:?}", e);
851 return Err(e);
852 }
853 }
854 }
855
856 log::debug!("Response stream completed");
857 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
858 })
859 }
860}
861
862struct NativeAgentModelSelector {
863 session_id: acp::SessionId,
864 connection: NativeAgentConnection,
865}
866
867impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
868 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
869 log::debug!("NativeAgentConnection::list_models called");
870 let list = self.connection.0.read(cx).models.model_list.clone();
871 Task::ready(if list.is_empty() {
872 Err(anyhow::anyhow!("No models available"))
873 } else {
874 Ok(list)
875 })
876 }
877
878 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
879 log::debug!(
880 "Setting model for session {}: {}",
881 self.session_id,
882 model_id
883 );
884 let Some(thread) = self
885 .connection
886 .0
887 .read(cx)
888 .sessions
889 .get(&self.session_id)
890 .map(|session| session.thread.clone())
891 else {
892 return Task::ready(Err(anyhow!("Session not found")));
893 };
894
895 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
896 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
897 };
898
899 thread.update(cx, |thread, cx| {
900 thread.set_model(model.clone(), cx);
901 });
902
903 update_settings_file(
904 self.connection.0.read(cx).fs.clone(),
905 cx,
906 move |settings, _cx| {
907 let provider = model.provider_id().0.to_string();
908 let model = model.id().0.to_string();
909 settings
910 .agent
911 .get_or_insert_default()
912 .set_model(LanguageModelSelection {
913 provider: provider.into(),
914 model,
915 });
916 },
917 );
918
919 Task::ready(Ok(()))
920 }
921
922 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
923 let Some(thread) = self
924 .connection
925 .0
926 .read(cx)
927 .sessions
928 .get(&self.session_id)
929 .map(|session| session.thread.clone())
930 else {
931 return Task::ready(Err(anyhow!("Session not found")));
932 };
933 let Some(model) = thread.read(cx).model() else {
934 return Task::ready(Err(anyhow!("Model not found")));
935 };
936 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
937 else {
938 return Task::ready(Err(anyhow!("Provider not found")));
939 };
940 Task::ready(Ok(LanguageModels::map_language_model_to_info(
941 model, &provider,
942 )))
943 }
944
945 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
946 Some(self.connection.0.read(cx).models.watch())
947 }
948
949 fn should_render_footer(&self) -> bool {
950 true
951 }
952}
953
954impl acp_thread::AgentConnection for NativeAgentConnection {
955 fn telemetry_id(&self) -> SharedString {
956 "zed".into()
957 }
958
959 fn new_thread(
960 self: Rc<Self>,
961 project: Entity<Project>,
962 cwd: &Path,
963 cx: &mut App,
964 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
965 let agent = self.0.clone();
966 log::debug!("Creating new thread for project at: {:?}", cwd);
967
968 cx.spawn(async move |cx| {
969 log::debug!("Starting thread creation in async context");
970
971 // Create Thread
972 let thread = agent.update(
973 cx,
974 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
975 // Fetch default model from registry settings
976 let registry = LanguageModelRegistry::read_global(cx);
977 // Log available models for debugging
978 let available_count = registry.available_models(cx).count();
979 log::debug!("Total available models: {}", available_count);
980
981 let default_model = registry.default_model().and_then(|default_model| {
982 agent
983 .models
984 .model_from_id(&LanguageModels::model_id(&default_model.model))
985 });
986 Ok(cx.new(|cx| {
987 Thread::new(
988 project.clone(),
989 agent.project_context.clone(),
990 agent.context_server_registry.clone(),
991 agent.templates.clone(),
992 default_model,
993 cx,
994 )
995 }))
996 },
997 )??;
998 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
999 })
1000 }
1001
1002 fn auth_methods(&self) -> &[acp::AuthMethod] {
1003 &[] // No auth for in-process
1004 }
1005
1006 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1007 Task::ready(Ok(()))
1008 }
1009
1010 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1011 Some(Rc::new(NativeAgentModelSelector {
1012 session_id: session_id.clone(),
1013 connection: self.clone(),
1014 }) as Rc<dyn AgentModelSelector>)
1015 }
1016
1017 fn prompt(
1018 &self,
1019 id: Option<acp_thread::UserMessageId>,
1020 params: acp::PromptRequest,
1021 cx: &mut App,
1022 ) -> Task<Result<acp::PromptResponse>> {
1023 let id = id.expect("UserMessageId is required");
1024 let session_id = params.session_id.clone();
1025 log::info!("Received prompt request for session: {}", session_id);
1026 log::debug!("Prompt blocks count: {}", params.prompt.len());
1027 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1028
1029 self.run_turn(session_id, cx, move |thread, cx| {
1030 let content: Vec<UserMessageContent> = params
1031 .prompt
1032 .into_iter()
1033 .map(|block| UserMessageContent::from_content_block(block, path_style))
1034 .collect::<Vec<_>>();
1035 log::debug!("Converted prompt to message: {} chars", content.len());
1036 log::debug!("Message id: {:?}", id);
1037 log::debug!("Message content: {:?}", content);
1038
1039 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1040 })
1041 }
1042
1043 fn resume(
1044 &self,
1045 session_id: &acp::SessionId,
1046 _cx: &App,
1047 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1048 Some(Rc::new(NativeAgentSessionResume {
1049 connection: self.clone(),
1050 session_id: session_id.clone(),
1051 }) as _)
1052 }
1053
1054 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1055 log::info!("Cancelling on session: {}", session_id);
1056 self.0.update(cx, |agent, cx| {
1057 if let Some(agent) = agent.sessions.get(session_id) {
1058 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1059 }
1060 });
1061 }
1062
1063 fn truncate(
1064 &self,
1065 session_id: &agent_client_protocol::SessionId,
1066 cx: &App,
1067 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1068 self.0.read_with(cx, |agent, _cx| {
1069 agent.sessions.get(session_id).map(|session| {
1070 Rc::new(NativeAgentSessionTruncate {
1071 thread: session.thread.clone(),
1072 acp_thread: session.acp_thread.clone(),
1073 }) as _
1074 })
1075 })
1076 }
1077
1078 fn set_title(
1079 &self,
1080 session_id: &acp::SessionId,
1081 _cx: &App,
1082 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1083 Some(Rc::new(NativeAgentSessionSetTitle {
1084 connection: self.clone(),
1085 session_id: session_id.clone(),
1086 }) as _)
1087 }
1088
1089 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1090 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1091 }
1092
1093 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1094 self
1095 }
1096}
1097
1098impl acp_thread::AgentTelemetry for NativeAgentConnection {
1099 fn thread_data(
1100 &self,
1101 session_id: &acp::SessionId,
1102 cx: &mut App,
1103 ) -> Task<Result<serde_json::Value>> {
1104 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1105 return Task::ready(Err(anyhow!("Session not found")));
1106 };
1107
1108 let task = session.thread.read(cx).to_db(cx);
1109 cx.background_spawn(async move {
1110 serde_json::to_value(task.await).context("Failed to serialize thread")
1111 })
1112 }
1113}
1114
1115struct NativeAgentSessionTruncate {
1116 thread: Entity<Thread>,
1117 acp_thread: WeakEntity<AcpThread>,
1118}
1119
1120impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1121 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1122 match self.thread.update(cx, |thread, cx| {
1123 thread.truncate(message_id.clone(), cx)?;
1124 Ok(thread.latest_token_usage())
1125 }) {
1126 Ok(usage) => {
1127 self.acp_thread
1128 .update(cx, |thread, cx| {
1129 thread.update_token_usage(usage, cx);
1130 })
1131 .ok();
1132 Task::ready(Ok(()))
1133 }
1134 Err(error) => Task::ready(Err(error)),
1135 }
1136 }
1137}
1138
1139struct NativeAgentSessionResume {
1140 connection: NativeAgentConnection,
1141 session_id: acp::SessionId,
1142}
1143
1144impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1145 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1146 self.connection
1147 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1148 thread.update(cx, |thread, cx| thread.resume(cx))
1149 })
1150 }
1151}
1152
1153struct NativeAgentSessionSetTitle {
1154 connection: NativeAgentConnection,
1155 session_id: acp::SessionId,
1156}
1157
1158impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1159 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1160 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1161 return Task::ready(Err(anyhow!("session not found")));
1162 };
1163 let thread = session.thread.clone();
1164 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1165 Task::ready(Ok(()))
1166 }
1167}
1168
1169pub struct AcpThreadEnvironment {
1170 acp_thread: WeakEntity<AcpThread>,
1171}
1172
1173impl ThreadEnvironment for AcpThreadEnvironment {
1174 fn create_terminal(
1175 &self,
1176 command: String,
1177 cwd: Option<PathBuf>,
1178 output_byte_limit: Option<u64>,
1179 cx: &mut AsyncApp,
1180 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1181 let task = self.acp_thread.update(cx, |thread, cx| {
1182 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1183 });
1184
1185 let acp_thread = self.acp_thread.clone();
1186 cx.spawn(async move |cx| {
1187 let terminal = task?.await?;
1188
1189 let (drop_tx, drop_rx) = oneshot::channel();
1190 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1191
1192 cx.spawn(async move |cx| {
1193 drop_rx.await.ok();
1194 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1195 })
1196 .detach();
1197
1198 let handle = AcpTerminalHandle {
1199 terminal,
1200 _drop_tx: Some(drop_tx),
1201 };
1202
1203 Ok(Rc::new(handle) as _)
1204 })
1205 }
1206}
1207
1208pub struct AcpTerminalHandle {
1209 terminal: Entity<acp_thread::Terminal>,
1210 _drop_tx: Option<oneshot::Sender<()>>,
1211}
1212
1213impl TerminalHandle for AcpTerminalHandle {
1214 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1215 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1216 }
1217
1218 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1219 self.terminal
1220 .read_with(cx, |term, _cx| term.wait_for_exit())
1221 }
1222
1223 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1224 self.terminal
1225 .read_with(cx, |term, cx| term.current_output(cx))
1226 }
1227}
1228
1229#[cfg(test)]
1230mod internal_tests {
1231 use crate::HistoryEntryId;
1232
1233 use super::*;
1234 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1235 use fs::FakeFs;
1236 use gpui::TestAppContext;
1237 use indoc::formatdoc;
1238 use language_model::fake_provider::FakeLanguageModel;
1239 use serde_json::json;
1240 use settings::SettingsStore;
1241 use util::{path, rel_path::rel_path};
1242
1243 #[gpui::test]
1244 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1245 init_test(cx);
1246 let fs = FakeFs::new(cx.executor());
1247 fs.insert_tree(
1248 "/",
1249 json!({
1250 "a": {}
1251 }),
1252 )
1253 .await;
1254 let project = Project::test(fs.clone(), [], cx).await;
1255 let text_thread_store =
1256 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1257 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1258 let agent = NativeAgent::new(
1259 project.clone(),
1260 history_store,
1261 Templates::new(),
1262 None,
1263 fs.clone(),
1264 &mut cx.to_async(),
1265 )
1266 .await
1267 .unwrap();
1268 agent.read_with(cx, |agent, cx| {
1269 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1270 });
1271
1272 let worktree = project
1273 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1274 .await
1275 .unwrap();
1276 cx.run_until_parked();
1277 agent.read_with(cx, |agent, cx| {
1278 assert_eq!(
1279 agent.project_context.read(cx).worktrees,
1280 vec![WorktreeContext {
1281 root_name: "a".into(),
1282 abs_path: Path::new("/a").into(),
1283 rules_file: None
1284 }]
1285 )
1286 });
1287
1288 // Creating `/a/.rules` updates the project context.
1289 fs.insert_file("/a/.rules", Vec::new()).await;
1290 cx.run_until_parked();
1291 agent.read_with(cx, |agent, cx| {
1292 let rules_entry = worktree
1293 .read(cx)
1294 .entry_for_path(rel_path(".rules"))
1295 .unwrap();
1296 assert_eq!(
1297 agent.project_context.read(cx).worktrees,
1298 vec![WorktreeContext {
1299 root_name: "a".into(),
1300 abs_path: Path::new("/a").into(),
1301 rules_file: Some(RulesFileContext {
1302 path_in_worktree: rel_path(".rules").into(),
1303 text: "".into(),
1304 project_entry_id: rules_entry.id.to_usize()
1305 })
1306 }]
1307 )
1308 });
1309 }
1310
1311 #[gpui::test]
1312 async fn test_listing_models(cx: &mut TestAppContext) {
1313 init_test(cx);
1314 let fs = FakeFs::new(cx.executor());
1315 fs.insert_tree("/", json!({ "a": {} })).await;
1316 let project = Project::test(fs.clone(), [], cx).await;
1317 let text_thread_store =
1318 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1319 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1320 let connection = NativeAgentConnection(
1321 NativeAgent::new(
1322 project.clone(),
1323 history_store,
1324 Templates::new(),
1325 None,
1326 fs.clone(),
1327 &mut cx.to_async(),
1328 )
1329 .await
1330 .unwrap(),
1331 );
1332
1333 // Create a thread/session
1334 let acp_thread = cx
1335 .update(|cx| {
1336 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1337 })
1338 .await
1339 .unwrap();
1340
1341 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1342
1343 let models = cx
1344 .update(|cx| {
1345 connection
1346 .model_selector(&session_id)
1347 .unwrap()
1348 .list_models(cx)
1349 })
1350 .await
1351 .unwrap();
1352
1353 let acp_thread::AgentModelList::Grouped(models) = models else {
1354 panic!("Unexpected model group");
1355 };
1356 assert_eq!(
1357 models,
1358 IndexMap::from_iter([(
1359 AgentModelGroupName("Fake".into()),
1360 vec![AgentModelInfo {
1361 id: acp::ModelId::new("fake/fake"),
1362 name: "Fake".into(),
1363 description: None,
1364 icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
1365 }]
1366 )])
1367 );
1368 }
1369
1370 #[gpui::test]
1371 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1372 init_test(cx);
1373 let fs = FakeFs::new(cx.executor());
1374 fs.create_dir(paths::settings_file().parent().unwrap())
1375 .await
1376 .unwrap();
1377 fs.insert_file(
1378 paths::settings_file(),
1379 json!({
1380 "agent": {
1381 "default_model": {
1382 "provider": "foo",
1383 "model": "bar"
1384 }
1385 }
1386 })
1387 .to_string()
1388 .into_bytes(),
1389 )
1390 .await;
1391 let project = Project::test(fs.clone(), [], cx).await;
1392
1393 let text_thread_store =
1394 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1395 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1396
1397 // Create the agent and connection
1398 let agent = NativeAgent::new(
1399 project.clone(),
1400 history_store,
1401 Templates::new(),
1402 None,
1403 fs.clone(),
1404 &mut cx.to_async(),
1405 )
1406 .await
1407 .unwrap();
1408 let connection = NativeAgentConnection(agent.clone());
1409
1410 // Create a thread/session
1411 let acp_thread = cx
1412 .update(|cx| {
1413 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1414 })
1415 .await
1416 .unwrap();
1417
1418 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1419
1420 // Select a model
1421 let selector = connection.model_selector(&session_id).unwrap();
1422 let model_id = acp::ModelId::new("fake/fake");
1423 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1424 .await
1425 .unwrap();
1426
1427 // Verify the thread has the selected model
1428 agent.read_with(cx, |agent, _| {
1429 let session = agent.sessions.get(&session_id).unwrap();
1430 session.thread.read_with(cx, |thread, _| {
1431 assert_eq!(thread.model().unwrap().id().0, "fake");
1432 });
1433 });
1434
1435 cx.run_until_parked();
1436
1437 // Verify settings file was updated
1438 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1439 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1440
1441 // Check that the agent settings contain the selected model
1442 assert_eq!(
1443 settings_json["agent"]["default_model"]["model"],
1444 json!("fake")
1445 );
1446 assert_eq!(
1447 settings_json["agent"]["default_model"]["provider"],
1448 json!("fake")
1449 );
1450 }
1451
1452 #[gpui::test]
1453 async fn test_save_load_thread(cx: &mut TestAppContext) {
1454 init_test(cx);
1455 let fs = FakeFs::new(cx.executor());
1456 fs.insert_tree(
1457 "/",
1458 json!({
1459 "a": {
1460 "b.md": "Lorem"
1461 }
1462 }),
1463 )
1464 .await;
1465 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1466 let text_thread_store =
1467 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1468 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1469 let agent = NativeAgent::new(
1470 project.clone(),
1471 history_store.clone(),
1472 Templates::new(),
1473 None,
1474 fs.clone(),
1475 &mut cx.to_async(),
1476 )
1477 .await
1478 .unwrap();
1479 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1480
1481 let acp_thread = cx
1482 .update(|cx| {
1483 connection
1484 .clone()
1485 .new_thread(project.clone(), Path::new(""), cx)
1486 })
1487 .await
1488 .unwrap();
1489 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1490 let thread = agent.read_with(cx, |agent, _| {
1491 agent.sessions.get(&session_id).unwrap().thread.clone()
1492 });
1493
1494 // Ensure empty threads are not saved, even if they get mutated.
1495 let model = Arc::new(FakeLanguageModel::default());
1496 let summary_model = Arc::new(FakeLanguageModel::default());
1497 thread.update(cx, |thread, cx| {
1498 thread.set_model(model.clone(), cx);
1499 thread.set_summarization_model(Some(summary_model.clone()), cx);
1500 });
1501 cx.run_until_parked();
1502 assert_eq!(history_entries(&history_store, cx), vec![]);
1503
1504 let send = acp_thread.update(cx, |thread, cx| {
1505 thread.send(
1506 vec![
1507 "What does ".into(),
1508 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1509 "b.md",
1510 MentionUri::File {
1511 abs_path: path!("/a/b.md").into(),
1512 }
1513 .to_uri()
1514 .to_string(),
1515 )),
1516 " mean?".into(),
1517 ],
1518 cx,
1519 )
1520 });
1521 let send = cx.foreground_executor().spawn(send);
1522 cx.run_until_parked();
1523
1524 model.send_last_completion_stream_text_chunk("Lorem.");
1525 model.end_last_completion_stream();
1526 cx.run_until_parked();
1527 summary_model
1528 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1529 summary_model.end_last_completion_stream();
1530
1531 send.await.unwrap();
1532 let uri = MentionUri::File {
1533 abs_path: path!("/a/b.md").into(),
1534 }
1535 .to_uri();
1536 acp_thread.read_with(cx, |thread, cx| {
1537 assert_eq!(
1538 thread.to_markdown(cx),
1539 formatdoc! {"
1540 ## User
1541
1542 What does [@b.md]({uri}) mean?
1543
1544 ## Assistant
1545
1546 Lorem.
1547
1548 "}
1549 )
1550 });
1551
1552 cx.run_until_parked();
1553
1554 // Drop the ACP thread, which should cause the session to be dropped as well.
1555 cx.update(|_| {
1556 drop(thread);
1557 drop(acp_thread);
1558 });
1559 agent.read_with(cx, |agent, _| {
1560 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1561 });
1562
1563 // Ensure the thread can be reloaded from disk.
1564 assert_eq!(
1565 history_entries(&history_store, cx),
1566 vec![(
1567 HistoryEntryId::AcpThread(session_id.clone()),
1568 format!("Explaining {}", path!("/a/b.md"))
1569 )]
1570 );
1571 let acp_thread = agent
1572 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1573 .await
1574 .unwrap();
1575 acp_thread.read_with(cx, |thread, cx| {
1576 assert_eq!(
1577 thread.to_markdown(cx),
1578 formatdoc! {"
1579 ## User
1580
1581 What does [@b.md]({uri}) mean?
1582
1583 ## Assistant
1584
1585 Lorem.
1586
1587 "}
1588 )
1589 });
1590 }
1591
1592 fn history_entries(
1593 history: &Entity<HistoryStore>,
1594 cx: &mut TestAppContext,
1595 ) -> Vec<(HistoryEntryId, String)> {
1596 history.read_with(cx, |history, _| {
1597 history
1598 .entries()
1599 .map(|e| (e.id(), e.title().to_string()))
1600 .collect::<Vec<_>>()
1601 })
1602 }
1603
1604 fn init_test(cx: &mut TestAppContext) {
1605 env_logger::try_init().ok();
1606 cx.update(|cx| {
1607 let settings_store = SettingsStore::test(cx);
1608 cx.set_global(settings_store);
1609
1610 LanguageModelRegistry::test(cx);
1611 });
1612 }
1613}