1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tools;
10
11#[cfg(test)]
12mod tests;
13
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
37};
38use serde::{Deserialize, Serialize};
39use settings::{LanguageModelSelection, update_settings_file};
40use std::any::Any;
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54const RULES_FILE_NAMES: [&str; 9] = [
55 ".rules",
56 ".cursorrules",
57 ".windsurfrules",
58 ".clinerules",
59 ".github/copilot-instructions.md",
60 "CLAUDE.md",
61 "AGENT.md",
62 "AGENTS.md",
63 "GEMINI.md",
64];
65
66pub struct RulesLoadingError {
67 pub message: SharedString,
68}
69
70/// Holds both the internal Thread and the AcpThread for a session
71struct Session {
72 /// The internal thread that processes messages
73 thread: Entity<Thread>,
74 /// The ACP thread that handles protocol communication
75 acp_thread: WeakEntity<acp_thread::AcpThread>,
76 pending_save: Task<()>,
77 _subscriptions: Vec<Subscription>,
78}
79
80pub struct LanguageModels {
81 /// Access language model by ID
82 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
83 /// Cached list for returning language model information
84 model_list: acp_thread::AgentModelList,
85 refresh_models_rx: watch::Receiver<()>,
86 refresh_models_tx: watch::Sender<()>,
87 _authenticate_all_providers_task: Task<()>,
88}
89
90impl LanguageModels {
91 fn new(cx: &mut App) -> Self {
92 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
93
94 let mut this = Self {
95 models: HashMap::default(),
96 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
97 refresh_models_rx,
98 refresh_models_tx,
99 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
100 };
101 this.refresh_list(cx);
102 this
103 }
104
105 fn refresh_list(&mut self, cx: &App) {
106 let providers = LanguageModelRegistry::global(cx)
107 .read(cx)
108 .providers()
109 .into_iter()
110 .filter(|provider| provider.is_authenticated(cx))
111 .collect::<Vec<_>>();
112
113 let mut language_model_list = IndexMap::default();
114 let mut recommended_models = HashSet::default();
115
116 let mut recommended = Vec::new();
117 for provider in &providers {
118 for model in provider.recommended_models(cx) {
119 recommended_models.insert((model.provider_id(), model.id()));
120 recommended.push(Self::map_language_model_to_info(&model, provider));
121 }
122 }
123 if !recommended.is_empty() {
124 language_model_list.insert(
125 acp_thread::AgentModelGroupName("Recommended".into()),
126 recommended,
127 );
128 }
129
130 let mut models = HashMap::default();
131 for provider in providers {
132 let mut provider_models = Vec::new();
133 for model in provider.provided_models(cx) {
134 let model_info = Self::map_language_model_to_info(&model, &provider);
135 let model_id = model_info.id.clone();
136 provider_models.push(model_info);
137 models.insert(model_id, model);
138 }
139 if !provider_models.is_empty() {
140 language_model_list.insert(
141 acp_thread::AgentModelGroupName(provider.name().0.clone()),
142 provider_models,
143 );
144 }
145 }
146
147 self.models = models;
148 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
149 self.refresh_models_tx.send(()).ok();
150 }
151
152 fn watch(&self) -> watch::Receiver<()> {
153 self.refresh_models_rx.clone()
154 }
155
156 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
157 self.models.get(model_id).cloned()
158 }
159
160 fn map_language_model_to_info(
161 model: &Arc<dyn LanguageModel>,
162 provider: &Arc<dyn LanguageModelProvider>,
163 ) -> acp_thread::AgentModelInfo {
164 acp_thread::AgentModelInfo {
165 id: Self::model_id(model),
166 name: model.name().0,
167 description: None,
168 icon: Some(provider.icon()),
169 }
170 }
171
172 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
173 acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
174 }
175
176 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
177 let authenticate_all_providers = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .providers()
180 .iter()
181 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
182 .collect::<Vec<_>>();
183
184 cx.background_spawn(async move {
185 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
186 if let Err(err) = authenticate_task.await {
187 match err {
188 language_model::AuthenticateError::CredentialsNotFound => {
189 // Since we're authenticating these providers in the
190 // background for the purposes of populating the
191 // language selector, we don't care about providers
192 // where the credentials are not found.
193 }
194 language_model::AuthenticateError::ConnectionRefused => {
195 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
196 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
197 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
198 }
199 _ => {
200 // Some providers have noisy failure states that we
201 // don't want to spam the logs with every time the
202 // language model selector is initialized.
203 //
204 // Ideally these should have more clear failure modes
205 // that we know are safe to ignore here, like what we do
206 // with `CredentialsNotFound` above.
207 match provider_id.0.as_ref() {
208 "lmstudio" | "ollama" => {
209 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210 //
211 // These fail noisily, so we don't log them.
212 }
213 "copilot_chat" => {
214 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215 }
216 _ => {
217 log::error!(
218 "Failed to authenticate provider: {}: {err:#}",
219 provider_name.0
220 );
221 }
222 }
223 }
224 }
225 }
226 }
227 })
228 }
229}
230
231pub struct NativeAgent {
232 /// Session ID -> Session mapping
233 sessions: HashMap<acp::SessionId, Session>,
234 history: Entity<HistoryStore>,
235 /// Shared project context for all threads
236 project_context: Entity<ProjectContext>,
237 project_context_needs_refresh: watch::Sender<()>,
238 _maintain_project_context: Task<Result<()>>,
239 context_server_registry: Entity<ContextServerRegistry>,
240 /// Shared templates for all threads
241 templates: Arc<Templates>,
242 /// Cached model information
243 models: LanguageModels,
244 project: Entity<Project>,
245 prompt_store: Option<Entity<PromptStore>>,
246 fs: Arc<dyn Fs>,
247 _subscriptions: Vec<Subscription>,
248}
249
250impl NativeAgent {
251 pub async fn new(
252 project: Entity<Project>,
253 history: Entity<HistoryStore>,
254 templates: Arc<Templates>,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 cx: &mut AsyncApp,
258 ) -> Result<Entity<NativeAgent>> {
259 log::debug!("Creating new NativeAgent");
260
261 let project_context = cx
262 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
263 .await;
264
265 cx.new(|cx| {
266 let mut subscriptions = vec![
267 cx.subscribe(&project, Self::handle_project_event),
268 cx.subscribe(
269 &LanguageModelRegistry::global(cx),
270 Self::handle_models_updated_event,
271 ),
272 ];
273 if let Some(prompt_store) = prompt_store.as_ref() {
274 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
275 }
276
277 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
278 watch::channel(());
279 Self {
280 sessions: HashMap::new(),
281 history,
282 project_context: cx.new(|_| project_context),
283 project_context_needs_refresh: project_context_needs_refresh_tx,
284 _maintain_project_context: cx.spawn(async move |this, cx| {
285 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
286 }),
287 context_server_registry: cx.new(|cx| {
288 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
289 }),
290 templates,
291 models: LanguageModels::new(cx),
292 project,
293 prompt_store,
294 fs,
295 _subscriptions: subscriptions,
296 }
297 })
298 }
299
300 fn register_session(
301 &mut self,
302 thread_handle: Entity<Thread>,
303 cx: &mut Context<Self>,
304 ) -> Entity<AcpThread> {
305 let connection = Rc::new(NativeAgentConnection(cx.entity()));
306
307 let thread = thread_handle.read(cx);
308 let session_id = thread.id().clone();
309 let title = thread.title();
310 let project = thread.project.clone();
311 let action_log = thread.action_log.clone();
312 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
313 let acp_thread = cx.new(|cx| {
314 acp_thread::AcpThread::new(
315 title,
316 connection,
317 project.clone(),
318 action_log.clone(),
319 session_id.clone(),
320 prompt_capabilities_rx,
321 cx,
322 )
323 });
324
325 let registry = LanguageModelRegistry::read_global(cx);
326 let summarization_model = registry.thread_summary_model().map(|c| c.model);
327
328 thread_handle.update(cx, |thread, cx| {
329 thread.set_summarization_model(summarization_model, cx);
330 thread.add_default_tools(
331 Rc::new(AcpThreadEnvironment {
332 acp_thread: acp_thread.downgrade(),
333 }) as _,
334 cx,
335 )
336 });
337
338 let subscriptions = vec![
339 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
340 this.sessions.remove(acp_thread.session_id());
341 }),
342 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
343 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
344 cx.observe(&thread_handle, move |this, thread, cx| {
345 this.save_thread(thread, cx)
346 }),
347 ];
348
349 self.sessions.insert(
350 session_id,
351 Session {
352 thread: thread_handle,
353 acp_thread: acp_thread.downgrade(),
354 _subscriptions: subscriptions,
355 pending_save: Task::ready(()),
356 },
357 );
358 acp_thread
359 }
360
361 pub fn models(&self) -> &LanguageModels {
362 &self.models
363 }
364
365 async fn maintain_project_context(
366 this: WeakEntity<Self>,
367 mut needs_refresh: watch::Receiver<()>,
368 cx: &mut AsyncApp,
369 ) -> Result<()> {
370 while needs_refresh.changed().await.is_ok() {
371 let project_context = this
372 .update(cx, |this, cx| {
373 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
374 })?
375 .await;
376 this.update(cx, |this, cx| {
377 this.project_context = cx.new(|_| project_context);
378 })?;
379 }
380
381 Ok(())
382 }
383
384 fn build_project_context(
385 project: &Entity<Project>,
386 prompt_store: Option<&Entity<PromptStore>>,
387 cx: &mut App,
388 ) -> Task<ProjectContext> {
389 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
390 let worktree_tasks = worktrees
391 .into_iter()
392 .map(|worktree| {
393 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
394 })
395 .collect::<Vec<_>>();
396 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
397 prompt_store.read_with(cx, |prompt_store, cx| {
398 let prompts = prompt_store.default_prompt_metadata();
399 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
400 let contents = prompt_store.load(prompt_metadata.id, cx);
401 async move { (contents.await, prompt_metadata) }
402 });
403 cx.background_spawn(future::join_all(load_tasks))
404 })
405 } else {
406 Task::ready(vec![])
407 };
408
409 cx.spawn(async move |_cx| {
410 let (worktrees, default_user_rules) =
411 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
412
413 let worktrees = worktrees
414 .into_iter()
415 .map(|(worktree, _rules_error)| {
416 // TODO: show error message
417 // if let Some(rules_error) = rules_error {
418 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
419 // }
420 worktree
421 })
422 .collect::<Vec<_>>();
423
424 let default_user_rules = default_user_rules
425 .into_iter()
426 .flat_map(|(contents, prompt_metadata)| match contents {
427 Ok(contents) => Some(UserRulesContext {
428 uuid: match prompt_metadata.id {
429 prompt_store::PromptId::User { uuid } => uuid,
430 prompt_store::PromptId::EditWorkflow => return None,
431 },
432 title: prompt_metadata.title.map(|title| title.to_string()),
433 contents,
434 }),
435 Err(_err) => {
436 // TODO: show error message
437 // this.update(cx, |_, cx| {
438 // cx.emit(RulesLoadingError {
439 // message: format!("{err:?}").into(),
440 // });
441 // })
442 // .ok();
443 None
444 }
445 })
446 .collect::<Vec<_>>();
447
448 ProjectContext::new(worktrees, default_user_rules)
449 })
450 }
451
452 fn load_worktree_info_for_system_prompt(
453 worktree: Entity<Worktree>,
454 project: Entity<Project>,
455 cx: &mut App,
456 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
457 let tree = worktree.read(cx);
458 let root_name = tree.root_name_str().into();
459 let abs_path = tree.abs_path();
460
461 let mut context = WorktreeContext {
462 root_name,
463 abs_path,
464 rules_file: None,
465 };
466
467 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
468 let Some(rules_task) = rules_task else {
469 return Task::ready((context, None));
470 };
471
472 cx.spawn(async move |_| {
473 let (rules_file, rules_file_error) = match rules_task.await {
474 Ok(rules_file) => (Some(rules_file), None),
475 Err(err) => (
476 None,
477 Some(RulesLoadingError {
478 message: format!("{err}").into(),
479 }),
480 ),
481 };
482 context.rules_file = rules_file;
483 (context, rules_file_error)
484 })
485 }
486
487 fn load_worktree_rules_file(
488 worktree: Entity<Worktree>,
489 project: Entity<Project>,
490 cx: &mut App,
491 ) -> Option<Task<Result<RulesFileContext>>> {
492 let worktree = worktree.read(cx);
493 let worktree_id = worktree.id();
494 let selected_rules_file = RULES_FILE_NAMES
495 .into_iter()
496 .filter_map(|name| {
497 worktree
498 .entry_for_path(RelPath::unix(name).unwrap())
499 .filter(|entry| entry.is_file())
500 .map(|entry| entry.path.clone())
501 })
502 .next();
503
504 // Note that Cline supports `.clinerules` being a directory, but that is not currently
505 // supported. This doesn't seem to occur often in GitHub repositories.
506 selected_rules_file.map(|path_in_worktree| {
507 let project_path = ProjectPath {
508 worktree_id,
509 path: path_in_worktree.clone(),
510 };
511 let buffer_task =
512 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
513 let rope_task = cx.spawn(async move |cx| {
514 buffer_task.await?.read_with(cx, |buffer, cx| {
515 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
516 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
517 })?
518 });
519 // Build a string from the rope on a background thread.
520 cx.background_spawn(async move {
521 let (project_entry_id, rope) = rope_task.await?;
522 anyhow::Ok(RulesFileContext {
523 path_in_worktree,
524 text: rope.to_string().trim().to_string(),
525 project_entry_id: project_entry_id.to_usize(),
526 })
527 })
528 })
529 }
530
531 fn handle_thread_title_updated(
532 &mut self,
533 thread: Entity<Thread>,
534 _: &TitleUpdated,
535 cx: &mut Context<Self>,
536 ) {
537 let session_id = thread.read(cx).id();
538 let Some(session) = self.sessions.get(session_id) else {
539 return;
540 };
541 let thread = thread.downgrade();
542 let acp_thread = session.acp_thread.clone();
543 cx.spawn(async move |_, cx| {
544 let title = thread.read_with(cx, |thread, _| thread.title())?;
545 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
546 task.await
547 })
548 .detach_and_log_err(cx);
549 }
550
551 fn handle_thread_token_usage_updated(
552 &mut self,
553 thread: Entity<Thread>,
554 usage: &TokenUsageUpdated,
555 cx: &mut Context<Self>,
556 ) {
557 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
558 return;
559 };
560 session
561 .acp_thread
562 .update(cx, |acp_thread, cx| {
563 acp_thread.update_token_usage(usage.0.clone(), cx);
564 })
565 .ok();
566 }
567
568 fn handle_project_event(
569 &mut self,
570 _project: Entity<Project>,
571 event: &project::Event,
572 _cx: &mut Context<Self>,
573 ) {
574 match event {
575 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
576 self.project_context_needs_refresh.send(()).ok();
577 }
578 project::Event::WorktreeUpdatedEntries(_, items) => {
579 if items.iter().any(|(path, _, _)| {
580 RULES_FILE_NAMES
581 .iter()
582 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
583 }) {
584 self.project_context_needs_refresh.send(()).ok();
585 }
586 }
587 _ => {}
588 }
589 }
590
591 fn handle_prompts_updated_event(
592 &mut self,
593 _prompt_store: Entity<PromptStore>,
594 _event: &prompt_store::PromptsUpdatedEvent,
595 _cx: &mut Context<Self>,
596 ) {
597 self.project_context_needs_refresh.send(()).ok();
598 }
599
600 fn handle_models_updated_event(
601 &mut self,
602 _registry: Entity<LanguageModelRegistry>,
603 _event: &language_model::Event,
604 cx: &mut Context<Self>,
605 ) {
606 self.models.refresh_list(cx);
607
608 let registry = LanguageModelRegistry::read_global(cx);
609 let default_model = registry.default_model().map(|m| m.model);
610 let summarization_model = registry.thread_summary_model().map(|m| m.model);
611
612 for session in self.sessions.values_mut() {
613 session.thread.update(cx, |thread, cx| {
614 if thread.model().is_none()
615 && let Some(model) = default_model.clone()
616 {
617 thread.set_model(model, cx);
618 cx.notify();
619 }
620 thread.set_summarization_model(summarization_model.clone(), cx);
621 });
622 }
623 }
624
625 pub fn load_thread(
626 &mut self,
627 id: acp::SessionId,
628 cx: &mut Context<Self>,
629 ) -> Task<Result<Entity<Thread>>> {
630 let database_future = ThreadsDatabase::connect(cx);
631 cx.spawn(async move |this, cx| {
632 let database = database_future.await.map_err(|err| anyhow!(err))?;
633 let db_thread = database
634 .load_thread(id.clone())
635 .await?
636 .with_context(|| format!("no thread found with ID: {id:?}"))?;
637
638 this.update(cx, |this, cx| {
639 let summarization_model = LanguageModelRegistry::read_global(cx)
640 .thread_summary_model()
641 .map(|c| c.model);
642
643 cx.new(|cx| {
644 let mut thread = Thread::from_db(
645 id.clone(),
646 db_thread,
647 this.project.clone(),
648 this.project_context.clone(),
649 this.context_server_registry.clone(),
650 this.templates.clone(),
651 cx,
652 );
653 thread.set_summarization_model(summarization_model, cx);
654 thread
655 })
656 })
657 })
658 }
659
660 pub fn open_thread(
661 &mut self,
662 id: acp::SessionId,
663 cx: &mut Context<Self>,
664 ) -> Task<Result<Entity<AcpThread>>> {
665 let task = self.load_thread(id, cx);
666 cx.spawn(async move |this, cx| {
667 let thread = task.await?;
668 let acp_thread =
669 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
670 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
671 cx.update(|cx| {
672 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
673 })?
674 .await?;
675 Ok(acp_thread)
676 })
677 }
678
679 pub fn thread_summary(
680 &mut self,
681 id: acp::SessionId,
682 cx: &mut Context<Self>,
683 ) -> Task<Result<SharedString>> {
684 let thread = self.open_thread(id.clone(), cx);
685 cx.spawn(async move |this, cx| {
686 let acp_thread = thread.await?;
687 let result = this
688 .update(cx, |this, cx| {
689 this.sessions
690 .get(&id)
691 .unwrap()
692 .thread
693 .update(cx, |thread, cx| thread.summary(cx))
694 })?
695 .await
696 .context("Failed to generate summary")?;
697 drop(acp_thread);
698 Ok(result)
699 })
700 }
701
702 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
703 if thread.read(cx).is_empty() {
704 return;
705 }
706
707 let database_future = ThreadsDatabase::connect(cx);
708 let (id, db_thread) =
709 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
710 let Some(session) = self.sessions.get_mut(&id) else {
711 return;
712 };
713 let history = self.history.clone();
714 session.pending_save = cx.spawn(async move |_, cx| {
715 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
716 return;
717 };
718 let db_thread = db_thread.await;
719 database.save_thread(id, db_thread).await.log_err();
720 history.update(cx, |history, cx| history.reload(cx)).ok();
721 });
722 }
723}
724
725/// Wrapper struct that implements the AgentConnection trait
726#[derive(Clone)]
727pub struct NativeAgentConnection(pub Entity<NativeAgent>);
728
729impl NativeAgentConnection {
730 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
731 self.0
732 .read(cx)
733 .sessions
734 .get(session_id)
735 .map(|session| session.thread.clone())
736 }
737
738 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
739 self.0.update(cx, |this, cx| this.load_thread(id, cx))
740 }
741
742 fn run_turn(
743 &self,
744 session_id: acp::SessionId,
745 cx: &mut App,
746 f: impl 'static
747 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
748 ) -> Task<Result<acp::PromptResponse>> {
749 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
750 agent
751 .sessions
752 .get_mut(&session_id)
753 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
754 }) else {
755 return Task::ready(Err(anyhow!("Session not found")));
756 };
757 log::debug!("Found session for: {}", session_id);
758
759 let response_stream = match f(thread, cx) {
760 Ok(stream) => stream,
761 Err(err) => return Task::ready(Err(err)),
762 };
763 Self::handle_thread_events(response_stream, acp_thread, cx)
764 }
765
766 fn handle_thread_events(
767 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
768 acp_thread: WeakEntity<AcpThread>,
769 cx: &App,
770 ) -> Task<Result<acp::PromptResponse>> {
771 cx.spawn(async move |cx| {
772 // Handle response stream and forward to session.acp_thread
773 while let Some(result) = events.next().await {
774 match result {
775 Ok(event) => {
776 log::trace!("Received completion event: {:?}", event);
777
778 match event {
779 ThreadEvent::UserMessage(message) => {
780 acp_thread.update(cx, |thread, cx| {
781 for content in message.content {
782 thread.push_user_content_block(
783 Some(message.id.clone()),
784 content.into(),
785 cx,
786 );
787 }
788 })?;
789 }
790 ThreadEvent::AgentText(text) => {
791 acp_thread.update(cx, |thread, cx| {
792 thread.push_assistant_content_block(
793 acp::ContentBlock::Text(acp::TextContent {
794 text,
795 annotations: None,
796 meta: None,
797 }),
798 false,
799 cx,
800 )
801 })?;
802 }
803 ThreadEvent::AgentThinking(text) => {
804 acp_thread.update(cx, |thread, cx| {
805 thread.push_assistant_content_block(
806 acp::ContentBlock::Text(acp::TextContent {
807 text,
808 annotations: None,
809 meta: None,
810 }),
811 true,
812 cx,
813 )
814 })?;
815 }
816 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
817 tool_call,
818 options,
819 response,
820 }) => {
821 let outcome_task = acp_thread.update(cx, |thread, cx| {
822 thread.request_tool_call_authorization(
823 tool_call, options, true, cx,
824 )
825 })??;
826 cx.background_spawn(async move {
827 if let acp::RequestPermissionOutcome::Selected { option_id } =
828 outcome_task.await
829 {
830 response
831 .send(option_id)
832 .map(|_| anyhow!("authorization receiver was dropped"))
833 .log_err();
834 }
835 })
836 .detach();
837 }
838 ThreadEvent::ToolCall(tool_call) => {
839 acp_thread.update(cx, |thread, cx| {
840 thread.upsert_tool_call(tool_call, cx)
841 })??;
842 }
843 ThreadEvent::ToolCallUpdate(update) => {
844 acp_thread.update(cx, |thread, cx| {
845 thread.update_tool_call(update, cx)
846 })??;
847 }
848 ThreadEvent::Retry(status) => {
849 acp_thread.update(cx, |thread, cx| {
850 thread.update_retry_status(status, cx)
851 })?;
852 }
853 ThreadEvent::Stop(stop_reason) => {
854 log::debug!("Assistant message complete: {:?}", stop_reason);
855 return Ok(acp::PromptResponse {
856 stop_reason,
857 meta: None,
858 });
859 }
860 }
861 }
862 Err(e) => {
863 log::error!("Error in model response stream: {:?}", e);
864 return Err(e);
865 }
866 }
867 }
868
869 log::debug!("Response stream completed");
870 anyhow::Ok(acp::PromptResponse {
871 stop_reason: acp::StopReason::EndTurn,
872 meta: None,
873 })
874 })
875 }
876}
877
878struct NativeAgentModelSelector {
879 session_id: acp::SessionId,
880 connection: NativeAgentConnection,
881}
882
883impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
884 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
885 log::debug!("NativeAgentConnection::list_models called");
886 let list = self.connection.0.read(cx).models.model_list.clone();
887 Task::ready(if list.is_empty() {
888 Err(anyhow::anyhow!("No models available"))
889 } else {
890 Ok(list)
891 })
892 }
893
894 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
895 log::debug!(
896 "Setting model for session {}: {}",
897 self.session_id,
898 model_id
899 );
900 let Some(thread) = self
901 .connection
902 .0
903 .read(cx)
904 .sessions
905 .get(&self.session_id)
906 .map(|session| session.thread.clone())
907 else {
908 return Task::ready(Err(anyhow!("Session not found")));
909 };
910
911 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
912 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
913 };
914
915 thread.update(cx, |thread, cx| {
916 thread.set_model(model.clone(), cx);
917 });
918
919 update_settings_file(
920 self.connection.0.read(cx).fs.clone(),
921 cx,
922 move |settings, _cx| {
923 let provider = model.provider_id().0.to_string();
924 let model = model.id().0.to_string();
925 settings
926 .agent
927 .get_or_insert_default()
928 .set_model(LanguageModelSelection {
929 provider: provider.into(),
930 model,
931 });
932 },
933 );
934
935 Task::ready(Ok(()))
936 }
937
938 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
939 let Some(thread) = self
940 .connection
941 .0
942 .read(cx)
943 .sessions
944 .get(&self.session_id)
945 .map(|session| session.thread.clone())
946 else {
947 return Task::ready(Err(anyhow!("Session not found")));
948 };
949 let Some(model) = thread.read(cx).model() else {
950 return Task::ready(Err(anyhow!("Model not found")));
951 };
952 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
953 else {
954 return Task::ready(Err(anyhow!("Provider not found")));
955 };
956 Task::ready(Ok(LanguageModels::map_language_model_to_info(
957 model, &provider,
958 )))
959 }
960
961 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
962 Some(self.connection.0.read(cx).models.watch())
963 }
964}
965
966impl acp_thread::AgentConnection for NativeAgentConnection {
967 fn telemetry_id(&self) -> &'static str {
968 "zed"
969 }
970
971 fn new_thread(
972 self: Rc<Self>,
973 project: Entity<Project>,
974 cwd: &Path,
975 cx: &mut App,
976 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
977 let agent = self.0.clone();
978 log::debug!("Creating new thread for project at: {:?}", cwd);
979
980 cx.spawn(async move |cx| {
981 log::debug!("Starting thread creation in async context");
982
983 // Create Thread
984 let thread = agent.update(
985 cx,
986 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
987 // Fetch default model from registry settings
988 let registry = LanguageModelRegistry::read_global(cx);
989 // Log available models for debugging
990 let available_count = registry.available_models(cx).count();
991 log::debug!("Total available models: {}", available_count);
992
993 let default_model = registry.default_model().and_then(|default_model| {
994 agent
995 .models
996 .model_from_id(&LanguageModels::model_id(&default_model.model))
997 });
998 Ok(cx.new(|cx| {
999 Thread::new(
1000 project.clone(),
1001 agent.project_context.clone(),
1002 agent.context_server_registry.clone(),
1003 agent.templates.clone(),
1004 default_model,
1005 cx,
1006 )
1007 }))
1008 },
1009 )??;
1010 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1011 })
1012 }
1013
1014 fn auth_methods(&self) -> &[acp::AuthMethod] {
1015 &[] // No auth for in-process
1016 }
1017
1018 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1019 Task::ready(Ok(()))
1020 }
1021
1022 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1023 Some(Rc::new(NativeAgentModelSelector {
1024 session_id: session_id.clone(),
1025 connection: self.clone(),
1026 }) as Rc<dyn AgentModelSelector>)
1027 }
1028
1029 fn prompt(
1030 &self,
1031 id: Option<acp_thread::UserMessageId>,
1032 params: acp::PromptRequest,
1033 cx: &mut App,
1034 ) -> Task<Result<acp::PromptResponse>> {
1035 let id = id.expect("UserMessageId is required");
1036 let session_id = params.session_id.clone();
1037 log::info!("Received prompt request for session: {}", session_id);
1038 log::debug!("Prompt blocks count: {}", params.prompt.len());
1039 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1040
1041 self.run_turn(session_id, cx, move |thread, cx| {
1042 let content: Vec<UserMessageContent> = params
1043 .prompt
1044 .into_iter()
1045 .map(|block| UserMessageContent::from_content_block(block, path_style))
1046 .collect::<Vec<_>>();
1047 log::debug!("Converted prompt to message: {} chars", content.len());
1048 log::debug!("Message id: {:?}", id);
1049 log::debug!("Message content: {:?}", content);
1050
1051 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1052 })
1053 }
1054
1055 fn resume(
1056 &self,
1057 session_id: &acp::SessionId,
1058 _cx: &App,
1059 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1060 Some(Rc::new(NativeAgentSessionResume {
1061 connection: self.clone(),
1062 session_id: session_id.clone(),
1063 }) as _)
1064 }
1065
1066 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1067 log::info!("Cancelling on session: {}", session_id);
1068 self.0.update(cx, |agent, cx| {
1069 if let Some(agent) = agent.sessions.get(session_id) {
1070 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1071 }
1072 });
1073 }
1074
1075 fn truncate(
1076 &self,
1077 session_id: &agent_client_protocol::SessionId,
1078 cx: &App,
1079 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1080 self.0.read_with(cx, |agent, _cx| {
1081 agent.sessions.get(session_id).map(|session| {
1082 Rc::new(NativeAgentSessionTruncate {
1083 thread: session.thread.clone(),
1084 acp_thread: session.acp_thread.clone(),
1085 }) as _
1086 })
1087 })
1088 }
1089
1090 fn set_title(
1091 &self,
1092 session_id: &acp::SessionId,
1093 _cx: &App,
1094 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1095 Some(Rc::new(NativeAgentSessionSetTitle {
1096 connection: self.clone(),
1097 session_id: session_id.clone(),
1098 }) as _)
1099 }
1100
1101 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1102 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1103 }
1104
1105 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1106 self
1107 }
1108}
1109
1110impl acp_thread::AgentTelemetry for NativeAgentConnection {
1111 fn thread_data(
1112 &self,
1113 session_id: &acp::SessionId,
1114 cx: &mut App,
1115 ) -> Task<Result<serde_json::Value>> {
1116 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1117 return Task::ready(Err(anyhow!("Session not found")));
1118 };
1119
1120 let task = session.thread.read(cx).to_db(cx);
1121 cx.background_spawn(async move {
1122 serde_json::to_value(task.await).context("Failed to serialize thread")
1123 })
1124 }
1125}
1126
1127struct NativeAgentSessionTruncate {
1128 thread: Entity<Thread>,
1129 acp_thread: WeakEntity<AcpThread>,
1130}
1131
1132impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1133 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1134 match self.thread.update(cx, |thread, cx| {
1135 thread.truncate(message_id.clone(), cx)?;
1136 Ok(thread.latest_token_usage())
1137 }) {
1138 Ok(usage) => {
1139 self.acp_thread
1140 .update(cx, |thread, cx| {
1141 thread.update_token_usage(usage, cx);
1142 })
1143 .ok();
1144 Task::ready(Ok(()))
1145 }
1146 Err(error) => Task::ready(Err(error)),
1147 }
1148 }
1149}
1150
1151struct NativeAgentSessionResume {
1152 connection: NativeAgentConnection,
1153 session_id: acp::SessionId,
1154}
1155
1156impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1157 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1158 self.connection
1159 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1160 thread.update(cx, |thread, cx| thread.resume(cx))
1161 })
1162 }
1163}
1164
1165struct NativeAgentSessionSetTitle {
1166 connection: NativeAgentConnection,
1167 session_id: acp::SessionId,
1168}
1169
1170impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1171 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1172 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1173 return Task::ready(Err(anyhow!("session not found")));
1174 };
1175 let thread = session.thread.clone();
1176 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1177 Task::ready(Ok(()))
1178 }
1179}
1180
1181pub struct AcpThreadEnvironment {
1182 acp_thread: WeakEntity<AcpThread>,
1183}
1184
1185impl ThreadEnvironment for AcpThreadEnvironment {
1186 fn create_terminal(
1187 &self,
1188 command: String,
1189 cwd: Option<PathBuf>,
1190 output_byte_limit: Option<u64>,
1191 cx: &mut AsyncApp,
1192 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1193 let task = self.acp_thread.update(cx, |thread, cx| {
1194 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1195 });
1196
1197 let acp_thread = self.acp_thread.clone();
1198 cx.spawn(async move |cx| {
1199 let terminal = task?.await?;
1200
1201 let (drop_tx, drop_rx) = oneshot::channel();
1202 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1203
1204 cx.spawn(async move |cx| {
1205 drop_rx.await.ok();
1206 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1207 })
1208 .detach();
1209
1210 let handle = AcpTerminalHandle {
1211 terminal,
1212 _drop_tx: Some(drop_tx),
1213 };
1214
1215 Ok(Rc::new(handle) as _)
1216 })
1217 }
1218}
1219
1220pub struct AcpTerminalHandle {
1221 terminal: Entity<acp_thread::Terminal>,
1222 _drop_tx: Option<oneshot::Sender<()>>,
1223}
1224
1225impl TerminalHandle for AcpTerminalHandle {
1226 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1227 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1228 }
1229
1230 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1231 self.terminal
1232 .read_with(cx, |term, _cx| term.wait_for_exit())
1233 }
1234
1235 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1236 self.terminal
1237 .read_with(cx, |term, cx| term.current_output(cx))
1238 }
1239}
1240
1241#[cfg(test)]
1242mod internal_tests {
1243 use crate::HistoryEntryId;
1244
1245 use super::*;
1246 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1247 use fs::FakeFs;
1248 use gpui::TestAppContext;
1249 use indoc::formatdoc;
1250 use language_model::fake_provider::FakeLanguageModel;
1251 use serde_json::json;
1252 use settings::SettingsStore;
1253 use util::{path, rel_path::rel_path};
1254
1255 #[gpui::test]
1256 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1257 init_test(cx);
1258 let fs = FakeFs::new(cx.executor());
1259 fs.insert_tree(
1260 "/",
1261 json!({
1262 "a": {}
1263 }),
1264 )
1265 .await;
1266 let project = Project::test(fs.clone(), [], cx).await;
1267 let text_thread_store =
1268 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1269 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1270 let agent = NativeAgent::new(
1271 project.clone(),
1272 history_store,
1273 Templates::new(),
1274 None,
1275 fs.clone(),
1276 &mut cx.to_async(),
1277 )
1278 .await
1279 .unwrap();
1280 agent.read_with(cx, |agent, cx| {
1281 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1282 });
1283
1284 let worktree = project
1285 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1286 .await
1287 .unwrap();
1288 cx.run_until_parked();
1289 agent.read_with(cx, |agent, cx| {
1290 assert_eq!(
1291 agent.project_context.read(cx).worktrees,
1292 vec![WorktreeContext {
1293 root_name: "a".into(),
1294 abs_path: Path::new("/a").into(),
1295 rules_file: None
1296 }]
1297 )
1298 });
1299
1300 // Creating `/a/.rules` updates the project context.
1301 fs.insert_file("/a/.rules", Vec::new()).await;
1302 cx.run_until_parked();
1303 agent.read_with(cx, |agent, cx| {
1304 let rules_entry = worktree
1305 .read(cx)
1306 .entry_for_path(rel_path(".rules"))
1307 .unwrap();
1308 assert_eq!(
1309 agent.project_context.read(cx).worktrees,
1310 vec![WorktreeContext {
1311 root_name: "a".into(),
1312 abs_path: Path::new("/a").into(),
1313 rules_file: Some(RulesFileContext {
1314 path_in_worktree: rel_path(".rules").into(),
1315 text: "".into(),
1316 project_entry_id: rules_entry.id.to_usize()
1317 })
1318 }]
1319 )
1320 });
1321 }
1322
1323 #[gpui::test]
1324 async fn test_listing_models(cx: &mut TestAppContext) {
1325 init_test(cx);
1326 let fs = FakeFs::new(cx.executor());
1327 fs.insert_tree("/", json!({ "a": {} })).await;
1328 let project = Project::test(fs.clone(), [], cx).await;
1329 let text_thread_store =
1330 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1331 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1332 let connection = NativeAgentConnection(
1333 NativeAgent::new(
1334 project.clone(),
1335 history_store,
1336 Templates::new(),
1337 None,
1338 fs.clone(),
1339 &mut cx.to_async(),
1340 )
1341 .await
1342 .unwrap(),
1343 );
1344
1345 // Create a thread/session
1346 let acp_thread = cx
1347 .update(|cx| {
1348 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1349 })
1350 .await
1351 .unwrap();
1352
1353 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1354
1355 let models = cx
1356 .update(|cx| {
1357 connection
1358 .model_selector(&session_id)
1359 .unwrap()
1360 .list_models(cx)
1361 })
1362 .await
1363 .unwrap();
1364
1365 let acp_thread::AgentModelList::Grouped(models) = models else {
1366 panic!("Unexpected model group");
1367 };
1368 assert_eq!(
1369 models,
1370 IndexMap::from_iter([(
1371 AgentModelGroupName("Fake".into()),
1372 vec![AgentModelInfo {
1373 id: acp::ModelId("fake/fake".into()),
1374 name: "Fake".into(),
1375 description: None,
1376 icon: Some(ui::IconName::ZedAssistant),
1377 }]
1378 )])
1379 );
1380 }
1381
1382 #[gpui::test]
1383 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1384 init_test(cx);
1385 let fs = FakeFs::new(cx.executor());
1386 fs.create_dir(paths::settings_file().parent().unwrap())
1387 .await
1388 .unwrap();
1389 fs.insert_file(
1390 paths::settings_file(),
1391 json!({
1392 "agent": {
1393 "default_model": {
1394 "provider": "foo",
1395 "model": "bar"
1396 }
1397 }
1398 })
1399 .to_string()
1400 .into_bytes(),
1401 )
1402 .await;
1403 let project = Project::test(fs.clone(), [], cx).await;
1404
1405 let text_thread_store =
1406 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1407 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1408
1409 // Create the agent and connection
1410 let agent = NativeAgent::new(
1411 project.clone(),
1412 history_store,
1413 Templates::new(),
1414 None,
1415 fs.clone(),
1416 &mut cx.to_async(),
1417 )
1418 .await
1419 .unwrap();
1420 let connection = NativeAgentConnection(agent.clone());
1421
1422 // Create a thread/session
1423 let acp_thread = cx
1424 .update(|cx| {
1425 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1426 })
1427 .await
1428 .unwrap();
1429
1430 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1431
1432 // Select a model
1433 let selector = connection.model_selector(&session_id).unwrap();
1434 let model_id = acp::ModelId("fake/fake".into());
1435 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1436 .await
1437 .unwrap();
1438
1439 // Verify the thread has the selected model
1440 agent.read_with(cx, |agent, _| {
1441 let session = agent.sessions.get(&session_id).unwrap();
1442 session.thread.read_with(cx, |thread, _| {
1443 assert_eq!(thread.model().unwrap().id().0, "fake");
1444 });
1445 });
1446
1447 cx.run_until_parked();
1448
1449 // Verify settings file was updated
1450 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1451 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1452
1453 // Check that the agent settings contain the selected model
1454 assert_eq!(
1455 settings_json["agent"]["default_model"]["model"],
1456 json!("fake")
1457 );
1458 assert_eq!(
1459 settings_json["agent"]["default_model"]["provider"],
1460 json!("fake")
1461 );
1462 }
1463
1464 #[gpui::test]
1465 async fn test_save_load_thread(cx: &mut TestAppContext) {
1466 init_test(cx);
1467 let fs = FakeFs::new(cx.executor());
1468 fs.insert_tree(
1469 "/",
1470 json!({
1471 "a": {
1472 "b.md": "Lorem"
1473 }
1474 }),
1475 )
1476 .await;
1477 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1478 let text_thread_store =
1479 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1480 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1481 let agent = NativeAgent::new(
1482 project.clone(),
1483 history_store.clone(),
1484 Templates::new(),
1485 None,
1486 fs.clone(),
1487 &mut cx.to_async(),
1488 )
1489 .await
1490 .unwrap();
1491 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1492
1493 let acp_thread = cx
1494 .update(|cx| {
1495 connection
1496 .clone()
1497 .new_thread(project.clone(), Path::new(""), cx)
1498 })
1499 .await
1500 .unwrap();
1501 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1502 let thread = agent.read_with(cx, |agent, _| {
1503 agent.sessions.get(&session_id).unwrap().thread.clone()
1504 });
1505
1506 // Ensure empty threads are not saved, even if they get mutated.
1507 let model = Arc::new(FakeLanguageModel::default());
1508 let summary_model = Arc::new(FakeLanguageModel::default());
1509 thread.update(cx, |thread, cx| {
1510 thread.set_model(model.clone(), cx);
1511 thread.set_summarization_model(Some(summary_model.clone()), cx);
1512 });
1513 cx.run_until_parked();
1514 assert_eq!(history_entries(&history_store, cx), vec![]);
1515
1516 let send = acp_thread.update(cx, |thread, cx| {
1517 thread.send(
1518 vec![
1519 "What does ".into(),
1520 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1521 name: "b.md".into(),
1522 uri: MentionUri::File {
1523 abs_path: path!("/a/b.md").into(),
1524 }
1525 .to_uri()
1526 .to_string(),
1527 annotations: None,
1528 description: None,
1529 mime_type: None,
1530 size: None,
1531 title: None,
1532 meta: None,
1533 }),
1534 " mean?".into(),
1535 ],
1536 cx,
1537 )
1538 });
1539 let send = cx.foreground_executor().spawn(send);
1540 cx.run_until_parked();
1541
1542 model.send_last_completion_stream_text_chunk("Lorem.");
1543 model.end_last_completion_stream();
1544 cx.run_until_parked();
1545 summary_model
1546 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1547 summary_model.end_last_completion_stream();
1548
1549 send.await.unwrap();
1550 let uri = MentionUri::File {
1551 abs_path: path!("/a/b.md").into(),
1552 }
1553 .to_uri();
1554 acp_thread.read_with(cx, |thread, cx| {
1555 assert_eq!(
1556 thread.to_markdown(cx),
1557 formatdoc! {"
1558 ## User
1559
1560 What does [@b.md]({uri}) mean?
1561
1562 ## Assistant
1563
1564 Lorem.
1565
1566 "}
1567 )
1568 });
1569
1570 cx.run_until_parked();
1571
1572 // Drop the ACP thread, which should cause the session to be dropped as well.
1573 cx.update(|_| {
1574 drop(thread);
1575 drop(acp_thread);
1576 });
1577 agent.read_with(cx, |agent, _| {
1578 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1579 });
1580
1581 // Ensure the thread can be reloaded from disk.
1582 assert_eq!(
1583 history_entries(&history_store, cx),
1584 vec![(
1585 HistoryEntryId::AcpThread(session_id.clone()),
1586 format!("Explaining {}", path!("/a/b.md"))
1587 )]
1588 );
1589 let acp_thread = agent
1590 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1591 .await
1592 .unwrap();
1593 acp_thread.read_with(cx, |thread, cx| {
1594 assert_eq!(
1595 thread.to_markdown(cx),
1596 formatdoc! {"
1597 ## User
1598
1599 What does [@b.md]({uri}) mean?
1600
1601 ## Assistant
1602
1603 Lorem.
1604
1605 "}
1606 )
1607 });
1608 }
1609
1610 fn history_entries(
1611 history: &Entity<HistoryStore>,
1612 cx: &mut TestAppContext,
1613 ) -> Vec<(HistoryEntryId, String)> {
1614 history.read_with(cx, |history, _| {
1615 history
1616 .entries()
1617 .map(|e| (e.id(), e.title().to_string()))
1618 .collect::<Vec<_>>()
1619 })
1620 }
1621
1622 fn init_test(cx: &mut TestAppContext) {
1623 env_logger::try_init().ok();
1624 cx.update(|cx| {
1625 let settings_store = SettingsStore::test(cx);
1626 cx.set_global(settings_store);
1627
1628 LanguageModelRegistry::test(cx);
1629 });
1630 }
1631}