1mod db;
2mod edit_agent;
3mod history_store;
4mod legacy_thread;
5mod native_agent_server;
6pub mod outline;
7mod templates;
8mod thread;
9mod tool_schema;
10mod tools;
11
12#[cfg(test)]
13mod tests;
14
15pub use db::*;
16pub use history_store::*;
17pub use native_agent_server::NativeAgentServer;
18pub use templates::*;
19pub use thread::*;
20pub use tools::*;
21
22use acp_thread::{AcpThread, AgentModelSelector};
23use agent_client_protocol as acp;
24use anyhow::{Context as _, Result, anyhow};
25use chrono::{DateTime, Utc};
26use collections::{HashSet, IndexMap};
27use fs::Fs;
28use futures::channel::{mpsc, oneshot};
29use futures::future::Shared;
30use futures::{StreamExt, future};
31use gpui::{
32 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
33};
34use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
35use project::{Project, ProjectItem, ProjectPath, Worktree};
36use prompt_store::{
37 ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
38};
39use serde::{Deserialize, Serialize};
40use settings::{LanguageModelSelection, update_settings_file};
41use std::any::Any;
42use std::collections::HashMap;
43use std::path::{Path, PathBuf};
44use std::rc::Rc;
45use std::sync::Arc;
46use util::ResultExt;
47use util::rel_path::RelPath;
48
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50pub struct ProjectSnapshot {
51 pub worktree_snapshots: Vec<WorktreeSnapshot>,
52 pub timestamp: DateTime<Utc>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct WorktreeSnapshot {
57 pub worktree_path: String,
58 pub git_state: Option<GitState>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
62pub struct GitState {
63 pub remote_url: Option<String>,
64 pub head_sha: Option<String>,
65 pub current_branch: Option<String>,
66 pub diff: Option<String>,
67}
68
69const RULES_FILE_NAMES: [&str; 9] = [
70 ".rules",
71 ".cursorrules",
72 ".windsurfrules",
73 ".clinerules",
74 ".github/copilot-instructions.md",
75 "CLAUDE.md",
76 "AGENT.md",
77 "AGENTS.md",
78 "GEMINI.md",
79];
80
81pub struct RulesLoadingError {
82 pub message: SharedString,
83}
84
85/// Holds both the internal Thread and the AcpThread for a session
86struct Session {
87 /// The internal thread that processes messages
88 thread: Entity<Thread>,
89 /// The ACP thread that handles protocol communication
90 acp_thread: WeakEntity<acp_thread::AcpThread>,
91 pending_save: Task<()>,
92 _subscriptions: Vec<Subscription>,
93}
94
95pub struct LanguageModels {
96 /// Access language model by ID
97 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
98 /// Cached list for returning language model information
99 model_list: acp_thread::AgentModelList,
100 refresh_models_rx: watch::Receiver<()>,
101 refresh_models_tx: watch::Sender<()>,
102 _authenticate_all_providers_task: Task<()>,
103}
104
105impl LanguageModels {
106 fn new(cx: &mut App) -> Self {
107 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
108
109 let mut this = Self {
110 models: HashMap::default(),
111 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
112 refresh_models_rx,
113 refresh_models_tx,
114 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
115 };
116 this.refresh_list(cx);
117 this
118 }
119
120 fn refresh_list(&mut self, cx: &App) {
121 let providers = LanguageModelRegistry::global(cx)
122 .read(cx)
123 .providers()
124 .into_iter()
125 .filter(|provider| provider.is_authenticated(cx))
126 .collect::<Vec<_>>();
127
128 let mut language_model_list = IndexMap::default();
129 let mut recommended_models = HashSet::default();
130
131 let mut recommended = Vec::new();
132 for provider in &providers {
133 for model in provider.recommended_models(cx) {
134 recommended_models.insert((model.provider_id(), model.id()));
135 recommended.push(Self::map_language_model_to_info(&model, provider));
136 }
137 }
138 if !recommended.is_empty() {
139 language_model_list.insert(
140 acp_thread::AgentModelGroupName("Recommended".into()),
141 recommended,
142 );
143 }
144
145 let mut models = HashMap::default();
146 for provider in providers {
147 let mut provider_models = Vec::new();
148 for model in provider.provided_models(cx) {
149 let model_info = Self::map_language_model_to_info(&model, &provider);
150 let model_id = model_info.id.clone();
151 if !recommended_models.contains(&(model.provider_id(), model.id())) {
152 provider_models.push(model_info);
153 }
154 models.insert(model_id, model);
155 }
156 if !provider_models.is_empty() {
157 language_model_list.insert(
158 acp_thread::AgentModelGroupName(provider.name().0.clone()),
159 provider_models,
160 );
161 }
162 }
163
164 self.models = models;
165 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
166 self.refresh_models_tx.send(()).ok();
167 }
168
169 fn watch(&self) -> watch::Receiver<()> {
170 self.refresh_models_rx.clone()
171 }
172
173 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
174 self.models.get(model_id).cloned()
175 }
176
177 fn map_language_model_to_info(
178 model: &Arc<dyn LanguageModel>,
179 provider: &Arc<dyn LanguageModelProvider>,
180 ) -> acp_thread::AgentModelInfo {
181 acp_thread::AgentModelInfo {
182 id: Self::model_id(model),
183 name: model.name().0,
184 description: None,
185 icon: Some(provider.icon()),
186 }
187 }
188
189 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
190 acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
191 }
192
193 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
194 let authenticate_all_providers = LanguageModelRegistry::global(cx)
195 .read(cx)
196 .providers()
197 .iter()
198 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
199 .collect::<Vec<_>>();
200
201 cx.background_spawn(async move {
202 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
203 if let Err(err) = authenticate_task.await {
204 match err {
205 language_model::AuthenticateError::CredentialsNotFound => {
206 // Since we're authenticating these providers in the
207 // background for the purposes of populating the
208 // language selector, we don't care about providers
209 // where the credentials are not found.
210 }
211 language_model::AuthenticateError::ConnectionRefused => {
212 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
213 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
214 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
215 }
216 _ => {
217 // Some providers have noisy failure states that we
218 // don't want to spam the logs with every time the
219 // language model selector is initialized.
220 //
221 // Ideally these should have more clear failure modes
222 // that we know are safe to ignore here, like what we do
223 // with `CredentialsNotFound` above.
224 match provider_id.0.as_ref() {
225 "lmstudio" | "ollama" => {
226 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
227 //
228 // These fail noisily, so we don't log them.
229 }
230 "copilot_chat" => {
231 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
232 }
233 _ => {
234 log::error!(
235 "Failed to authenticate provider: {}: {err}",
236 provider_name.0
237 );
238 }
239 }
240 }
241 }
242 }
243 }
244 })
245 }
246}
247
248pub struct NativeAgent {
249 /// Session ID -> Session mapping
250 sessions: HashMap<acp::SessionId, Session>,
251 history: Entity<HistoryStore>,
252 /// Shared project context for all threads
253 project_context: Entity<ProjectContext>,
254 project_context_needs_refresh: watch::Sender<()>,
255 _maintain_project_context: Task<Result<()>>,
256 context_server_registry: Entity<ContextServerRegistry>,
257 /// Shared templates for all threads
258 templates: Arc<Templates>,
259 /// Cached model information
260 models: LanguageModels,
261 project: Entity<Project>,
262 prompt_store: Option<Entity<PromptStore>>,
263 fs: Arc<dyn Fs>,
264 _subscriptions: Vec<Subscription>,
265}
266
267impl NativeAgent {
268 pub async fn new(
269 project: Entity<Project>,
270 history: Entity<HistoryStore>,
271 templates: Arc<Templates>,
272 prompt_store: Option<Entity<PromptStore>>,
273 fs: Arc<dyn Fs>,
274 cx: &mut AsyncApp,
275 ) -> Result<Entity<NativeAgent>> {
276 log::debug!("Creating new NativeAgent");
277
278 let project_context = cx
279 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
280 .await;
281
282 cx.new(|cx| {
283 let mut subscriptions = vec![
284 cx.subscribe(&project, Self::handle_project_event),
285 cx.subscribe(
286 &LanguageModelRegistry::global(cx),
287 Self::handle_models_updated_event,
288 ),
289 ];
290 if let Some(prompt_store) = prompt_store.as_ref() {
291 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
292 }
293
294 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
295 watch::channel(());
296 Self {
297 sessions: HashMap::new(),
298 history,
299 project_context: cx.new(|_| project_context),
300 project_context_needs_refresh: project_context_needs_refresh_tx,
301 _maintain_project_context: cx.spawn(async move |this, cx| {
302 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
303 }),
304 context_server_registry: cx.new(|cx| {
305 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
306 }),
307 templates,
308 models: LanguageModels::new(cx),
309 project,
310 prompt_store,
311 fs,
312 _subscriptions: subscriptions,
313 }
314 })
315 }
316
317 fn register_session(
318 &mut self,
319 thread_handle: Entity<Thread>,
320 cx: &mut Context<Self>,
321 ) -> Entity<AcpThread> {
322 let connection = Rc::new(NativeAgentConnection(cx.entity()));
323
324 let thread = thread_handle.read(cx);
325 let session_id = thread.id().clone();
326 let title = thread.title();
327 let project = thread.project.clone();
328 let action_log = thread.action_log.clone();
329 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
330 let acp_thread = cx.new(|cx| {
331 acp_thread::AcpThread::new(
332 title,
333 connection,
334 project.clone(),
335 action_log.clone(),
336 session_id.clone(),
337 prompt_capabilities_rx,
338 cx,
339 )
340 });
341
342 let registry = LanguageModelRegistry::read_global(cx);
343 let summarization_model = registry.thread_summary_model().map(|c| c.model);
344
345 thread_handle.update(cx, |thread, cx| {
346 thread.set_summarization_model(summarization_model, cx);
347 thread.add_default_tools(
348 Rc::new(AcpThreadEnvironment {
349 acp_thread: acp_thread.downgrade(),
350 }) as _,
351 cx,
352 )
353 });
354
355 let subscriptions = vec![
356 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
357 this.sessions.remove(acp_thread.session_id());
358 }),
359 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
360 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
361 cx.observe(&thread_handle, move |this, thread, cx| {
362 this.save_thread(thread, cx)
363 }),
364 ];
365
366 self.sessions.insert(
367 session_id,
368 Session {
369 thread: thread_handle,
370 acp_thread: acp_thread.downgrade(),
371 _subscriptions: subscriptions,
372 pending_save: Task::ready(()),
373 },
374 );
375 acp_thread
376 }
377
378 pub fn models(&self) -> &LanguageModels {
379 &self.models
380 }
381
382 async fn maintain_project_context(
383 this: WeakEntity<Self>,
384 mut needs_refresh: watch::Receiver<()>,
385 cx: &mut AsyncApp,
386 ) -> Result<()> {
387 while needs_refresh.changed().await.is_ok() {
388 let project_context = this
389 .update(cx, |this, cx| {
390 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
391 })?
392 .await;
393 this.update(cx, |this, cx| {
394 this.project_context = cx.new(|_| project_context);
395 })?;
396 }
397
398 Ok(())
399 }
400
401 fn build_project_context(
402 project: &Entity<Project>,
403 prompt_store: Option<&Entity<PromptStore>>,
404 cx: &mut App,
405 ) -> Task<ProjectContext> {
406 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
407 let worktree_tasks = worktrees
408 .into_iter()
409 .map(|worktree| {
410 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
411 })
412 .collect::<Vec<_>>();
413 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
414 prompt_store.read_with(cx, |prompt_store, cx| {
415 let prompts = prompt_store.default_prompt_metadata();
416 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
417 let contents = prompt_store.load(prompt_metadata.id, cx);
418 async move { (contents.await, prompt_metadata) }
419 });
420 cx.background_spawn(future::join_all(load_tasks))
421 })
422 } else {
423 Task::ready(vec![])
424 };
425
426 cx.spawn(async move |_cx| {
427 let (worktrees, default_user_rules) =
428 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
429
430 let worktrees = worktrees
431 .into_iter()
432 .map(|(worktree, _rules_error)| {
433 // TODO: show error message
434 // if let Some(rules_error) = rules_error {
435 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
436 // }
437 worktree
438 })
439 .collect::<Vec<_>>();
440
441 let default_user_rules = default_user_rules
442 .into_iter()
443 .flat_map(|(contents, prompt_metadata)| match contents {
444 Ok(contents) => Some(UserRulesContext {
445 uuid: match prompt_metadata.id {
446 prompt_store::PromptId::User { uuid } => uuid,
447 prompt_store::PromptId::EditWorkflow => return None,
448 },
449 title: prompt_metadata.title.map(|title| title.to_string()),
450 contents,
451 }),
452 Err(_err) => {
453 // TODO: show error message
454 // this.update(cx, |_, cx| {
455 // cx.emit(RulesLoadingError {
456 // message: format!("{err:?}").into(),
457 // });
458 // })
459 // .ok();
460 None
461 }
462 })
463 .collect::<Vec<_>>();
464
465 ProjectContext::new(worktrees, default_user_rules)
466 })
467 }
468
469 fn load_worktree_info_for_system_prompt(
470 worktree: Entity<Worktree>,
471 project: Entity<Project>,
472 cx: &mut App,
473 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
474 let tree = worktree.read(cx);
475 let root_name = tree.root_name_str().into();
476 let abs_path = tree.abs_path();
477
478 let mut context = WorktreeContext {
479 root_name,
480 abs_path,
481 rules_file: None,
482 };
483
484 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
485 let Some(rules_task) = rules_task else {
486 return Task::ready((context, None));
487 };
488
489 cx.spawn(async move |_| {
490 let (rules_file, rules_file_error) = match rules_task.await {
491 Ok(rules_file) => (Some(rules_file), None),
492 Err(err) => (
493 None,
494 Some(RulesLoadingError {
495 message: format!("{err}").into(),
496 }),
497 ),
498 };
499 context.rules_file = rules_file;
500 (context, rules_file_error)
501 })
502 }
503
504 fn load_worktree_rules_file(
505 worktree: Entity<Worktree>,
506 project: Entity<Project>,
507 cx: &mut App,
508 ) -> Option<Task<Result<RulesFileContext>>> {
509 let worktree = worktree.read(cx);
510 let worktree_id = worktree.id();
511 let selected_rules_file = RULES_FILE_NAMES
512 .into_iter()
513 .filter_map(|name| {
514 worktree
515 .entry_for_path(RelPath::unix(name).unwrap())
516 .filter(|entry| entry.is_file())
517 .map(|entry| entry.path.clone())
518 })
519 .next();
520
521 // Note that Cline supports `.clinerules` being a directory, but that is not currently
522 // supported. This doesn't seem to occur often in GitHub repositories.
523 selected_rules_file.map(|path_in_worktree| {
524 let project_path = ProjectPath {
525 worktree_id,
526 path: path_in_worktree.clone(),
527 };
528 let buffer_task =
529 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
530 let rope_task = cx.spawn(async move |cx| {
531 buffer_task.await?.read_with(cx, |buffer, cx| {
532 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
533 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
534 })?
535 });
536 // Build a string from the rope on a background thread.
537 cx.background_spawn(async move {
538 let (project_entry_id, rope) = rope_task.await?;
539 anyhow::Ok(RulesFileContext {
540 path_in_worktree,
541 text: rope.to_string().trim().to_string(),
542 project_entry_id: project_entry_id.to_usize(),
543 })
544 })
545 })
546 }
547
548 fn handle_thread_title_updated(
549 &mut self,
550 thread: Entity<Thread>,
551 _: &TitleUpdated,
552 cx: &mut Context<Self>,
553 ) {
554 let session_id = thread.read(cx).id();
555 let Some(session) = self.sessions.get(session_id) else {
556 return;
557 };
558 let thread = thread.downgrade();
559 let acp_thread = session.acp_thread.clone();
560 cx.spawn(async move |_, cx| {
561 let title = thread.read_with(cx, |thread, _| thread.title())?;
562 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
563 task.await
564 })
565 .detach_and_log_err(cx);
566 }
567
568 fn handle_thread_token_usage_updated(
569 &mut self,
570 thread: Entity<Thread>,
571 usage: &TokenUsageUpdated,
572 cx: &mut Context<Self>,
573 ) {
574 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
575 return;
576 };
577 session
578 .acp_thread
579 .update(cx, |acp_thread, cx| {
580 acp_thread.update_token_usage(usage.0.clone(), cx);
581 })
582 .ok();
583 }
584
585 fn handle_project_event(
586 &mut self,
587 _project: Entity<Project>,
588 event: &project::Event,
589 _cx: &mut Context<Self>,
590 ) {
591 match event {
592 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
593 self.project_context_needs_refresh.send(()).ok();
594 }
595 project::Event::WorktreeUpdatedEntries(_, items) => {
596 if items.iter().any(|(path, _, _)| {
597 RULES_FILE_NAMES
598 .iter()
599 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
600 }) {
601 self.project_context_needs_refresh.send(()).ok();
602 }
603 }
604 _ => {}
605 }
606 }
607
608 fn handle_prompts_updated_event(
609 &mut self,
610 _prompt_store: Entity<PromptStore>,
611 _event: &prompt_store::PromptsUpdatedEvent,
612 _cx: &mut Context<Self>,
613 ) {
614 self.project_context_needs_refresh.send(()).ok();
615 }
616
617 fn handle_models_updated_event(
618 &mut self,
619 _registry: Entity<LanguageModelRegistry>,
620 _event: &language_model::Event,
621 cx: &mut Context<Self>,
622 ) {
623 self.models.refresh_list(cx);
624
625 let registry = LanguageModelRegistry::read_global(cx);
626 let default_model = registry.default_model().map(|m| m.model);
627 let summarization_model = registry.thread_summary_model().map(|m| m.model);
628
629 for session in self.sessions.values_mut() {
630 session.thread.update(cx, |thread, cx| {
631 if thread.model().is_none()
632 && let Some(model) = default_model.clone()
633 {
634 thread.set_model(model, cx);
635 cx.notify();
636 }
637 thread.set_summarization_model(summarization_model.clone(), cx);
638 });
639 }
640 }
641
642 pub fn load_thread(
643 &mut self,
644 id: acp::SessionId,
645 cx: &mut Context<Self>,
646 ) -> Task<Result<Entity<Thread>>> {
647 let database_future = ThreadsDatabase::connect(cx);
648 cx.spawn(async move |this, cx| {
649 let database = database_future.await.map_err(|err| anyhow!(err))?;
650 let db_thread = database
651 .load_thread(id.clone())
652 .await?
653 .with_context(|| format!("no thread found with ID: {id:?}"))?;
654
655 this.update(cx, |this, cx| {
656 let summarization_model = LanguageModelRegistry::read_global(cx)
657 .thread_summary_model()
658 .map(|c| c.model);
659
660 cx.new(|cx| {
661 let mut thread = Thread::from_db(
662 id.clone(),
663 db_thread,
664 this.project.clone(),
665 this.project_context.clone(),
666 this.context_server_registry.clone(),
667 this.templates.clone(),
668 cx,
669 );
670 thread.set_summarization_model(summarization_model, cx);
671 thread
672 })
673 })
674 })
675 }
676
677 pub fn open_thread(
678 &mut self,
679 id: acp::SessionId,
680 cx: &mut Context<Self>,
681 ) -> Task<Result<Entity<AcpThread>>> {
682 let task = self.load_thread(id, cx);
683 cx.spawn(async move |this, cx| {
684 let thread = task.await?;
685 let acp_thread =
686 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
687 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
688 cx.update(|cx| {
689 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
690 })?
691 .await?;
692 Ok(acp_thread)
693 })
694 }
695
696 pub fn thread_summary(
697 &mut self,
698 id: acp::SessionId,
699 cx: &mut Context<Self>,
700 ) -> Task<Result<SharedString>> {
701 let thread = self.open_thread(id.clone(), cx);
702 cx.spawn(async move |this, cx| {
703 let acp_thread = thread.await?;
704 let result = this
705 .update(cx, |this, cx| {
706 this.sessions
707 .get(&id)
708 .unwrap()
709 .thread
710 .update(cx, |thread, cx| thread.summary(cx))
711 })?
712 .await
713 .context("Failed to generate summary")?;
714 drop(acp_thread);
715 Ok(result)
716 })
717 }
718
719 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
720 if thread.read(cx).is_empty() {
721 return;
722 }
723
724 let database_future = ThreadsDatabase::connect(cx);
725 let (id, db_thread) =
726 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
727 let Some(session) = self.sessions.get_mut(&id) else {
728 return;
729 };
730 let history = self.history.clone();
731 session.pending_save = cx.spawn(async move |_, cx| {
732 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
733 return;
734 };
735 let db_thread = db_thread.await;
736 database.save_thread(id, db_thread).await.log_err();
737 history.update(cx, |history, cx| history.reload(cx)).ok();
738 });
739 }
740}
741
742/// Wrapper struct that implements the AgentConnection trait
743#[derive(Clone)]
744pub struct NativeAgentConnection(pub Entity<NativeAgent>);
745
746impl NativeAgentConnection {
747 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
748 self.0
749 .read(cx)
750 .sessions
751 .get(session_id)
752 .map(|session| session.thread.clone())
753 }
754
755 pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task<Result<Entity<Thread>>> {
756 self.0.update(cx, |this, cx| this.load_thread(id, cx))
757 }
758
759 fn run_turn(
760 &self,
761 session_id: acp::SessionId,
762 cx: &mut App,
763 f: impl 'static
764 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
765 ) -> Task<Result<acp::PromptResponse>> {
766 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
767 agent
768 .sessions
769 .get_mut(&session_id)
770 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
771 }) else {
772 return Task::ready(Err(anyhow!("Session not found")));
773 };
774 log::debug!("Found session for: {}", session_id);
775
776 let response_stream = match f(thread, cx) {
777 Ok(stream) => stream,
778 Err(err) => return Task::ready(Err(err)),
779 };
780 Self::handle_thread_events(response_stream, acp_thread, cx)
781 }
782
783 fn handle_thread_events(
784 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
785 acp_thread: WeakEntity<AcpThread>,
786 cx: &App,
787 ) -> Task<Result<acp::PromptResponse>> {
788 cx.spawn(async move |cx| {
789 // Handle response stream and forward to session.acp_thread
790 while let Some(result) = events.next().await {
791 match result {
792 Ok(event) => {
793 log::trace!("Received completion event: {:?}", event);
794
795 match event {
796 ThreadEvent::UserMessage(message) => {
797 acp_thread.update(cx, |thread, cx| {
798 for content in message.content {
799 thread.push_user_content_block(
800 Some(message.id.clone()),
801 content.into(),
802 cx,
803 );
804 }
805 })?;
806 }
807 ThreadEvent::AgentText(text) => {
808 acp_thread.update(cx, |thread, cx| {
809 thread.push_assistant_content_block(
810 acp::ContentBlock::Text(acp::TextContent {
811 text,
812 annotations: None,
813 meta: None,
814 }),
815 false,
816 cx,
817 )
818 })?;
819 }
820 ThreadEvent::AgentThinking(text) => {
821 acp_thread.update(cx, |thread, cx| {
822 thread.push_assistant_content_block(
823 acp::ContentBlock::Text(acp::TextContent {
824 text,
825 annotations: None,
826 meta: None,
827 }),
828 true,
829 cx,
830 )
831 })?;
832 }
833 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
834 tool_call,
835 options,
836 response,
837 }) => {
838 let outcome_task = acp_thread.update(cx, |thread, cx| {
839 thread.request_tool_call_authorization(
840 tool_call, options, true, cx,
841 )
842 })??;
843 cx.background_spawn(async move {
844 if let acp::RequestPermissionOutcome::Selected { option_id } =
845 outcome_task.await
846 {
847 response
848 .send(option_id)
849 .map(|_| anyhow!("authorization receiver was dropped"))
850 .log_err();
851 }
852 })
853 .detach();
854 }
855 ThreadEvent::ToolCall(tool_call) => {
856 acp_thread.update(cx, |thread, cx| {
857 thread.upsert_tool_call(tool_call, cx)
858 })??;
859 }
860 ThreadEvent::ToolCallUpdate(update) => {
861 acp_thread.update(cx, |thread, cx| {
862 thread.update_tool_call(update, cx)
863 })??;
864 }
865 ThreadEvent::Retry(status) => {
866 acp_thread.update(cx, |thread, cx| {
867 thread.update_retry_status(status, cx)
868 })?;
869 }
870 ThreadEvent::Stop(stop_reason) => {
871 log::debug!("Assistant message complete: {:?}", stop_reason);
872 return Ok(acp::PromptResponse {
873 stop_reason,
874 meta: None,
875 });
876 }
877 }
878 }
879 Err(e) => {
880 log::error!("Error in model response stream: {:?}", e);
881 return Err(e);
882 }
883 }
884 }
885
886 log::debug!("Response stream completed");
887 anyhow::Ok(acp::PromptResponse {
888 stop_reason: acp::StopReason::EndTurn,
889 meta: None,
890 })
891 })
892 }
893}
894
895struct NativeAgentModelSelector {
896 session_id: acp::SessionId,
897 connection: NativeAgentConnection,
898}
899
900impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
901 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
902 log::debug!("NativeAgentConnection::list_models called");
903 let list = self.connection.0.read(cx).models.model_list.clone();
904 Task::ready(if list.is_empty() {
905 Err(anyhow::anyhow!("No models available"))
906 } else {
907 Ok(list)
908 })
909 }
910
911 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
912 log::debug!(
913 "Setting model for session {}: {}",
914 self.session_id,
915 model_id
916 );
917 let Some(thread) = self
918 .connection
919 .0
920 .read(cx)
921 .sessions
922 .get(&self.session_id)
923 .map(|session| session.thread.clone())
924 else {
925 return Task::ready(Err(anyhow!("Session not found")));
926 };
927
928 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
929 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
930 };
931
932 thread.update(cx, |thread, cx| {
933 thread.set_model(model.clone(), cx);
934 });
935
936 update_settings_file(
937 self.connection.0.read(cx).fs.clone(),
938 cx,
939 move |settings, _cx| {
940 let provider = model.provider_id().0.to_string();
941 let model = model.id().0.to_string();
942 settings
943 .agent
944 .get_or_insert_default()
945 .set_model(LanguageModelSelection {
946 provider: provider.into(),
947 model,
948 });
949 },
950 );
951
952 Task::ready(Ok(()))
953 }
954
955 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
956 let Some(thread) = self
957 .connection
958 .0
959 .read(cx)
960 .sessions
961 .get(&self.session_id)
962 .map(|session| session.thread.clone())
963 else {
964 return Task::ready(Err(anyhow!("Session not found")));
965 };
966 let Some(model) = thread.read(cx).model() else {
967 return Task::ready(Err(anyhow!("Model not found")));
968 };
969 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
970 else {
971 return Task::ready(Err(anyhow!("Provider not found")));
972 };
973 Task::ready(Ok(LanguageModels::map_language_model_to_info(
974 model, &provider,
975 )))
976 }
977
978 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
979 Some(self.connection.0.read(cx).models.watch())
980 }
981}
982
983impl acp_thread::AgentConnection for NativeAgentConnection {
984 fn new_thread(
985 self: Rc<Self>,
986 project: Entity<Project>,
987 cwd: &Path,
988 cx: &mut App,
989 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
990 let agent = self.0.clone();
991 log::debug!("Creating new thread for project at: {:?}", cwd);
992
993 cx.spawn(async move |cx| {
994 log::debug!("Starting thread creation in async context");
995
996 // Create Thread
997 let thread = agent.update(
998 cx,
999 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
1000 // Fetch default model from registry settings
1001 let registry = LanguageModelRegistry::read_global(cx);
1002 // Log available models for debugging
1003 let available_count = registry.available_models(cx).count();
1004 log::debug!("Total available models: {}", available_count);
1005
1006 let default_model = registry.default_model().and_then(|default_model| {
1007 agent
1008 .models
1009 .model_from_id(&LanguageModels::model_id(&default_model.model))
1010 });
1011 Ok(cx.new(|cx| {
1012 Thread::new(
1013 project.clone(),
1014 agent.project_context.clone(),
1015 agent.context_server_registry.clone(),
1016 agent.templates.clone(),
1017 default_model,
1018 cx,
1019 )
1020 }))
1021 },
1022 )??;
1023 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
1024 })
1025 }
1026
1027 fn auth_methods(&self) -> &[acp::AuthMethod] {
1028 &[] // No auth for in-process
1029 }
1030
1031 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
1032 Task::ready(Ok(()))
1033 }
1034
1035 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
1036 Some(Rc::new(NativeAgentModelSelector {
1037 session_id: session_id.clone(),
1038 connection: self.clone(),
1039 }) as Rc<dyn AgentModelSelector>)
1040 }
1041
1042 fn prompt(
1043 &self,
1044 id: Option<acp_thread::UserMessageId>,
1045 params: acp::PromptRequest,
1046 cx: &mut App,
1047 ) -> Task<Result<acp::PromptResponse>> {
1048 let id = id.expect("UserMessageId is required");
1049 let session_id = params.session_id.clone();
1050 log::info!("Received prompt request for session: {}", session_id);
1051 log::debug!("Prompt blocks count: {}", params.prompt.len());
1052
1053 self.run_turn(session_id, cx, |thread, cx| {
1054 let content: Vec<UserMessageContent> = params
1055 .prompt
1056 .into_iter()
1057 .map(Into::into)
1058 .collect::<Vec<_>>();
1059 log::debug!("Converted prompt to message: {} chars", content.len());
1060 log::debug!("Message id: {:?}", id);
1061 log::debug!("Message content: {:?}", content);
1062
1063 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1064 })
1065 }
1066
1067 fn resume(
1068 &self,
1069 session_id: &acp::SessionId,
1070 _cx: &App,
1071 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1072 Some(Rc::new(NativeAgentSessionResume {
1073 connection: self.clone(),
1074 session_id: session_id.clone(),
1075 }) as _)
1076 }
1077
1078 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1079 log::info!("Cancelling on session: {}", session_id);
1080 self.0.update(cx, |agent, cx| {
1081 if let Some(agent) = agent.sessions.get(session_id) {
1082 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1083 }
1084 });
1085 }
1086
1087 fn truncate(
1088 &self,
1089 session_id: &agent_client_protocol::SessionId,
1090 cx: &App,
1091 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1092 self.0.read_with(cx, |agent, _cx| {
1093 agent.sessions.get(session_id).map(|session| {
1094 Rc::new(NativeAgentSessionTruncate {
1095 thread: session.thread.clone(),
1096 acp_thread: session.acp_thread.clone(),
1097 }) as _
1098 })
1099 })
1100 }
1101
1102 fn set_title(
1103 &self,
1104 session_id: &acp::SessionId,
1105 _cx: &App,
1106 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1107 Some(Rc::new(NativeAgentSessionSetTitle {
1108 connection: self.clone(),
1109 session_id: session_id.clone(),
1110 }) as _)
1111 }
1112
1113 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1114 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1115 }
1116
1117 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1118 self
1119 }
1120}
1121
1122impl acp_thread::AgentTelemetry for NativeAgentConnection {
1123 fn agent_name(&self) -> String {
1124 "Zed".into()
1125 }
1126
1127 fn thread_data(
1128 &self,
1129 session_id: &acp::SessionId,
1130 cx: &mut App,
1131 ) -> Task<Result<serde_json::Value>> {
1132 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1133 return Task::ready(Err(anyhow!("Session not found")));
1134 };
1135
1136 let task = session.thread.read(cx).to_db(cx);
1137 cx.background_spawn(async move {
1138 serde_json::to_value(task.await).context("Failed to serialize thread")
1139 })
1140 }
1141}
1142
1143struct NativeAgentSessionTruncate {
1144 thread: Entity<Thread>,
1145 acp_thread: WeakEntity<AcpThread>,
1146}
1147
1148impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1149 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1150 match self.thread.update(cx, |thread, cx| {
1151 thread.truncate(message_id.clone(), cx)?;
1152 Ok(thread.latest_token_usage())
1153 }) {
1154 Ok(usage) => {
1155 self.acp_thread
1156 .update(cx, |thread, cx| {
1157 thread.update_token_usage(usage, cx);
1158 })
1159 .ok();
1160 Task::ready(Ok(()))
1161 }
1162 Err(error) => Task::ready(Err(error)),
1163 }
1164 }
1165}
1166
1167struct NativeAgentSessionResume {
1168 connection: NativeAgentConnection,
1169 session_id: acp::SessionId,
1170}
1171
1172impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1173 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1174 self.connection
1175 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1176 thread.update(cx, |thread, cx| thread.resume(cx))
1177 })
1178 }
1179}
1180
1181struct NativeAgentSessionSetTitle {
1182 connection: NativeAgentConnection,
1183 session_id: acp::SessionId,
1184}
1185
1186impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1187 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1188 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1189 return Task::ready(Err(anyhow!("session not found")));
1190 };
1191 let thread = session.thread.clone();
1192 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1193 Task::ready(Ok(()))
1194 }
1195}
1196
1197pub struct AcpThreadEnvironment {
1198 acp_thread: WeakEntity<AcpThread>,
1199}
1200
1201impl ThreadEnvironment for AcpThreadEnvironment {
1202 fn create_terminal(
1203 &self,
1204 command: String,
1205 cwd: Option<PathBuf>,
1206 output_byte_limit: Option<u64>,
1207 cx: &mut AsyncApp,
1208 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1209 let task = self.acp_thread.update(cx, |thread, cx| {
1210 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1211 });
1212
1213 let acp_thread = self.acp_thread.clone();
1214 cx.spawn(async move |cx| {
1215 let terminal = task?.await?;
1216
1217 let (drop_tx, drop_rx) = oneshot::channel();
1218 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1219
1220 cx.spawn(async move |cx| {
1221 drop_rx.await.ok();
1222 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1223 })
1224 .detach();
1225
1226 let handle = AcpTerminalHandle {
1227 terminal,
1228 _drop_tx: Some(drop_tx),
1229 };
1230
1231 Ok(Rc::new(handle) as _)
1232 })
1233 }
1234}
1235
1236pub struct AcpTerminalHandle {
1237 terminal: Entity<acp_thread::Terminal>,
1238 _drop_tx: Option<oneshot::Sender<()>>,
1239}
1240
1241impl TerminalHandle for AcpTerminalHandle {
1242 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1243 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1244 }
1245
1246 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1247 self.terminal
1248 .read_with(cx, |term, _cx| term.wait_for_exit())
1249 }
1250
1251 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1252 self.terminal
1253 .read_with(cx, |term, cx| term.current_output(cx))
1254 }
1255}
1256
1257#[cfg(test)]
1258mod internal_tests {
1259 use crate::HistoryEntryId;
1260
1261 use super::*;
1262 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1263 use fs::FakeFs;
1264 use gpui::TestAppContext;
1265 use indoc::formatdoc;
1266 use language_model::fake_provider::FakeLanguageModel;
1267 use serde_json::json;
1268 use settings::SettingsStore;
1269 use util::{path, rel_path::rel_path};
1270
1271 #[gpui::test]
1272 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1273 init_test(cx);
1274 let fs = FakeFs::new(cx.executor());
1275 fs.insert_tree(
1276 "/",
1277 json!({
1278 "a": {}
1279 }),
1280 )
1281 .await;
1282 let project = Project::test(fs.clone(), [], cx).await;
1283 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1284 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1285 let agent = NativeAgent::new(
1286 project.clone(),
1287 history_store,
1288 Templates::new(),
1289 None,
1290 fs.clone(),
1291 &mut cx.to_async(),
1292 )
1293 .await
1294 .unwrap();
1295 agent.read_with(cx, |agent, cx| {
1296 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1297 });
1298
1299 let worktree = project
1300 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1301 .await
1302 .unwrap();
1303 cx.run_until_parked();
1304 agent.read_with(cx, |agent, cx| {
1305 assert_eq!(
1306 agent.project_context.read(cx).worktrees,
1307 vec![WorktreeContext {
1308 root_name: "a".into(),
1309 abs_path: Path::new("/a").into(),
1310 rules_file: None
1311 }]
1312 )
1313 });
1314
1315 // Creating `/a/.rules` updates the project context.
1316 fs.insert_file("/a/.rules", Vec::new()).await;
1317 cx.run_until_parked();
1318 agent.read_with(cx, |agent, cx| {
1319 let rules_entry = worktree
1320 .read(cx)
1321 .entry_for_path(rel_path(".rules"))
1322 .unwrap();
1323 assert_eq!(
1324 agent.project_context.read(cx).worktrees,
1325 vec![WorktreeContext {
1326 root_name: "a".into(),
1327 abs_path: Path::new("/a").into(),
1328 rules_file: Some(RulesFileContext {
1329 path_in_worktree: rel_path(".rules").into(),
1330 text: "".into(),
1331 project_entry_id: rules_entry.id.to_usize()
1332 })
1333 }]
1334 )
1335 });
1336 }
1337
1338 #[gpui::test]
1339 async fn test_listing_models(cx: &mut TestAppContext) {
1340 init_test(cx);
1341 let fs = FakeFs::new(cx.executor());
1342 fs.insert_tree("/", json!({ "a": {} })).await;
1343 let project = Project::test(fs.clone(), [], cx).await;
1344 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1345 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1346 let connection = NativeAgentConnection(
1347 NativeAgent::new(
1348 project.clone(),
1349 history_store,
1350 Templates::new(),
1351 None,
1352 fs.clone(),
1353 &mut cx.to_async(),
1354 )
1355 .await
1356 .unwrap(),
1357 );
1358
1359 // Create a thread/session
1360 let acp_thread = cx
1361 .update(|cx| {
1362 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1363 })
1364 .await
1365 .unwrap();
1366
1367 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1368
1369 let models = cx
1370 .update(|cx| {
1371 connection
1372 .model_selector(&session_id)
1373 .unwrap()
1374 .list_models(cx)
1375 })
1376 .await
1377 .unwrap();
1378
1379 let acp_thread::AgentModelList::Grouped(models) = models else {
1380 panic!("Unexpected model group");
1381 };
1382 assert_eq!(
1383 models,
1384 IndexMap::from_iter([(
1385 AgentModelGroupName("Fake".into()),
1386 vec![AgentModelInfo {
1387 id: acp::ModelId("fake/fake".into()),
1388 name: "Fake".into(),
1389 description: None,
1390 icon: Some(ui::IconName::ZedAssistant),
1391 }]
1392 )])
1393 );
1394 }
1395
1396 #[gpui::test]
1397 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1398 init_test(cx);
1399 let fs = FakeFs::new(cx.executor());
1400 fs.create_dir(paths::settings_file().parent().unwrap())
1401 .await
1402 .unwrap();
1403 fs.insert_file(
1404 paths::settings_file(),
1405 json!({
1406 "agent": {
1407 "default_model": {
1408 "provider": "foo",
1409 "model": "bar"
1410 }
1411 }
1412 })
1413 .to_string()
1414 .into_bytes(),
1415 )
1416 .await;
1417 let project = Project::test(fs.clone(), [], cx).await;
1418
1419 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1420 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1421
1422 // Create the agent and connection
1423 let agent = NativeAgent::new(
1424 project.clone(),
1425 history_store,
1426 Templates::new(),
1427 None,
1428 fs.clone(),
1429 &mut cx.to_async(),
1430 )
1431 .await
1432 .unwrap();
1433 let connection = NativeAgentConnection(agent.clone());
1434
1435 // Create a thread/session
1436 let acp_thread = cx
1437 .update(|cx| {
1438 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1439 })
1440 .await
1441 .unwrap();
1442
1443 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1444
1445 // Select a model
1446 let selector = connection.model_selector(&session_id).unwrap();
1447 let model_id = acp::ModelId("fake/fake".into());
1448 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1449 .await
1450 .unwrap();
1451
1452 // Verify the thread has the selected model
1453 agent.read_with(cx, |agent, _| {
1454 let session = agent.sessions.get(&session_id).unwrap();
1455 session.thread.read_with(cx, |thread, _| {
1456 assert_eq!(thread.model().unwrap().id().0, "fake");
1457 });
1458 });
1459
1460 cx.run_until_parked();
1461
1462 // Verify settings file was updated
1463 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1464 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1465
1466 // Check that the agent settings contain the selected model
1467 assert_eq!(
1468 settings_json["agent"]["default_model"]["model"],
1469 json!("fake")
1470 );
1471 assert_eq!(
1472 settings_json["agent"]["default_model"]["provider"],
1473 json!("fake")
1474 );
1475 }
1476
1477 #[gpui::test]
1478 async fn test_save_load_thread(cx: &mut TestAppContext) {
1479 init_test(cx);
1480 let fs = FakeFs::new(cx.executor());
1481 fs.insert_tree(
1482 "/",
1483 json!({
1484 "a": {
1485 "b.md": "Lorem"
1486 }
1487 }),
1488 )
1489 .await;
1490 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1491 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1492 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1493 let agent = NativeAgent::new(
1494 project.clone(),
1495 history_store.clone(),
1496 Templates::new(),
1497 None,
1498 fs.clone(),
1499 &mut cx.to_async(),
1500 )
1501 .await
1502 .unwrap();
1503 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1504
1505 let acp_thread = cx
1506 .update(|cx| {
1507 connection
1508 .clone()
1509 .new_thread(project.clone(), Path::new(""), cx)
1510 })
1511 .await
1512 .unwrap();
1513 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1514 let thread = agent.read_with(cx, |agent, _| {
1515 agent.sessions.get(&session_id).unwrap().thread.clone()
1516 });
1517
1518 // Ensure empty threads are not saved, even if they get mutated.
1519 let model = Arc::new(FakeLanguageModel::default());
1520 let summary_model = Arc::new(FakeLanguageModel::default());
1521 thread.update(cx, |thread, cx| {
1522 thread.set_model(model.clone(), cx);
1523 thread.set_summarization_model(Some(summary_model.clone()), cx);
1524 });
1525 cx.run_until_parked();
1526 assert_eq!(history_entries(&history_store, cx), vec![]);
1527
1528 let send = acp_thread.update(cx, |thread, cx| {
1529 thread.send(
1530 vec![
1531 "What does ".into(),
1532 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1533 name: "b.md".into(),
1534 uri: MentionUri::File {
1535 abs_path: path!("/a/b.md").into(),
1536 }
1537 .to_uri()
1538 .to_string(),
1539 annotations: None,
1540 description: None,
1541 mime_type: None,
1542 size: None,
1543 title: None,
1544 meta: None,
1545 }),
1546 " mean?".into(),
1547 ],
1548 cx,
1549 )
1550 });
1551 let send = cx.foreground_executor().spawn(send);
1552 cx.run_until_parked();
1553
1554 model.send_last_completion_stream_text_chunk("Lorem.");
1555 model.end_last_completion_stream();
1556 cx.run_until_parked();
1557 summary_model
1558 .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md")));
1559 summary_model.end_last_completion_stream();
1560
1561 send.await.unwrap();
1562 let uri = MentionUri::File {
1563 abs_path: path!("/a/b.md").into(),
1564 }
1565 .to_uri();
1566 acp_thread.read_with(cx, |thread, cx| {
1567 assert_eq!(
1568 thread.to_markdown(cx),
1569 formatdoc! {"
1570 ## User
1571
1572 What does [@b.md]({uri}) mean?
1573
1574 ## Assistant
1575
1576 Lorem.
1577
1578 "}
1579 )
1580 });
1581
1582 cx.run_until_parked();
1583
1584 // Drop the ACP thread, which should cause the session to be dropped as well.
1585 cx.update(|_| {
1586 drop(thread);
1587 drop(acp_thread);
1588 });
1589 agent.read_with(cx, |agent, _| {
1590 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1591 });
1592
1593 // Ensure the thread can be reloaded from disk.
1594 assert_eq!(
1595 history_entries(&history_store, cx),
1596 vec![(
1597 HistoryEntryId::AcpThread(session_id.clone()),
1598 format!("Explaining {}", path!("/a/b.md"))
1599 )]
1600 );
1601 let acp_thread = agent
1602 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1603 .await
1604 .unwrap();
1605 acp_thread.read_with(cx, |thread, cx| {
1606 assert_eq!(
1607 thread.to_markdown(cx),
1608 formatdoc! {"
1609 ## User
1610
1611 What does [@b.md]({uri}) mean?
1612
1613 ## Assistant
1614
1615 Lorem.
1616
1617 "}
1618 )
1619 });
1620 }
1621
1622 fn history_entries(
1623 history: &Entity<HistoryStore>,
1624 cx: &mut TestAppContext,
1625 ) -> Vec<(HistoryEntryId, String)> {
1626 history.read_with(cx, |history, _| {
1627 history
1628 .entries()
1629 .map(|e| (e.id(), e.title().to_string()))
1630 .collect::<Vec<_>>()
1631 })
1632 }
1633
1634 fn init_test(cx: &mut TestAppContext) {
1635 env_logger::try_init().ok();
1636 cx.update(|cx| {
1637 let settings_store = SettingsStore::test(cx);
1638 cx.set_global(settings_store);
1639 Project::init_settings(cx);
1640 agent_settings::init(cx);
1641 language::init(cx);
1642 LanguageModelRegistry::test(cx);
1643 });
1644 }
1645}