1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tools;
10
11#[cfg(test)]
12mod tests;
13
14pub use db::*;
15pub use history_store::*;
16pub use native_agent_server::NativeAgentServer;
17pub use templates::*;
18pub use thread::*;
19pub use tools::*;
20
21use acp_thread::{AcpThread, AgentModelSelector};
22use agent_client_protocol as acp;
23use anyhow::{Context as _, Result, anyhow};
24use chrono::{DateTime, Utc};
25use collections::{HashSet, IndexMap};
26use fs::Fs;
27use futures::channel::{mpsc, oneshot};
28use futures::future::Shared;
29use futures::{StreamExt, future};
30use gpui::{
31 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
32};
33use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
34use project::{Project, ProjectItem, ProjectPath, Worktree};
35use prompt_store::{
36 ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
37};
38use serde::{Deserialize, Serialize};
39use settings::{LanguageModelSelection, update_settings_file};
40use std::any::Any;
41use std::collections::HashMap;
42use std::path::{Path, PathBuf};
43use std::rc::Rc;
44use std::sync::Arc;
45use util::ResultExt;
46use util::rel_path::RelPath;
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ProjectSnapshot {
50 pub worktree_snapshots: Vec<project::telemetry_snapshot::TelemetryWorktreeSnapshot>,
51 pub timestamp: DateTime<Utc>,
52}
53
54const RULES_FILE_NAMES: [&str; 9] = [
55 ".rules",
56 ".cursorrules",
57 ".windsurfrules",
58 ".clinerules",
59 ".github/copilot-instructions.md",
60 "CLAUDE.md",
61 "AGENT.md",
62 "AGENTS.md",
63 "GEMINI.md",
64];
65
66pub struct RulesLoadingError {
67 pub message: SharedString,
68}
69
70/// Holds both the internal Thread and the AcpThread for a session
71struct Session {
72 /// The internal thread that processes messages
73 thread: Entity<Thread>,
74 /// The ACP thread that handles protocol communication
75 acp_thread: WeakEntity<acp_thread::AcpThread>,
76 pending_save: Task<()>,
77 _subscriptions: Vec<Subscription>,
78}
79
80pub struct LanguageModels {
81 /// Access language model by ID
82 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
83 /// Cached list for returning language model information
84 model_list: acp_thread::AgentModelList,
85 refresh_models_rx: watch::Receiver<()>,
86 refresh_models_tx: watch::Sender<()>,
87 _authenticate_all_providers_task: Task<()>,
88}
89
90impl LanguageModels {
91 fn new(cx: &mut App) -> Self {
92 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
93
94 let mut this = Self {
95 models: HashMap::default(),
96 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
97 refresh_models_rx,
98 refresh_models_tx,
99 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
100 };
101 this.refresh_list(cx);
102 this
103 }
104
105 fn refresh_list(&mut self, cx: &App) {
106 let providers = LanguageModelRegistry::global(cx)
107 .read(cx)
108 .providers()
109 .into_iter()
110 .filter(|provider| provider.is_authenticated(cx))
111 .collect::<Vec<_>>();
112
113 let mut language_model_list = IndexMap::default();
114 let mut recommended_models = HashSet::default();
115
116 let mut recommended = Vec::new();
117 for provider in &providers {
118 for model in provider.recommended_models(cx) {
119 recommended_models.insert((model.provider_id(), model.id()));
120 recommended.push(Self::map_language_model_to_info(&model, provider));
121 }
122 }
123 if !recommended.is_empty() {
124 language_model_list.insert(
125 acp_thread::AgentModelGroupName("Recommended".into()),
126 recommended,
127 );
128 }
129
130 let mut models = HashMap::default();
131 for provider in providers {
132 let mut provider_models = Vec::new();
133 for model in provider.provided_models(cx) {
134 let model_info = Self::map_language_model_to_info(&model, &provider);
135 let model_id = model_info.id.clone();
136 provider_models.push(model_info);
137 models.insert(model_id, model);
138 }
139 if !provider_models.is_empty() {
140 language_model_list.insert(
141 acp_thread::AgentModelGroupName(provider.name().0.clone()),
142 provider_models,
143 );
144 }
145 }
146
147 self.models = models;
148 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
149 self.refresh_models_tx.send(()).ok();
150 }
151
152 fn watch(&self) -> watch::Receiver<()> {
153 self.refresh_models_rx.clone()
154 }
155
156 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
157 self.models.get(model_id).cloned()
158 }
159
160 fn map_language_model_to_info(
161 model: &Arc<dyn LanguageModel>,
162 provider: &Arc<dyn LanguageModelProvider>,
163 ) -> acp_thread::AgentModelInfo {
164 acp_thread::AgentModelInfo {
165 id: Self::model_id(model),
166 name: model.name().0,
167 description: None,
168 icon: Some(provider.icon()),
169 }
170 }
171
172 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
173 acp::ModelId::new(format!("{}/{}", model.provider_id().0, model.id().0))
174 }
175
176 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
177 let authenticate_all_providers = LanguageModelRegistry::global(cx)
178 .read(cx)
179 .providers()
180 .iter()
181 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
182 .collect::<Vec<_>>();
183
184 cx.background_spawn(async move {
185 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
186 if let Err(err) = authenticate_task.await {
187 match err {
188 language_model::AuthenticateError::CredentialsNotFound => {
189 // Since we're authenticating these providers in the
190 // background for the purposes of populating the
191 // language selector, we don't care about providers
192 // where the credentials are not found.
193 }
194 language_model::AuthenticateError::ConnectionRefused => {
195 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
196 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
197 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
198 }
199 _ => {
200 // Some providers have noisy failure states that we
201 // don't want to spam the logs with every time the
202 // language model selector is initialized.
203 //
204 // Ideally these should have more clear failure modes
205 // that we know are safe to ignore here, like what we do
206 // with `CredentialsNotFound` above.
207 match provider_id.0.as_ref() {
208 "lmstudio" | "ollama" => {
209 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
210 //
211 // These fail noisily, so we don't log them.
212 }
213 "copilot_chat" => {
214 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
215 }
216 _ => {
217 log::error!(
218 "Failed to authenticate provider: {}: {err:#}",
219 provider_name.0
220 );
221 }
222 }
223 }
224 }
225 }
226 }
227 })
228 }
229}
230
231pub struct NativeAgent {
232 /// Session ID -> Session mapping
233 sessions: HashMap<acp::SessionId, Session>,
234 history: Entity<HistoryStore>,
235 /// Shared project context for all threads
236 project_context: Entity<ProjectContext>,
237 project_context_needs_refresh: watch::Sender<()>,
238 _maintain_project_context: Task<Result<()>>,
239 context_server_registry: Entity<ContextServerRegistry>,
240 /// Shared templates for all threads
241 templates: Arc<Templates>,
242 /// Cached model information
243 models: LanguageModels,
244 project: Entity<Project>,
245 prompt_store: Option<Entity<PromptStore>>,
246 fs: Arc<dyn Fs>,
247 _subscriptions: Vec<Subscription>,
248}
249
250impl NativeAgent {
251 pub async fn new(
252 project: Entity<Project>,
253 history: Entity<HistoryStore>,
254 templates: Arc<Templates>,
255 prompt_store: Option<Entity<PromptStore>>,
256 fs: Arc<dyn Fs>,
257 cx: &mut AsyncApp,
258 ) -> Result<Entity<NativeAgent>> {
259 log::debug!("Creating new NativeAgent");
260
261 let project_context = cx
262 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
263 .await;
264
265 cx.new(|cx| {
266 let mut subscriptions = vec![
267 cx.subscribe(&project, Self::handle_project_event),
268 cx.subscribe(
269 &LanguageModelRegistry::global(cx),
270 Self::handle_models_updated_event,
271 ),
272 ];
273 if let Some(prompt_store) = prompt_store.as_ref() {
274 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
275 }
276
277 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
278 watch::channel(());
279 Self {
280 sessions: HashMap::new(),
281 history,
282 project_context: cx.new(|_| project_context),
283 project_context_needs_refresh: project_context_needs_refresh_tx,
284 _maintain_project_context: cx.spawn(async move |this, cx| {
285 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
286 }),
287 context_server_registry: cx.new(|cx| {
288 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
289 }),
290 templates,
291 models: LanguageModels::new(cx),
292 project,
293 prompt_store,
294 fs,
295 _subscriptions: subscriptions,
296 }
297 })
298 }
299
300 fn register_session(
301 &mut self,
302 thread_handle: Entity<Thread>,
303 cx: &mut Context<Self>,
304 ) -> Entity<AcpThread> {
305 let connection = Rc::new(NativeAgentConnection(cx.entity()));
306
307 let thread = thread_handle.read(cx);
308 let session_id = thread.id().clone();
309 let title = thread.title();
310 let project = thread.project.clone();
311 let action_log = thread.action_log.clone();
312 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
313 let acp_thread = cx.new(|cx| {
314 acp_thread::AcpThread::new(
315 title,
316 connection,
317 project.clone(),
318 action_log.clone(),
319 session_id.clone(),
320 prompt_capabilities_rx,
321 cx,
322 )
323 });
324
325 let registry = LanguageModelRegistry::read_global(cx);
326 let summarization_model = registry.thread_summary_model().map(|c| c.model);
327
328 thread_handle.update(cx, |thread, cx| {
329 thread.set_summarization_model(summarization_model, cx);
330 thread.add_default_tools(
331 Rc::new(AcpThreadEnvironment {
332 acp_thread: acp_thread.downgrade(),
333 }) as _,
334 cx,
335 )
336 });
337
338 let subscriptions = vec![
339 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
340 this.sessions.remove(acp_thread.session_id());
341 }),
342 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
343 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
344 cx.observe(&thread_handle, move |this, thread, cx| {
345 this.save_thread(thread, cx)
346 }),
347 ];
348
349 self.sessions.insert(
350 session_id,
351 Session {
352 thread: thread_handle,
353 acp_thread: acp_thread.downgrade(),
354 _subscriptions: subscriptions,
355 pending_save: Task::ready(()),
356 },
357 );
358 acp_thread
359 }
360
361 pub fn models(&self) -> &LanguageModels {
362 &self.models
363 }
364
365 async fn maintain_project_context(
366 this: WeakEntity<Self>,
367 mut needs_refresh: watch::Receiver<()>,
368 cx: &mut AsyncApp,
369 ) -> Result<()> {
370 while needs_refresh.changed().await.is_ok() {
371 let project_context = this
372 .update(cx, |this, cx| {
373 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
374 })?
375 .await;
376 this.update(cx, |this, cx| {
377 this.project_context = cx.new(|_| project_context);
378 })?;
379 }
380
381 Ok(())
382 }
383
384 fn build_project_context(
385 project: &Entity<Project>,
386 prompt_store: Option<&Entity<PromptStore>>,
387 cx: &mut App,
388 ) -> Task<ProjectContext> {
389 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
390 let worktree_tasks = worktrees
391 .into_iter()
392 .map(|worktree| {
393 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
394 })
395 .collect::<Vec<_>>();
396 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
397 prompt_store.read_with(cx, |prompt_store, cx| {
398 let prompts = prompt_store.default_prompt_metadata();
399 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
400 let contents = prompt_store.load(prompt_metadata.id, cx);
401 async move { (contents.await, prompt_metadata) }
402 });
403 cx.background_spawn(future::join_all(load_tasks))
404 })
405 } else {
406 Task::ready(vec![])
407 };
408
409 cx.spawn(async move |_cx| {
410 let (worktrees, default_user_rules) =
411 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
412
413 let worktrees = worktrees
414 .into_iter()
415 .map(|(worktree, _rules_error)| {
416 // TODO: show error message
417 // if let Some(rules_error) = rules_error {
418 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
419 // }
420 worktree
421 })
422 .collect::<Vec<_>>();
423
424 let default_user_rules = default_user_rules
425 .into_iter()
426 .flat_map(|(contents, prompt_metadata)| match contents {
427 Ok(contents) => Some(UserRulesContext {
428 uuid: match prompt_metadata.id {
429 prompt_store::PromptId::User { uuid } => uuid,
430 prompt_store::PromptId::EditWorkflow => return None,
431 },
432 title: prompt_metadata.title.map(|title| title.to_string()),
433 contents,
434 }),
435 Err(_err) => {
436 // TODO: show error message
437 // this.update(cx, |_, cx| {
438 // cx.emit(RulesLoadingError {
439 // message: format!("{err:?}").into(),
440 // });
441 // })
442 // .ok();
443 None
444 }
445 })
446 .collect::<Vec<_>>();
447
448 ProjectContext::new(worktrees, default_user_rules)
449 })
450 }
451
452 fn load_worktree_info_for_system_prompt(
453 worktree: Entity<Worktree>,
454 project: Entity<Project>,
455 cx: &mut App,
456 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
457 let tree = worktree.read(cx);
458 let root_name = tree.root_name_str().into();
459 let abs_path = tree.abs_path();
460
461 let mut context = WorktreeContext {
462 root_name,
463 abs_path,
464 rules_file: None,
465 };
466
467 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
468 let Some(rules_task) = rules_task else {
469 return Task::ready((context, None));
470 };
471
472 cx.spawn(async move |_| {
473 let (rules_file, rules_file_error) = match rules_task.await {
474 Ok(rules_file) => (Some(rules_file), None),
475 Err(err) => (
476 None,
477 Some(RulesLoadingError {
478 message: format!("{err}").into(),
479 }),
480 ),
481 };
482 context.rules_file = rules_file;
483 (context, rules_file_error)
484 })
485 }
486
487 fn load_worktree_rules_file(
488 worktree: Entity<Worktree>,
489 project: Entity<Project>,
490 cx: &mut App,
491 ) -> Option<Task<Result<RulesFileContext>>> {
492 let worktree = worktree.read(cx);
493 let worktree_id = worktree.id();
494 let selected_rules_file = RULES_FILE_NAMES
495 .into_iter()
496 .filter_map(|name| {
497 worktree
498 .entry_for_path(RelPath::unix(name).unwrap())
499 .filter(|entry| entry.is_file())
500 .map(|entry| entry.path.clone())
501 })
502 .next();
503
504 // Note that Cline supports `.clinerules` being a directory, but that is not currently
505 // supported. This doesn't seem to occur often in GitHub repositories.
506 selected_rules_file.map(|path_in_worktree| {
507 let project_path = ProjectPath {
508 worktree_id,
509 path: path_in_worktree.clone(),
510 };
511 let buffer_task =
512 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
513 let rope_task = cx.spawn(async move |cx| {
514 buffer_task.await?.read_with(cx, |buffer, cx| {
515 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
516 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
517 })?
518 });
519 // Build a string from the rope on a background thread.
520 cx.background_spawn(async move {
521 let (project_entry_id, rope) = rope_task.await?;
522 anyhow::Ok(RulesFileContext {
523 path_in_worktree,
524 text: rope.to_string().trim().to_string(),
525 project_entry_id: project_entry_id.to_usize(),
526 })
527 })
528 })
529 }
530
531 fn handle_thread_title_updated(
532 &mut self,
533 thread: Entity<Thread>,
534 _: &TitleUpdated,
535 cx: &mut Context<Self>,
536 ) {
537 let session_id = thread.read(cx).id();
538 let Some(session) = self.sessions.get(session_id) else {
539 return;
540 };
541 let thread = thread.downgrade();
542 let acp_thread = session.acp_thread.clone();
543 cx.spawn(async move |_, cx| {
544 let title = thread.read_with(cx, |thread, _| thread.title())?;
545 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
546 task.await
547 })
548 .detach_and_log_err(cx);
549 }
550
551 fn handle_thread_token_usage_updated(
552 &mut self,
553 thread: Entity<Thread>,
554 usage: &TokenUsageUpdated,
555 cx: &mut Context<Self>,
556 ) {
557 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
558 return;
559 };
560 session
561 .acp_thread
562 .update(cx, |acp_thread, cx| {
563 acp_thread.update_token_usage(usage.0.clone(), cx);
564 })
565 .ok();
566 }
567
568 fn handle_project_event(
569 &mut self,
570 _project: Entity<Project>,
571 event: &project::Event,
572 _cx: &mut Context<Self>,
573 ) {
574 match event {
575 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
576 self.project_context_needs_refresh.send(()).ok();
577 }
578 project::Event::WorktreeUpdatedEntries(_, items) => {
579 if items.iter().any(|(path, _, _)| {
580 RULES_FILE_NAMES
581 .iter()
582 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
583 }) {
584 self.project_context_needs_refresh.send(()).ok();
585 }
586 }
587 _ => {}
588 }
589 }
590
591 fn handle_prompts_updated_event(
592 &mut self,
593 _prompt_store: Entity<PromptStore>,
594 _event: &prompt_store::PromptsUpdatedEvent,
595 _cx: &mut Context<Self>,
596 ) {
597 self.project_context_needs_refresh.send(()).ok();
598 }
599
600 fn handle_models_updated_event(
601 &mut self,
602 _registry: Entity<LanguageModelRegistry>,
603 _event: &language_model::Event,
604 cx: &mut Context<Self>,
605 ) {
606 self.models.refresh_list(cx);
607
608 let registry = LanguageModelRegistry::read_global(cx);
609 let default_model = registry.default_model().map(|m| m.model);
610 let summarization_model = registry.thread_summary_model().map(|m| m.model);
611
612 for session in self.sessions.values_mut() {
613 session.thread.update(cx, |thread, cx| {
614 if thread.model().is_none()
615 && let Some(model) = default_model.clone()
616 {
617 thread.set_model(model, cx);
618 cx.notify();
619 }
620 thread.set_summarization_model(summarization_model.clone(), cx);
621 });
622 }
623 }
624
625 pub fn load_thread(
626 &mut self,
627 id: acp::SessionId,
628 cx: &mut Context<Self>,
629 ) -> Task<Result<Entity<Thread>>> {
630 let database_future = ThreadsDatabase::connect(cx);
631 cx.spawn(async move |this, cx| {
632 let database = database_future.await.map_err(|err| anyhow!(err))?;
633 let db_thread = database
634 .load_thread(id.clone())
635 .await?
636 .with_context(|| format!("no thread found with ID: {id:?}"))?;
637
638 this.update(cx, |this, cx| {
639 let summarization_model = LanguageModelRegistry::read_global(cx)
640 .thread_summary_model()
641 .map(|c| c.model);
642
643 cx.new(|cx| {
644 let mut thread = Thread::from_db(
645 id.clone(),
646 db_thread,
647 this.project.clone(),
648 this.project_context.clone(),
649 this.context_server_registry.clone(),
650 this.templates.clone(),
651 cx,
652 );
653 thread.set_summarization_model(summarization_model, cx);
654 thread
655 })
656 })
657 })
658 }
659
660 pub fn open_thread(
661 &mut self,
662 id: acp::SessionId,
663 cx: &mut Context<Self>,
664 ) -> Task<Result<Entity<AcpThread>>> {
665 let task = self.load_thread(id, cx);
666 cx.spawn(async move |this, cx| {
667 let thread = task.await?;
668 let acp_thread =
669 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
670 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
671 cx.update(|cx| {
672 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
673 })?
674 .await?;
675 Ok(acp_thread)
676 })
677 }
678
679 pub fn thread_summary(
680 &mut self,
681 id: acp::SessionId,
682 cx: &mut Context<Self>,
683 ) -> Task<Result<SharedString>> {
684 let thread = self.open_thread(id.clone(), cx);
685 cx.spawn(async move |this, cx| {
686 let acp_thread = thread.await?;
687 let result = this
688 .update(cx, |this, cx| {
689 this.sessions
690 .get(&id)
691 .unwrap()
692 .thread
693 .update(cx, |thread, cx| thread.summary(cx))
694 })?
695 .await
696 .context("Failed to generate summary")?;
697 drop(acp_thread);
698 Ok(result)
699 })
700 }
701
702 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
703 if thread.read(cx).is_empty() {
704 return;
705 }
706
707 let database_future = ThreadsDatabase::connect(cx);
708 let (id, db_thread) =
709 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
710 let Some(session) = self.sessions.get_mut(&id) else {
711 return;
712 };
713 let history = self.history.clone();
714 session.pending_save = cx.spawn(async move |_, cx| {
715 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
716 return;
717 };
718 let db_thread = db_thread.await;
719 database.save_thread(id, db_thread).await.log_err();
720 history.update(cx, |history, cx| history.reload(cx)).ok();
721 });
722 }
723}
724
725/// Wrapper struct that implements the AgentConnection trait
726#[derive(Clone)]
727pub struct NativeAgentConnection(pub Entity<NativeAgent>);
728
729impl NativeAgentConnection {
730 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
731 self.0
732 .read(cx)
733 .sessions
734 .get(session_id)
735 .map(|session| session.thread.clone())
736 }
737
738 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
739 self.0.update(cx, |this, cx| this.load_thread(id, cx))
740 }
741
742 fn run_turn(
743 &self,
744 session_id: acp::SessionId,
745 cx: &mut App,
746 f: impl 'static
747 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
748 ) -> Task<Result<acp::PromptResponse>> {
749 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
750 agent
751 .sessions
752 .get_mut(&session_id)
753 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
754 }) else {
755 return Task::ready(Err(anyhow!("Session not found")));
756 };
757 log::debug!("Found session for: {}", session_id);
758
759 let response_stream = match f(thread, cx) {
760 Ok(stream) => stream,
761 Err(err) => return Task::ready(Err(err)),
762 };
763 Self::handle_thread_events(response_stream, acp_thread, cx)
764 }
765
766 fn handle_thread_events(
767 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
768 acp_thread: WeakEntity<AcpThread>,
769 cx: &App,
770 ) -> Task<Result<acp::PromptResponse>> {
771 cx.spawn(async move |cx| {
772 // Handle response stream and forward to session.acp_thread
773 while let Some(result) = events.next().await {
774 match result {
775 Ok(event) => {
776 log::trace!("Received completion event: {:?}", event);
777
778 match event {
779 ThreadEvent::UserMessage(message) => {
780 acp_thread.update(cx, |thread, cx| {
781 for content in message.content {
782 thread.push_user_content_block(
783 Some(message.id.clone()),
784 content.into(),
785 cx,
786 );
787 }
788 })?;
789 }
790 ThreadEvent::AgentText(text) => {
791 acp_thread.update(cx, |thread, cx| {
792 thread.push_assistant_content_block(text.into(), false, cx)
793 })?;
794 }
795 ThreadEvent::AgentThinking(text) => {
796 acp_thread.update(cx, |thread, cx| {
797 thread.push_assistant_content_block(text.into(), true, cx)
798 })?;
799 }
800 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
801 tool_call,
802 options,
803 response,
804 }) => {
805 let outcome_task = acp_thread.update(cx, |thread, cx| {
806 thread.request_tool_call_authorization(
807 tool_call, options, true, cx,
808 )
809 })??;
810 cx.background_spawn(async move {
811 if let acp::RequestPermissionOutcome::Selected(
812 acp::SelectedPermissionOutcome { option_id, .. },
813 ) = outcome_task.await
814 {
815 response
816 .send(option_id)
817 .map(|_| anyhow!("authorization receiver was dropped"))
818 .log_err();
819 }
820 })
821 .detach();
822 }
823 ThreadEvent::ToolCall(tool_call) => {
824 acp_thread.update(cx, |thread, cx| {
825 thread.upsert_tool_call(tool_call, cx)
826 })??;
827 }
828 ThreadEvent::ToolCallUpdate(update) => {
829 acp_thread.update(cx, |thread, cx| {
830 thread.update_tool_call(update, cx)
831 })??;
832 }
833 ThreadEvent::Retry(status) => {
834 acp_thread.update(cx, |thread, cx| {
835 thread.update_retry_status(status, cx)
836 })?;
837 }
838 ThreadEvent::Stop(stop_reason) => {
839 log::debug!("Assistant message complete: {:?}", stop_reason);
840 return Ok(acp::PromptResponse::new(stop_reason));
841 }
842 }
843 }
844 Err(e) => {
845 log::error!("Error in model response stream: {:?}", e);
846 return Err(e);
847 }
848 }
849 }
850
851 log::debug!("Response stream completed");
852 anyhow::Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
853 })
854 }
855}
856
857struct NativeAgentModelSelector {
858 session_id: acp::SessionId,
859 connection: NativeAgentConnection,
860}
861
862impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
863 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
864 log::debug!("NativeAgentConnection::list_models called");
865 let list = self.connection.0.read(cx).models.model_list.clone();
866 Task::ready(if list.is_empty() {
867 Err(anyhow::anyhow!("No models available"))
868 } else {
869 Ok(list)
870 })
871 }
872
873 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
874 log::debug!(
875 "Setting model for session {}: {}",
876 self.session_id,
877 model_id
878 );
879 let Some(thread) = self
880 .connection
881 .0
882 .read(cx)
883 .sessions
884 .get(&self.session_id)
885 .map(|session| session.thread.clone())
886 else {
887 return Task::ready(Err(anyhow!("Session not found")));
888 };
889
890 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
891 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
892 };
893
894 thread.update(cx, |thread, cx| {
895 thread.set_model(model.clone(), cx);
896 });
897
898 update_settings_file(
899 self.connection.0.read(cx).fs.clone(),
900 cx,
901 move |settings, _cx| {
902 let provider = model.provider_id().0.to_string();
903 let model = model.id().0.to_string();
904 settings
905 .agent
906 .get_or_insert_default()
907 .set_model(LanguageModelSelection {
908 provider: provider.into(),
909 model,
910 });
911 },
912 );
913
914 Task::ready(Ok(()))
915 }
916
917 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
918 let Some(thread) = self
919 .connection
920 .0
921 .read(cx)
922 .sessions
923 .get(&self.session_id)
924 .map(|session| session.thread.clone())
925 else {
926 return Task::ready(Err(anyhow!("Session not found")));
927 };
928 let Some(model) = thread.read(cx).model() else {
929 return Task::ready(Err(anyhow!("Model not found")));
930 };
931 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
932 else {
933 return Task::ready(Err(anyhow!("Provider not found")));
934 };
935 Task::ready(Ok(LanguageModels::map_language_model_to_info(
936 model, &provider,
937 )))
938 }
939
940 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
941 Some(self.connection.0.read(cx).models.watch())
942 }
943
944 fn should_render_footer(&self) -> bool {
945 true
946 }
947}
948
949impl acp_thread::AgentConnection for NativeAgentConnection {
950 fn telemetry_id(&self) -> &'static str {
951 "zed"
952 }
953
954 fn new_thread(
955 self: Rc<Self>,
956 project: Entity<Project>,
957 cwd: &Path,
958 cx: &mut App,
959 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
960 let agent = self.0.clone();
961 log::debug!("Creating new thread for project at: {:?}", cwd);
962
963 cx.spawn(async move |cx| {
964 log::debug!("Starting thread creation in async context");
965
966 // Create Thread
967 let thread = agent.update(
968 cx,
969 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
970 // Fetch default model from registry settings
971 let registry = LanguageModelRegistry::read_global(cx);
972 // Log available models for debugging
973 let available_count = registry.available_models(cx).count();
974 log::debug!("Total available models: {}", available_count);
975
976 let default_model = registry.default_model().and_then(|default_model| {
977 agent
978 .models
979 .model_from_id(&LanguageModels::model_id(&default_model.model))
980 });
981 Ok(cx.new(|cx| {
982 Thread::new(
983 project.clone(),
984 agent.project_context.clone(),
985 agent.context_server_registry.clone(),
986 agent.templates.clone(),
987 default_model,
988 cx,
989 )
990 }))
991 },
992 )??;
993 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
994 })
995 }
996
997 fn auth_methods(&self) -> &[acp::AuthMethod] {
998 &[] // No auth for in-process
999 }
1000
1001 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1002 Task::ready(Ok(()))
1003 }
1004
1005 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1006 Some(Rc::new(NativeAgentModelSelector {
1007 session_id: session_id.clone(),
1008 connection: self.clone(),
1009 }) as Rc<dyn AgentModelSelector>)
1010 }
1011
1012 fn prompt(
1013 &self,
1014 id: Option<acp_thread::UserMessageId>,
1015 params: acp::PromptRequest,
1016 cx: &mut App,
1017 ) -> Task<Result<acp::PromptResponse>> {
1018 let id = id.expect("UserMessageId is required");
1019 let session_id = params.session_id.clone();
1020 log::info!("Received prompt request for session: {}", session_id);
1021 log::debug!("Prompt blocks count: {}", params.prompt.len());
1022 let path_style = self.0.read(cx).project.read(cx).path_style(cx);
1023
1024 self.run_turn(session_id, cx, move |thread, cx| {
1025 let content: Vec<UserMessageContent> = params
1026 .prompt
1027 .into_iter()
1028 .map(|block| UserMessageContent::from_content_block(block, path_style))
1029 .collect::<Vec<_>>();
1030 log::debug!("Converted prompt to message: {} chars", content.len());
1031 log::debug!("Message id: {:?}", id);
1032 log::debug!("Message content: {:?}", content);
1033
1034 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1035 })
1036 }
1037
1038 fn resume(
1039 &self,
1040 session_id: &acp::SessionId,
1041 _cx: &App,
1042 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1043 Some(Rc::new(NativeAgentSessionResume {
1044 connection: self.clone(),
1045 session_id: session_id.clone(),
1046 }) as _)
1047 }
1048
1049 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1050 log::info!("Cancelling on session: {}", session_id);
1051 self.0.update(cx, |agent, cx| {
1052 if let Some(agent) = agent.sessions.get(session_id) {
1053 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1054 }
1055 });
1056 }
1057
1058 fn truncate(
1059 &self,
1060 session_id: &agent_client_protocol::SessionId,
1061 cx: &App,
1062 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1063 self.0.read_with(cx, |agent, _cx| {
1064 agent.sessions.get(session_id).map(|session| {
1065 Rc::new(NativeAgentSessionTruncate {
1066 thread: session.thread.clone(),
1067 acp_thread: session.acp_thread.clone(),
1068 }) as _
1069 })
1070 })
1071 }
1072
1073 fn set_title(
1074 &self,
1075 session_id: &acp::SessionId,
1076 _cx: &App,
1077 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1078 Some(Rc::new(NativeAgentSessionSetTitle {
1079 connection: self.clone(),
1080 session_id: session_id.clone(),
1081 }) as _)
1082 }
1083
1084 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1085 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1086 }
1087
1088 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1089 self
1090 }
1091}
1092
1093impl acp_thread::AgentTelemetry for NativeAgentConnection {
1094 fn thread_data(
1095 &self,
1096 session_id: &acp::SessionId,
1097 cx: &mut App,
1098 ) -> Task<Result<serde_json::Value>> {
1099 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1100 return Task::ready(Err(anyhow!("Session not found")));
1101 };
1102
1103 let task = session.thread.read(cx).to_db(cx);
1104 cx.background_spawn(async move {
1105 serde_json::to_value(task.await).context("Failed to serialize thread")
1106 })
1107 }
1108}
1109
1110struct NativeAgentSessionTruncate {
1111 thread: Entity<Thread>,
1112 acp_thread: WeakEntity<AcpThread>,
1113}
1114
1115impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1116 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1117 match self.thread.update(cx, |thread, cx| {
1118 thread.truncate(message_id.clone(), cx)?;
1119 Ok(thread.latest_token_usage())
1120 }) {
1121 Ok(usage) => {
1122 self.acp_thread
1123 .update(cx, |thread, cx| {
1124 thread.update_token_usage(usage, cx);
1125 })
1126 .ok();
1127 Task::ready(Ok(()))
1128 }
1129 Err(error) => Task::ready(Err(error)),
1130 }
1131 }
1132}
1133
1134struct NativeAgentSessionResume {
1135 connection: NativeAgentConnection,
1136 session_id: acp::SessionId,
1137}
1138
1139impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1140 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1141 self.connection
1142 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1143 thread.update(cx, |thread, cx| thread.resume(cx))
1144 })
1145 }
1146}
1147
1148struct NativeAgentSessionSetTitle {
1149 connection: NativeAgentConnection,
1150 session_id: acp::SessionId,
1151}
1152
1153impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1154 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1155 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1156 return Task::ready(Err(anyhow!("session not found")));
1157 };
1158 let thread = session.thread.clone();
1159 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1160 Task::ready(Ok(()))
1161 }
1162}
1163
1164pub struct AcpThreadEnvironment {
1165 acp_thread: WeakEntity<AcpThread>,
1166}
1167
1168impl ThreadEnvironment for AcpThreadEnvironment {
1169 fn create_terminal(
1170 &self,
1171 command: String,
1172 cwd: Option<PathBuf>,
1173 output_byte_limit: Option<u64>,
1174 cx: &mut AsyncApp,
1175 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1176 let task = self.acp_thread.update(cx, |thread, cx| {
1177 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1178 });
1179
1180 let acp_thread = self.acp_thread.clone();
1181 cx.spawn(async move |cx| {
1182 let terminal = task?.await?;
1183
1184 let (drop_tx, drop_rx) = oneshot::channel();
1185 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1186
1187 cx.spawn(async move |cx| {
1188 drop_rx.await.ok();
1189 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1190 })
1191 .detach();
1192
1193 let handle = AcpTerminalHandle {
1194 terminal,
1195 _drop_tx: Some(drop_tx),
1196 };
1197
1198 Ok(Rc::new(handle) as _)
1199 })
1200 }
1201}
1202
1203pub struct AcpTerminalHandle {
1204 terminal: Entity<acp_thread::Terminal>,
1205 _drop_tx: Option<oneshot::Sender<()>>,
1206}
1207
1208impl TerminalHandle for AcpTerminalHandle {
1209 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1210 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1211 }
1212
1213 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1214 self.terminal
1215 .read_with(cx, |term, _cx| term.wait_for_exit())
1216 }
1217
1218 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1219 self.terminal
1220 .read_with(cx, |term, cx| term.current_output(cx))
1221 }
1222}
1223
1224#[cfg(test)]
1225mod internal_tests {
1226 use crate::HistoryEntryId;
1227
1228 use super::*;
1229 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1230 use fs::FakeFs;
1231 use gpui::TestAppContext;
1232 use indoc::formatdoc;
1233 use language_model::fake_provider::FakeLanguageModel;
1234 use serde_json::json;
1235 use settings::SettingsStore;
1236 use util::{path, rel_path::rel_path};
1237
1238 #[gpui::test]
1239 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1240 init_test(cx);
1241 let fs = FakeFs::new(cx.executor());
1242 fs.insert_tree(
1243 "/",
1244 json!({
1245 "a": {}
1246 }),
1247 )
1248 .await;
1249 let project = Project::test(fs.clone(), [], cx).await;
1250 let text_thread_store =
1251 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1252 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1253 let agent = NativeAgent::new(
1254 project.clone(),
1255 history_store,
1256 Templates::new(),
1257 None,
1258 fs.clone(),
1259 &mut cx.to_async(),
1260 )
1261 .await
1262 .unwrap();
1263 agent.read_with(cx, |agent, cx| {
1264 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1265 });
1266
1267 let worktree = project
1268 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1269 .await
1270 .unwrap();
1271 cx.run_until_parked();
1272 agent.read_with(cx, |agent, cx| {
1273 assert_eq!(
1274 agent.project_context.read(cx).worktrees,
1275 vec![WorktreeContext {
1276 root_name: "a".into(),
1277 abs_path: Path::new("/a").into(),
1278 rules_file: None
1279 }]
1280 )
1281 });
1282
1283 // Creating `/a/.rules` updates the project context.
1284 fs.insert_file("/a/.rules", Vec::new()).await;
1285 cx.run_until_parked();
1286 agent.read_with(cx, |agent, cx| {
1287 let rules_entry = worktree
1288 .read(cx)
1289 .entry_for_path(rel_path(".rules"))
1290 .unwrap();
1291 assert_eq!(
1292 agent.project_context.read(cx).worktrees,
1293 vec![WorktreeContext {
1294 root_name: "a".into(),
1295 abs_path: Path::new("/a").into(),
1296 rules_file: Some(RulesFileContext {
1297 path_in_worktree: rel_path(".rules").into(),
1298 text: "".into(),
1299 project_entry_id: rules_entry.id.to_usize()
1300 })
1301 }]
1302 )
1303 });
1304 }
1305
1306 #[gpui::test]
1307 async fn test_listing_models(cx: &mut TestAppContext) {
1308 init_test(cx);
1309 let fs = FakeFs::new(cx.executor());
1310 fs.insert_tree("/", json!({ "a": {} })).await;
1311 let project = Project::test(fs.clone(), [], cx).await;
1312 let text_thread_store =
1313 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1314 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1315 let connection = NativeAgentConnection(
1316 NativeAgent::new(
1317 project.clone(),
1318 history_store,
1319 Templates::new(),
1320 None,
1321 fs.clone(),
1322 &mut cx.to_async(),
1323 )
1324 .await
1325 .unwrap(),
1326 );
1327
1328 // Create a thread/session
1329 let acp_thread = cx
1330 .update(|cx| {
1331 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1332 })
1333 .await
1334 .unwrap();
1335
1336 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1337
1338 let models = cx
1339 .update(|cx| {
1340 connection
1341 .model_selector(&session_id)
1342 .unwrap()
1343 .list_models(cx)
1344 })
1345 .await
1346 .unwrap();
1347
1348 let acp_thread::AgentModelList::Grouped(models) = models else {
1349 panic!("Unexpected model group");
1350 };
1351 assert_eq!(
1352 models,
1353 IndexMap::from_iter([(
1354 AgentModelGroupName("Fake".into()),
1355 vec![AgentModelInfo {
1356 id: acp::ModelId::new("fake/fake"),
1357 name: "Fake".into(),
1358 description: None,
1359 icon: Some(ui::IconName::ZedAssistant),
1360 }]
1361 )])
1362 );
1363 }
1364
1365 #[gpui::test]
1366 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1367 init_test(cx);
1368 let fs = FakeFs::new(cx.executor());
1369 fs.create_dir(paths::settings_file().parent().unwrap())
1370 .await
1371 .unwrap();
1372 fs.insert_file(
1373 paths::settings_file(),
1374 json!({
1375 "agent": {
1376 "default_model": {
1377 "provider": "foo",
1378 "model": "bar"
1379 }
1380 }
1381 })
1382 .to_string()
1383 .into_bytes(),
1384 )
1385 .await;
1386 let project = Project::test(fs.clone(), [], cx).await;
1387
1388 let text_thread_store =
1389 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1390 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1391
1392 // Create the agent and connection
1393 let agent = NativeAgent::new(
1394 project.clone(),
1395 history_store,
1396 Templates::new(),
1397 None,
1398 fs.clone(),
1399 &mut cx.to_async(),
1400 )
1401 .await
1402 .unwrap();
1403 let connection = NativeAgentConnection(agent.clone());
1404
1405 // Create a thread/session
1406 let acp_thread = cx
1407 .update(|cx| {
1408 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1409 })
1410 .await
1411 .unwrap();
1412
1413 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1414
1415 // Select a model
1416 let selector = connection.model_selector(&session_id).unwrap();
1417 let model_id = acp::ModelId::new("fake/fake");
1418 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1419 .await
1420 .unwrap();
1421
1422 // Verify the thread has the selected model
1423 agent.read_with(cx, |agent, _| {
1424 let session = agent.sessions.get(&session_id).unwrap();
1425 session.thread.read_with(cx, |thread, _| {
1426 assert_eq!(thread.model().unwrap().id().0, "fake");
1427 });
1428 });
1429
1430 cx.run_until_parked();
1431
1432 // Verify settings file was updated
1433 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1434 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1435
1436 // Check that the agent settings contain the selected model
1437 assert_eq!(
1438 settings_json["agent"]["default_model"]["model"],
1439 json!("fake")
1440 );
1441 assert_eq!(
1442 settings_json["agent"]["default_model"]["provider"],
1443 json!("fake")
1444 );
1445 }
1446
1447 #[gpui::test]
1448 async fn test_save_load_thread(cx: &mut TestAppContext) {
1449 init_test(cx);
1450 let fs = FakeFs::new(cx.executor());
1451 fs.insert_tree(
1452 "/",
1453 json!({
1454 "a": {
1455 "b.md": "Lorem"
1456 }
1457 }),
1458 )
1459 .await;
1460 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1461 let text_thread_store =
1462 cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
1463 let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
1464 let agent = NativeAgent::new(
1465 project.clone(),
1466 history_store.clone(),
1467 Templates::new(),
1468 None,
1469 fs.clone(),
1470 &mut cx.to_async(),
1471 )
1472 .await
1473 .unwrap();
1474 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1475
1476 let acp_thread = cx
1477 .update(|cx| {
1478 connection
1479 .clone()
1480 .new_thread(project.clone(), Path::new(""), cx)
1481 })
1482 .await
1483 .unwrap();
1484 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1485 let thread = agent.read_with(cx, |agent, _| {
1486 agent.sessions.get(&session_id).unwrap().thread.clone()
1487 });
1488
1489 // Ensure empty threads are not saved, even if they get mutated.
1490 let model = Arc::new(FakeLanguageModel::default());
1491 let summary_model = Arc::new(FakeLanguageModel::default());
1492 thread.update(cx, |thread, cx| {
1493 thread.set_model(model.clone(), cx);
1494 thread.set_summarization_model(Some(summary_model.clone()), cx);
1495 });
1496 cx.run_until_parked();
1497 assert_eq!(history_entries(&history_store, cx), vec![]);
1498
1499 let send = acp_thread.update(cx, |thread, cx| {
1500 thread.send(
1501 vec![
1502 "What does ".into(),
1503 acp::ContentBlock::ResourceLink(acp::ResourceLink::new(
1504 "b.md",
1505 MentionUri::File {
1506 abs_path: path!("/a/b.md").into(),
1507 }
1508 .to_uri()
1509 .to_string(),
1510 )),
1511 " mean?".into(),
1512 ],
1513 cx,
1514 )
1515 });
1516 let send = cx.foreground_executor().spawn(send);
1517 cx.run_until_parked();
1518
1519 model.send_last_completion_stream_text_chunk("Lorem.");
1520 model.end_last_completion_stream();
1521 cx.run_until_parked();
1522 summary_model
1523 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1524 summary_model.end_last_completion_stream();
1525
1526 send.await.unwrap();
1527 let uri = MentionUri::File {
1528 abs_path: path!("/a/b.md").into(),
1529 }
1530 .to_uri();
1531 acp_thread.read_with(cx, |thread, cx| {
1532 assert_eq!(
1533 thread.to_markdown(cx),
1534 formatdoc! {"
1535 ## User
1536
1537 What does [@b.md]({uri}) mean?
1538
1539 ## Assistant
1540
1541 Lorem.
1542
1543 "}
1544 )
1545 });
1546
1547 cx.run_until_parked();
1548
1549 // Drop the ACP thread, which should cause the session to be dropped as well.
1550 cx.update(|_| {
1551 drop(thread);
1552 drop(acp_thread);
1553 });
1554 agent.read_with(cx, |agent, _| {
1555 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1556 });
1557
1558 // Ensure the thread can be reloaded from disk.
1559 assert_eq!(
1560 history_entries(&history_store, cx),
1561 vec![(
1562 HistoryEntryId::AcpThread(session_id.clone()),
1563 format!("Explaining {}", path!("/a/b.md"))
1564 )]
1565 );
1566 let acp_thread = agent
1567 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1568 .await
1569 .unwrap();
1570 acp_thread.read_with(cx, |thread, cx| {
1571 assert_eq!(
1572 thread.to_markdown(cx),
1573 formatdoc! {"
1574 ## User
1575
1576 What does [@b.md]({uri}) mean?
1577
1578 ## Assistant
1579
1580 Lorem.
1581
1582 "}
1583 )
1584 });
1585 }
1586
1587 fn history_entries(
1588 history: &Entity<HistoryStore>,
1589 cx: &mut TestAppContext,
1590 ) -> Vec<(HistoryEntryId, String)> {
1591 history.read_with(cx, |history, _| {
1592 history
1593 .entries()
1594 .map(|e| (e.id(), e.title().to_string()))
1595 .collect::<Vec<_>>()
1596 })
1597 }
1598
1599 fn init_test(cx: &mut TestAppContext) {
1600 env_logger::try_init().ok();
1601 cx.update(|cx| {
1602 let settings_store = SettingsStore::test(cx);
1603 cx.set_global(settings_store);
1604
1605 LanguageModelRegistry::test(cx);
1606 });
1607 }
1608}