1use crate::{
2 ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization,
3 UserMessageContent, templates::Templates,
4};
5use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated};
6use acp_thread::{AcpThread, AgentModelSelector};
7use action_log::ActionLog;
8use agent_client_protocol as acp;
9use anyhow::{Context as _, Result, anyhow};
10use collections::{HashSet, IndexMap};
11use fs::Fs;
12use futures::channel::{mpsc, oneshot};
13use futures::future::Shared;
14use futures::{StreamExt, future};
15use gpui::{
16 App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
17};
18use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
19use project::{Project, ProjectItem, ProjectPath, Worktree};
20use prompt_store::{
21 ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
22};
23use settings::{LanguageModelSelection, update_settings_file};
24use std::any::Any;
25use std::collections::HashMap;
26use std::path::{Path, PathBuf};
27use std::rc::Rc;
28use std::sync::Arc;
29use util::ResultExt;
30use util::rel_path::RelPath;
31
32const RULES_FILE_NAMES: [&str; 9] = [
33 ".rules",
34 ".cursorrules",
35 ".windsurfrules",
36 ".clinerules",
37 ".github/copilot-instructions.md",
38 "CLAUDE.md",
39 "AGENT.md",
40 "AGENTS.md",
41 "GEMINI.md",
42];
43
44pub struct RulesLoadingError {
45 pub message: SharedString,
46}
47
48/// Holds both the internal Thread and the AcpThread for a session
49struct Session {
50 /// The internal thread that processes messages
51 thread: Entity<Thread>,
52 /// The ACP thread that handles protocol communication
53 acp_thread: WeakEntity<acp_thread::AcpThread>,
54 pending_save: Task<()>,
55 _subscriptions: Vec<Subscription>,
56}
57
58pub struct LanguageModels {
59 /// Access language model by ID
60 models: HashMap<acp::ModelId, Arc<dyn LanguageModel>>,
61 /// Cached list for returning language model information
62 model_list: acp_thread::AgentModelList,
63 refresh_models_rx: watch::Receiver<()>,
64 refresh_models_tx: watch::Sender<()>,
65 _authenticate_all_providers_task: Task<()>,
66}
67
68impl LanguageModels {
69 fn new(cx: &mut App) -> Self {
70 let (refresh_models_tx, refresh_models_rx) = watch::channel(());
71
72 let mut this = Self {
73 models: HashMap::default(),
74 model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()),
75 refresh_models_rx,
76 refresh_models_tx,
77 _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx),
78 };
79 this.refresh_list(cx);
80 this
81 }
82
83 fn refresh_list(&mut self, cx: &App) {
84 let providers = LanguageModelRegistry::global(cx)
85 .read(cx)
86 .providers()
87 .into_iter()
88 .filter(|provider| provider.is_authenticated(cx))
89 .collect::<Vec<_>>();
90
91 let mut language_model_list = IndexMap::default();
92 let mut recommended_models = HashSet::default();
93
94 let mut recommended = Vec::new();
95 for provider in &providers {
96 for model in provider.recommended_models(cx) {
97 recommended_models.insert((model.provider_id(), model.id()));
98 recommended.push(Self::map_language_model_to_info(&model, provider));
99 }
100 }
101 if !recommended.is_empty() {
102 language_model_list.insert(
103 acp_thread::AgentModelGroupName("Recommended".into()),
104 recommended,
105 );
106 }
107
108 let mut models = HashMap::default();
109 for provider in providers {
110 let mut provider_models = Vec::new();
111 for model in provider.provided_models(cx) {
112 let model_info = Self::map_language_model_to_info(&model, &provider);
113 let model_id = model_info.id.clone();
114 if !recommended_models.contains(&(model.provider_id(), model.id())) {
115 provider_models.push(model_info);
116 }
117 models.insert(model_id, model);
118 }
119 if !provider_models.is_empty() {
120 language_model_list.insert(
121 acp_thread::AgentModelGroupName(provider.name().0.clone()),
122 provider_models,
123 );
124 }
125 }
126
127 self.models = models;
128 self.model_list = acp_thread::AgentModelList::Grouped(language_model_list);
129 self.refresh_models_tx.send(()).ok();
130 }
131
132 fn watch(&self) -> watch::Receiver<()> {
133 self.refresh_models_rx.clone()
134 }
135
136 pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option<Arc<dyn LanguageModel>> {
137 self.models.get(model_id).cloned()
138 }
139
140 fn map_language_model_to_info(
141 model: &Arc<dyn LanguageModel>,
142 provider: &Arc<dyn LanguageModelProvider>,
143 ) -> acp_thread::AgentModelInfo {
144 acp_thread::AgentModelInfo {
145 id: Self::model_id(model),
146 name: model.name().0,
147 description: None,
148 icon: Some(provider.icon()),
149 }
150 }
151
152 fn model_id(model: &Arc<dyn LanguageModel>) -> acp::ModelId {
153 acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into())
154 }
155
156 fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> {
157 let authenticate_all_providers = LanguageModelRegistry::global(cx)
158 .read(cx)
159 .providers()
160 .iter()
161 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
162 .collect::<Vec<_>>();
163
164 cx.background_spawn(async move {
165 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
166 if let Err(err) = authenticate_task.await {
167 match err {
168 language_model::AuthenticateError::CredentialsNotFound => {
169 // Since we're authenticating these providers in the
170 // background for the purposes of populating the
171 // language selector, we don't care about providers
172 // where the credentials are not found.
173 }
174 language_model::AuthenticateError::ConnectionRefused => {
175 // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures.
176 // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it.
177 // TODO: Better manage LM Studio auth logic to avoid these noisy failures.
178 }
179 _ => {
180 // Some providers have noisy failure states that we
181 // don't want to spam the logs with every time the
182 // language model selector is initialized.
183 //
184 // Ideally these should have more clear failure modes
185 // that we know are safe to ignore here, like what we do
186 // with `CredentialsNotFound` above.
187 match provider_id.0.as_ref() {
188 "lmstudio" | "ollama" => {
189 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
190 //
191 // These fail noisily, so we don't log them.
192 }
193 "copilot_chat" => {
194 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
195 }
196 _ => {
197 log::error!(
198 "Failed to authenticate provider: {}: {err}",
199 provider_name.0
200 );
201 }
202 }
203 }
204 }
205 }
206 }
207 })
208 }
209}
210
211pub struct NativeAgent {
212 /// Session ID -> Session mapping
213 sessions: HashMap<acp::SessionId, Session>,
214 history: Entity<HistoryStore>,
215 /// Shared project context for all threads
216 project_context: Entity<ProjectContext>,
217 project_context_needs_refresh: watch::Sender<()>,
218 _maintain_project_context: Task<Result<()>>,
219 context_server_registry: Entity<ContextServerRegistry>,
220 /// Shared templates for all threads
221 templates: Arc<Templates>,
222 /// Cached model information
223 models: LanguageModels,
224 project: Entity<Project>,
225 prompt_store: Option<Entity<PromptStore>>,
226 fs: Arc<dyn Fs>,
227 _subscriptions: Vec<Subscription>,
228}
229
230impl NativeAgent {
231 pub async fn new(
232 project: Entity<Project>,
233 history: Entity<HistoryStore>,
234 templates: Arc<Templates>,
235 prompt_store: Option<Entity<PromptStore>>,
236 fs: Arc<dyn Fs>,
237 cx: &mut AsyncApp,
238 ) -> Result<Entity<NativeAgent>> {
239 log::debug!("Creating new NativeAgent");
240
241 let project_context = cx
242 .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))?
243 .await;
244
245 cx.new(|cx| {
246 let mut subscriptions = vec![
247 cx.subscribe(&project, Self::handle_project_event),
248 cx.subscribe(
249 &LanguageModelRegistry::global(cx),
250 Self::handle_models_updated_event,
251 ),
252 ];
253 if let Some(prompt_store) = prompt_store.as_ref() {
254 subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event))
255 }
256
257 let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) =
258 watch::channel(());
259 Self {
260 sessions: HashMap::new(),
261 history,
262 project_context: cx.new(|_| project_context),
263 project_context_needs_refresh: project_context_needs_refresh_tx,
264 _maintain_project_context: cx.spawn(async move |this, cx| {
265 Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await
266 }),
267 context_server_registry: cx.new(|cx| {
268 ContextServerRegistry::new(project.read(cx).context_server_store(), cx)
269 }),
270 templates,
271 models: LanguageModels::new(cx),
272 project,
273 prompt_store,
274 fs,
275 _subscriptions: subscriptions,
276 }
277 })
278 }
279
280 fn register_session(
281 &mut self,
282 thread_handle: Entity<Thread>,
283 cx: &mut Context<Self>,
284 ) -> Entity<AcpThread> {
285 let connection = Rc::new(NativeAgentConnection(cx.entity()));
286
287 let thread = thread_handle.read(cx);
288 let session_id = thread.id().clone();
289 let title = thread.title();
290 let project = thread.project.clone();
291 let action_log = thread.action_log.clone();
292 let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone();
293 let acp_thread = cx.new(|cx| {
294 acp_thread::AcpThread::new(
295 title,
296 connection,
297 project.clone(),
298 action_log.clone(),
299 session_id.clone(),
300 prompt_capabilities_rx,
301 cx,
302 )
303 });
304
305 let registry = LanguageModelRegistry::read_global(cx);
306 let summarization_model = registry.thread_summary_model().map(|c| c.model);
307
308 thread_handle.update(cx, |thread, cx| {
309 thread.set_summarization_model(summarization_model, cx);
310 thread.add_default_tools(
311 Rc::new(AcpThreadEnvironment {
312 acp_thread: acp_thread.downgrade(),
313 }) as _,
314 cx,
315 )
316 });
317
318 let subscriptions = vec![
319 cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
320 this.sessions.remove(acp_thread.session_id());
321 }),
322 cx.subscribe(&thread_handle, Self::handle_thread_title_updated),
323 cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated),
324 cx.observe(&thread_handle, move |this, thread, cx| {
325 this.save_thread(thread, cx)
326 }),
327 ];
328
329 self.sessions.insert(
330 session_id,
331 Session {
332 thread: thread_handle,
333 acp_thread: acp_thread.downgrade(),
334 _subscriptions: subscriptions,
335 pending_save: Task::ready(()),
336 },
337 );
338 acp_thread
339 }
340
341 pub fn models(&self) -> &LanguageModels {
342 &self.models
343 }
344
345 async fn maintain_project_context(
346 this: WeakEntity<Self>,
347 mut needs_refresh: watch::Receiver<()>,
348 cx: &mut AsyncApp,
349 ) -> Result<()> {
350 while needs_refresh.changed().await.is_ok() {
351 let project_context = this
352 .update(cx, |this, cx| {
353 Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx)
354 })?
355 .await;
356 this.update(cx, |this, cx| {
357 this.project_context = cx.new(|_| project_context);
358 })?;
359 }
360
361 Ok(())
362 }
363
364 fn build_project_context(
365 project: &Entity<Project>,
366 prompt_store: Option<&Entity<PromptStore>>,
367 cx: &mut App,
368 ) -> Task<ProjectContext> {
369 let worktrees = project.read(cx).visible_worktrees(cx).collect::<Vec<_>>();
370 let worktree_tasks = worktrees
371 .into_iter()
372 .map(|worktree| {
373 Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx)
374 })
375 .collect::<Vec<_>>();
376 let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() {
377 prompt_store.read_with(cx, |prompt_store, cx| {
378 let prompts = prompt_store.default_prompt_metadata();
379 let load_tasks = prompts.into_iter().map(|prompt_metadata| {
380 let contents = prompt_store.load(prompt_metadata.id, cx);
381 async move { (contents.await, prompt_metadata) }
382 });
383 cx.background_spawn(future::join_all(load_tasks))
384 })
385 } else {
386 Task::ready(vec![])
387 };
388
389 cx.spawn(async move |_cx| {
390 let (worktrees, default_user_rules) =
391 future::join(future::join_all(worktree_tasks), default_user_rules_task).await;
392
393 let worktrees = worktrees
394 .into_iter()
395 .map(|(worktree, _rules_error)| {
396 // TODO: show error message
397 // if let Some(rules_error) = rules_error {
398 // this.update(cx, |_, cx| cx.emit(rules_error)).ok();
399 // }
400 worktree
401 })
402 .collect::<Vec<_>>();
403
404 let default_user_rules = default_user_rules
405 .into_iter()
406 .flat_map(|(contents, prompt_metadata)| match contents {
407 Ok(contents) => Some(UserRulesContext {
408 uuid: match prompt_metadata.id {
409 PromptId::User { uuid } => uuid,
410 PromptId::EditWorkflow => return None,
411 },
412 title: prompt_metadata.title.map(|title| title.to_string()),
413 contents,
414 }),
415 Err(_err) => {
416 // TODO: show error message
417 // this.update(cx, |_, cx| {
418 // cx.emit(RulesLoadingError {
419 // message: format!("{err:?}").into(),
420 // });
421 // })
422 // .ok();
423 None
424 }
425 })
426 .collect::<Vec<_>>();
427
428 ProjectContext::new(worktrees, default_user_rules)
429 })
430 }
431
432 fn load_worktree_info_for_system_prompt(
433 worktree: Entity<Worktree>,
434 project: Entity<Project>,
435 cx: &mut App,
436 ) -> Task<(WorktreeContext, Option<RulesLoadingError>)> {
437 let tree = worktree.read(cx);
438 let root_name = tree.root_name_str().into();
439 let abs_path = tree.abs_path();
440
441 let mut context = WorktreeContext {
442 root_name,
443 abs_path,
444 rules_file: None,
445 };
446
447 let rules_task = Self::load_worktree_rules_file(worktree, project, cx);
448 let Some(rules_task) = rules_task else {
449 return Task::ready((context, None));
450 };
451
452 cx.spawn(async move |_| {
453 let (rules_file, rules_file_error) = match rules_task.await {
454 Ok(rules_file) => (Some(rules_file), None),
455 Err(err) => (
456 None,
457 Some(RulesLoadingError {
458 message: format!("{err}").into(),
459 }),
460 ),
461 };
462 context.rules_file = rules_file;
463 (context, rules_file_error)
464 })
465 }
466
467 fn load_worktree_rules_file(
468 worktree: Entity<Worktree>,
469 project: Entity<Project>,
470 cx: &mut App,
471 ) -> Option<Task<Result<RulesFileContext>>> {
472 let worktree = worktree.read(cx);
473 let worktree_id = worktree.id();
474 let selected_rules_file = RULES_FILE_NAMES
475 .into_iter()
476 .filter_map(|name| {
477 worktree
478 .entry_for_path(RelPath::unix(name).unwrap())
479 .filter(|entry| entry.is_file())
480 .map(|entry| entry.path.clone())
481 })
482 .next();
483
484 // Note that Cline supports `.clinerules` being a directory, but that is not currently
485 // supported. This doesn't seem to occur often in GitHub repositories.
486 selected_rules_file.map(|path_in_worktree| {
487 let project_path = ProjectPath {
488 worktree_id,
489 path: path_in_worktree.clone(),
490 };
491 let buffer_task =
492 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
493 let rope_task = cx.spawn(async move |cx| {
494 buffer_task.await?.read_with(cx, |buffer, cx| {
495 let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?;
496 anyhow::Ok((project_entry_id, buffer.as_rope().clone()))
497 })?
498 });
499 // Build a string from the rope on a background thread.
500 cx.background_spawn(async move {
501 let (project_entry_id, rope) = rope_task.await?;
502 anyhow::Ok(RulesFileContext {
503 path_in_worktree,
504 text: rope.to_string().trim().to_string(),
505 project_entry_id: project_entry_id.to_usize(),
506 })
507 })
508 })
509 }
510
511 fn handle_thread_title_updated(
512 &mut self,
513 thread: Entity<Thread>,
514 _: &TitleUpdated,
515 cx: &mut Context<Self>,
516 ) {
517 let session_id = thread.read(cx).id();
518 let Some(session) = self.sessions.get(session_id) else {
519 return;
520 };
521 let thread = thread.downgrade();
522 let acp_thread = session.acp_thread.clone();
523 cx.spawn(async move |_, cx| {
524 let title = thread.read_with(cx, |thread, _| thread.title())?;
525 let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?;
526 task.await
527 })
528 .detach_and_log_err(cx);
529 }
530
531 fn handle_thread_token_usage_updated(
532 &mut self,
533 thread: Entity<Thread>,
534 usage: &TokenUsageUpdated,
535 cx: &mut Context<Self>,
536 ) {
537 let Some(session) = self.sessions.get(thread.read(cx).id()) else {
538 return;
539 };
540 session
541 .acp_thread
542 .update(cx, |acp_thread, cx| {
543 acp_thread.update_token_usage(usage.0.clone(), cx);
544 })
545 .ok();
546 }
547
548 fn handle_project_event(
549 &mut self,
550 _project: Entity<Project>,
551 event: &project::Event,
552 _cx: &mut Context<Self>,
553 ) {
554 match event {
555 project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => {
556 self.project_context_needs_refresh.send(()).ok();
557 }
558 project::Event::WorktreeUpdatedEntries(_, items) => {
559 if items.iter().any(|(path, _, _)| {
560 RULES_FILE_NAMES
561 .iter()
562 .any(|name| path.as_ref() == RelPath::unix(name).unwrap())
563 }) {
564 self.project_context_needs_refresh.send(()).ok();
565 }
566 }
567 _ => {}
568 }
569 }
570
571 fn handle_prompts_updated_event(
572 &mut self,
573 _prompt_store: Entity<PromptStore>,
574 _event: &prompt_store::PromptsUpdatedEvent,
575 _cx: &mut Context<Self>,
576 ) {
577 self.project_context_needs_refresh.send(()).ok();
578 }
579
580 fn handle_models_updated_event(
581 &mut self,
582 _registry: Entity<LanguageModelRegistry>,
583 _event: &language_model::Event,
584 cx: &mut Context<Self>,
585 ) {
586 self.models.refresh_list(cx);
587
588 let registry = LanguageModelRegistry::read_global(cx);
589 let default_model = registry.default_model().map(|m| m.model);
590 let summarization_model = registry.thread_summary_model().map(|m| m.model);
591
592 for session in self.sessions.values_mut() {
593 session.thread.update(cx, |thread, cx| {
594 if thread.model().is_none()
595 && let Some(model) = default_model.clone()
596 {
597 thread.set_model(model, cx);
598 cx.notify();
599 }
600 thread.set_summarization_model(summarization_model.clone(), cx);
601 });
602 }
603 }
604
605 pub fn open_thread(
606 &mut self,
607 id: acp::SessionId,
608 cx: &mut Context<Self>,
609 ) -> Task<Result<Entity<AcpThread>>> {
610 let database_future = ThreadsDatabase::connect(cx);
611 cx.spawn(async move |this, cx| {
612 let database = database_future.await.map_err(|err| anyhow!(err))?;
613 let db_thread = database
614 .load_thread(id.clone())
615 .await?
616 .with_context(|| format!("no thread found with ID: {id:?}"))?;
617
618 let thread = this.update(cx, |this, cx| {
619 let action_log = cx.new(|_cx| ActionLog::new(this.project.clone()));
620 cx.new(|cx| {
621 Thread::from_db(
622 id.clone(),
623 db_thread,
624 this.project.clone(),
625 this.project_context.clone(),
626 this.context_server_registry.clone(),
627 action_log.clone(),
628 this.templates.clone(),
629 cx,
630 )
631 })
632 })?;
633 let acp_thread =
634 this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?;
635 let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
636 cx.update(|cx| {
637 NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
638 })?
639 .await?;
640 Ok(acp_thread)
641 })
642 }
643
644 pub fn thread_summary(
645 &mut self,
646 id: acp::SessionId,
647 cx: &mut Context<Self>,
648 ) -> Task<Result<SharedString>> {
649 let thread = self.open_thread(id.clone(), cx);
650 cx.spawn(async move |this, cx| {
651 let acp_thread = thread.await?;
652 let result = this
653 .update(cx, |this, cx| {
654 this.sessions
655 .get(&id)
656 .unwrap()
657 .thread
658 .update(cx, |thread, cx| thread.summary(cx))
659 })?
660 .await?;
661 drop(acp_thread);
662 Ok(result)
663 })
664 }
665
666 fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
667 if thread.read(cx).is_empty() {
668 return;
669 }
670
671 let database_future = ThreadsDatabase::connect(cx);
672 let (id, db_thread) =
673 thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
674 let Some(session) = self.sessions.get_mut(&id) else {
675 return;
676 };
677 let history = self.history.clone();
678 session.pending_save = cx.spawn(async move |_, cx| {
679 let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else {
680 return;
681 };
682 let db_thread = db_thread.await;
683 database.save_thread(id, db_thread).await.log_err();
684 history.update(cx, |history, cx| history.reload(cx)).ok();
685 });
686 }
687}
688
689/// Wrapper struct that implements the AgentConnection trait
690#[derive(Clone)]
691pub struct NativeAgentConnection(pub Entity<NativeAgent>);
692
693impl NativeAgentConnection {
694 pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
695 self.0
696 .read(cx)
697 .sessions
698 .get(session_id)
699 .map(|session| session.thread.clone())
700 }
701
702 fn run_turn(
703 &self,
704 session_id: acp::SessionId,
705 cx: &mut App,
706 f: impl 'static
707 + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
708 ) -> Task<Result<acp::PromptResponse>> {
709 let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
710 agent
711 .sessions
712 .get_mut(&session_id)
713 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
714 }) else {
715 return Task::ready(Err(anyhow!("Session not found")));
716 };
717 log::debug!("Found session for: {}", session_id);
718
719 let response_stream = match f(thread, cx) {
720 Ok(stream) => stream,
721 Err(err) => return Task::ready(Err(err)),
722 };
723 Self::handle_thread_events(response_stream, acp_thread, cx)
724 }
725
726 fn handle_thread_events(
727 mut events: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
728 acp_thread: WeakEntity<AcpThread>,
729 cx: &App,
730 ) -> Task<Result<acp::PromptResponse>> {
731 cx.spawn(async move |cx| {
732 // Handle response stream and forward to session.acp_thread
733 while let Some(result) = events.next().await {
734 match result {
735 Ok(event) => {
736 log::trace!("Received completion event: {:?}", event);
737
738 match event {
739 ThreadEvent::UserMessage(message) => {
740 acp_thread.update(cx, |thread, cx| {
741 for content in message.content {
742 thread.push_user_content_block(
743 Some(message.id.clone()),
744 content.into(),
745 cx,
746 );
747 }
748 })?;
749 }
750 ThreadEvent::AgentText(text) => {
751 acp_thread.update(cx, |thread, cx| {
752 thread.push_assistant_content_block(
753 acp::ContentBlock::Text(acp::TextContent {
754 text,
755 annotations: None,
756 meta: None,
757 }),
758 false,
759 cx,
760 )
761 })?;
762 }
763 ThreadEvent::AgentThinking(text) => {
764 acp_thread.update(cx, |thread, cx| {
765 thread.push_assistant_content_block(
766 acp::ContentBlock::Text(acp::TextContent {
767 text,
768 annotations: None,
769 meta: None,
770 }),
771 true,
772 cx,
773 )
774 })?;
775 }
776 ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
777 tool_call,
778 options,
779 response,
780 }) => {
781 let outcome_task = acp_thread.update(cx, |thread, cx| {
782 thread.request_tool_call_authorization(
783 tool_call, options, true, cx,
784 )
785 })??;
786 cx.background_spawn(async move {
787 if let acp::RequestPermissionOutcome::Selected { option_id } =
788 outcome_task.await
789 {
790 response
791 .send(option_id)
792 .map(|_| anyhow!("authorization receiver was dropped"))
793 .log_err();
794 }
795 })
796 .detach();
797 }
798 ThreadEvent::ToolCall(tool_call) => {
799 acp_thread.update(cx, |thread, cx| {
800 thread.upsert_tool_call(tool_call, cx)
801 })??;
802 }
803 ThreadEvent::ToolCallUpdate(update) => {
804 acp_thread.update(cx, |thread, cx| {
805 thread.update_tool_call(update, cx)
806 })??;
807 }
808 ThreadEvent::Retry(status) => {
809 acp_thread.update(cx, |thread, cx| {
810 thread.update_retry_status(status, cx)
811 })?;
812 }
813 ThreadEvent::Stop(stop_reason) => {
814 log::debug!("Assistant message complete: {:?}", stop_reason);
815 return Ok(acp::PromptResponse {
816 stop_reason,
817 meta: None,
818 });
819 }
820 }
821 }
822 Err(e) => {
823 log::error!("Error in model response stream: {:?}", e);
824 return Err(e);
825 }
826 }
827 }
828
829 log::debug!("Response stream completed");
830 anyhow::Ok(acp::PromptResponse {
831 stop_reason: acp::StopReason::EndTurn,
832 meta: None,
833 })
834 })
835 }
836}
837
838struct NativeAgentModelSelector {
839 session_id: acp::SessionId,
840 connection: NativeAgentConnection,
841}
842
843impl acp_thread::AgentModelSelector for NativeAgentModelSelector {
844 fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
845 log::debug!("NativeAgentConnection::list_models called");
846 let list = self.connection.0.read(cx).models.model_list.clone();
847 Task::ready(if list.is_empty() {
848 Err(anyhow::anyhow!("No models available"))
849 } else {
850 Ok(list)
851 })
852 }
853
854 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>> {
855 log::debug!(
856 "Setting model for session {}: {}",
857 self.session_id,
858 model_id
859 );
860 let Some(thread) = self
861 .connection
862 .0
863 .read(cx)
864 .sessions
865 .get(&self.session_id)
866 .map(|session| session.thread.clone())
867 else {
868 return Task::ready(Err(anyhow!("Session not found")));
869 };
870
871 let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else {
872 return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
873 };
874
875 thread.update(cx, |thread, cx| {
876 thread.set_model(model.clone(), cx);
877 });
878
879 update_settings_file(
880 self.connection.0.read(cx).fs.clone(),
881 cx,
882 move |settings, _cx| {
883 let provider = model.provider_id().0.to_string();
884 let model = model.id().0.to_string();
885 settings
886 .agent
887 .get_or_insert_default()
888 .set_model(LanguageModelSelection {
889 provider: provider.into(),
890 model,
891 });
892 },
893 );
894
895 Task::ready(Ok(()))
896 }
897
898 fn selected_model(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelInfo>> {
899 let Some(thread) = self
900 .connection
901 .0
902 .read(cx)
903 .sessions
904 .get(&self.session_id)
905 .map(|session| session.thread.clone())
906 else {
907 return Task::ready(Err(anyhow!("Session not found")));
908 };
909 let Some(model) = thread.read(cx).model() else {
910 return Task::ready(Err(anyhow!("Model not found")));
911 };
912 let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
913 else {
914 return Task::ready(Err(anyhow!("Provider not found")));
915 };
916 Task::ready(Ok(LanguageModels::map_language_model_to_info(
917 model, &provider,
918 )))
919 }
920
921 fn watch(&self, cx: &mut App) -> Option<watch::Receiver<()>> {
922 Some(self.connection.0.read(cx).models.watch())
923 }
924}
925
926impl acp_thread::AgentConnection for NativeAgentConnection {
927 fn new_thread(
928 self: Rc<Self>,
929 project: Entity<Project>,
930 cwd: &Path,
931 cx: &mut App,
932 ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
933 let agent = self.0.clone();
934 log::debug!("Creating new thread for project at: {:?}", cwd);
935
936 cx.spawn(async move |cx| {
937 log::debug!("Starting thread creation in async context");
938
939 // Create Thread
940 let thread = agent.update(
941 cx,
942 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
943 // Fetch default model from registry settings
944 let registry = LanguageModelRegistry::read_global(cx);
945 // Log available models for debugging
946 let available_count = registry.available_models(cx).count();
947 log::debug!("Total available models: {}", available_count);
948
949 let default_model = registry.default_model().and_then(|default_model| {
950 agent
951 .models
952 .model_from_id(&LanguageModels::model_id(&default_model.model))
953 });
954 Ok(cx.new(|cx| {
955 Thread::new(
956 project.clone(),
957 agent.project_context.clone(),
958 agent.context_server_registry.clone(),
959 agent.templates.clone(),
960 default_model,
961 cx,
962 )
963 }))
964 },
965 )??;
966 agent.update(cx, |agent, cx| agent.register_session(thread, cx))
967 })
968 }
969
970 fn auth_methods(&self) -> &[acp::AuthMethod] {
971 &[] // No auth for in-process
972 }
973
974 fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
975 Task::ready(Ok(()))
976 }
977
978 fn model_selector(&self, session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
979 Some(Rc::new(NativeAgentModelSelector {
980 session_id: session_id.clone(),
981 connection: self.clone(),
982 }) as Rc<dyn AgentModelSelector>)
983 }
984
985 fn prompt(
986 &self,
987 id: Option<acp_thread::UserMessageId>,
988 params: acp::PromptRequest,
989 cx: &mut App,
990 ) -> Task<Result<acp::PromptResponse>> {
991 let id = id.expect("UserMessageId is required");
992 let session_id = params.session_id.clone();
993 log::info!("Received prompt request for session: {}", session_id);
994 log::debug!("Prompt blocks count: {}", params.prompt.len());
995
996 self.run_turn(session_id, cx, |thread, cx| {
997 let content: Vec<UserMessageContent> = params
998 .prompt
999 .into_iter()
1000 .map(Into::into)
1001 .collect::<Vec<_>>();
1002 log::debug!("Converted prompt to message: {} chars", content.len());
1003 log::debug!("Message id: {:?}", id);
1004 log::debug!("Message content: {:?}", content);
1005
1006 thread.update(cx, |thread, cx| thread.send(id, content, cx))
1007 })
1008 }
1009
1010 fn resume(
1011 &self,
1012 session_id: &acp::SessionId,
1013 _cx: &App,
1014 ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
1015 Some(Rc::new(NativeAgentSessionResume {
1016 connection: self.clone(),
1017 session_id: session_id.clone(),
1018 }) as _)
1019 }
1020
1021 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1022 log::info!("Cancelling on session: {}", session_id);
1023 self.0.update(cx, |agent, cx| {
1024 if let Some(agent) = agent.sessions.get(session_id) {
1025 agent.thread.update(cx, |thread, cx| thread.cancel(cx));
1026 }
1027 });
1028 }
1029
1030 fn truncate(
1031 &self,
1032 session_id: &agent_client_protocol::SessionId,
1033 cx: &App,
1034 ) -> Option<Rc<dyn acp_thread::AgentSessionTruncate>> {
1035 self.0.read_with(cx, |agent, _cx| {
1036 agent.sessions.get(session_id).map(|session| {
1037 Rc::new(NativeAgentSessionTruncate {
1038 thread: session.thread.clone(),
1039 acp_thread: session.acp_thread.clone(),
1040 }) as _
1041 })
1042 })
1043 }
1044
1045 fn set_title(
1046 &self,
1047 session_id: &acp::SessionId,
1048 _cx: &App,
1049 ) -> Option<Rc<dyn acp_thread::AgentSessionSetTitle>> {
1050 Some(Rc::new(NativeAgentSessionSetTitle {
1051 connection: self.clone(),
1052 session_id: session_id.clone(),
1053 }) as _)
1054 }
1055
1056 fn telemetry(&self) -> Option<Rc<dyn acp_thread::AgentTelemetry>> {
1057 Some(Rc::new(self.clone()) as Rc<dyn acp_thread::AgentTelemetry>)
1058 }
1059
1060 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
1061 self
1062 }
1063}
1064
1065impl acp_thread::AgentTelemetry for NativeAgentConnection {
1066 fn agent_name(&self) -> String {
1067 "Zed".into()
1068 }
1069
1070 fn thread_data(
1071 &self,
1072 session_id: &acp::SessionId,
1073 cx: &mut App,
1074 ) -> Task<Result<serde_json::Value>> {
1075 let Some(session) = self.0.read(cx).sessions.get(session_id) else {
1076 return Task::ready(Err(anyhow!("Session not found")));
1077 };
1078
1079 let task = session.thread.read(cx).to_db(cx);
1080 cx.background_spawn(async move {
1081 serde_json::to_value(task.await).context("Failed to serialize thread")
1082 })
1083 }
1084}
1085
1086struct NativeAgentSessionTruncate {
1087 thread: Entity<Thread>,
1088 acp_thread: WeakEntity<AcpThread>,
1089}
1090
1091impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate {
1092 fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
1093 match self.thread.update(cx, |thread, cx| {
1094 thread.truncate(message_id.clone(), cx)?;
1095 Ok(thread.latest_token_usage())
1096 }) {
1097 Ok(usage) => {
1098 self.acp_thread
1099 .update(cx, |thread, cx| {
1100 thread.update_token_usage(usage, cx);
1101 })
1102 .ok();
1103 Task::ready(Ok(()))
1104 }
1105 Err(error) => Task::ready(Err(error)),
1106 }
1107 }
1108}
1109
1110struct NativeAgentSessionResume {
1111 connection: NativeAgentConnection,
1112 session_id: acp::SessionId,
1113}
1114
1115impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
1116 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
1117 self.connection
1118 .run_turn(self.session_id.clone(), cx, |thread, cx| {
1119 thread.update(cx, |thread, cx| thread.resume(cx))
1120 })
1121 }
1122}
1123
1124struct NativeAgentSessionSetTitle {
1125 connection: NativeAgentConnection,
1126 session_id: acp::SessionId,
1127}
1128
1129impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle {
1130 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>> {
1131 let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else {
1132 return Task::ready(Err(anyhow!("session not found")));
1133 };
1134 let thread = session.thread.clone();
1135 thread.update(cx, |thread, cx| thread.set_title(title, cx));
1136 Task::ready(Ok(()))
1137 }
1138}
1139
1140pub struct AcpThreadEnvironment {
1141 acp_thread: WeakEntity<AcpThread>,
1142}
1143
1144impl ThreadEnvironment for AcpThreadEnvironment {
1145 fn create_terminal(
1146 &self,
1147 command: String,
1148 cwd: Option<PathBuf>,
1149 output_byte_limit: Option<u64>,
1150 cx: &mut AsyncApp,
1151 ) -> Task<Result<Rc<dyn TerminalHandle>>> {
1152 let task = self.acp_thread.update(cx, |thread, cx| {
1153 thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx)
1154 });
1155
1156 let acp_thread = self.acp_thread.clone();
1157 cx.spawn(async move |cx| {
1158 let terminal = task?.await?;
1159
1160 let (drop_tx, drop_rx) = oneshot::channel();
1161 let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?;
1162
1163 cx.spawn(async move |cx| {
1164 drop_rx.await.ok();
1165 acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx))
1166 })
1167 .detach();
1168
1169 let handle = AcpTerminalHandle {
1170 terminal,
1171 _drop_tx: Some(drop_tx),
1172 };
1173
1174 Ok(Rc::new(handle) as _)
1175 })
1176 }
1177}
1178
1179pub struct AcpTerminalHandle {
1180 terminal: Entity<acp_thread::Terminal>,
1181 _drop_tx: Option<oneshot::Sender<()>>,
1182}
1183
1184impl TerminalHandle for AcpTerminalHandle {
1185 fn id(&self, cx: &AsyncApp) -> Result<acp::TerminalId> {
1186 self.terminal.read_with(cx, |term, _cx| term.id().clone())
1187 }
1188
1189 fn wait_for_exit(&self, cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
1190 self.terminal
1191 .read_with(cx, |term, _cx| term.wait_for_exit())
1192 }
1193
1194 fn current_output(&self, cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
1195 self.terminal
1196 .read_with(cx, |term, cx| term.current_output(cx))
1197 }
1198}
1199
1200#[cfg(test)]
1201mod tests {
1202 use crate::HistoryEntryId;
1203
1204 use super::*;
1205 use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri};
1206 use fs::FakeFs;
1207 use gpui::TestAppContext;
1208 use indoc::formatdoc;
1209 use language_model::fake_provider::FakeLanguageModel;
1210 use serde_json::json;
1211 use settings::SettingsStore;
1212 use util::{path, rel_path::rel_path};
1213
1214 #[gpui::test]
1215 async fn test_maintaining_project_context(cx: &mut TestAppContext) {
1216 init_test(cx);
1217 let fs = FakeFs::new(cx.executor());
1218 fs.insert_tree(
1219 "/",
1220 json!({
1221 "a": {}
1222 }),
1223 )
1224 .await;
1225 let project = Project::test(fs.clone(), [], cx).await;
1226 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1227 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1228 let agent = NativeAgent::new(
1229 project.clone(),
1230 history_store,
1231 Templates::new(),
1232 None,
1233 fs.clone(),
1234 &mut cx.to_async(),
1235 )
1236 .await
1237 .unwrap();
1238 agent.read_with(cx, |agent, cx| {
1239 assert_eq!(agent.project_context.read(cx).worktrees, vec![])
1240 });
1241
1242 let worktree = project
1243 .update(cx, |project, cx| project.create_worktree("/a", true, cx))
1244 .await
1245 .unwrap();
1246 cx.run_until_parked();
1247 agent.read_with(cx, |agent, cx| {
1248 assert_eq!(
1249 agent.project_context.read(cx).worktrees,
1250 vec![WorktreeContext {
1251 root_name: "a".into(),
1252 abs_path: Path::new("/a").into(),
1253 rules_file: None
1254 }]
1255 )
1256 });
1257
1258 // Creating `/a/.rules` updates the project context.
1259 fs.insert_file("/a/.rules", Vec::new()).await;
1260 cx.run_until_parked();
1261 agent.read_with(cx, |agent, cx| {
1262 let rules_entry = worktree
1263 .read(cx)
1264 .entry_for_path(rel_path(".rules"))
1265 .unwrap();
1266 assert_eq!(
1267 agent.project_context.read(cx).worktrees,
1268 vec![WorktreeContext {
1269 root_name: "a".into(),
1270 abs_path: Path::new("/a").into(),
1271 rules_file: Some(RulesFileContext {
1272 path_in_worktree: rel_path(".rules").into(),
1273 text: "".into(),
1274 project_entry_id: rules_entry.id.to_usize()
1275 })
1276 }]
1277 )
1278 });
1279 }
1280
1281 #[gpui::test]
1282 async fn test_listing_models(cx: &mut TestAppContext) {
1283 init_test(cx);
1284 let fs = FakeFs::new(cx.executor());
1285 fs.insert_tree("/", json!({ "a": {} })).await;
1286 let project = Project::test(fs.clone(), [], cx).await;
1287 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1288 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1289 let connection = NativeAgentConnection(
1290 NativeAgent::new(
1291 project.clone(),
1292 history_store,
1293 Templates::new(),
1294 None,
1295 fs.clone(),
1296 &mut cx.to_async(),
1297 )
1298 .await
1299 .unwrap(),
1300 );
1301
1302 // Create a thread/session
1303 let acp_thread = cx
1304 .update(|cx| {
1305 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1306 })
1307 .await
1308 .unwrap();
1309
1310 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1311
1312 let models = cx
1313 .update(|cx| {
1314 connection
1315 .model_selector(&session_id)
1316 .unwrap()
1317 .list_models(cx)
1318 })
1319 .await
1320 .unwrap();
1321
1322 let acp_thread::AgentModelList::Grouped(models) = models else {
1323 panic!("Unexpected model group");
1324 };
1325 assert_eq!(
1326 models,
1327 IndexMap::from_iter([(
1328 AgentModelGroupName("Fake".into()),
1329 vec![AgentModelInfo {
1330 id: acp::ModelId("fake/fake".into()),
1331 name: "Fake".into(),
1332 description: None,
1333 icon: Some(ui::IconName::ZedAssistant),
1334 }]
1335 )])
1336 );
1337 }
1338
1339 #[gpui::test]
1340 async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) {
1341 init_test(cx);
1342 let fs = FakeFs::new(cx.executor());
1343 fs.create_dir(paths::settings_file().parent().unwrap())
1344 .await
1345 .unwrap();
1346 fs.insert_file(
1347 paths::settings_file(),
1348 json!({
1349 "agent": {
1350 "default_model": {
1351 "provider": "foo",
1352 "model": "bar"
1353 }
1354 }
1355 })
1356 .to_string()
1357 .into_bytes(),
1358 )
1359 .await;
1360 let project = Project::test(fs.clone(), [], cx).await;
1361
1362 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1363 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1364
1365 // Create the agent and connection
1366 let agent = NativeAgent::new(
1367 project.clone(),
1368 history_store,
1369 Templates::new(),
1370 None,
1371 fs.clone(),
1372 &mut cx.to_async(),
1373 )
1374 .await
1375 .unwrap();
1376 let connection = NativeAgentConnection(agent.clone());
1377
1378 // Create a thread/session
1379 let acp_thread = cx
1380 .update(|cx| {
1381 Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx)
1382 })
1383 .await
1384 .unwrap();
1385
1386 let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone());
1387
1388 // Select a model
1389 let selector = connection.model_selector(&session_id).unwrap();
1390 let model_id = acp::ModelId("fake/fake".into());
1391 cx.update(|cx| selector.select_model(model_id.clone(), cx))
1392 .await
1393 .unwrap();
1394
1395 // Verify the thread has the selected model
1396 agent.read_with(cx, |agent, _| {
1397 let session = agent.sessions.get(&session_id).unwrap();
1398 session.thread.read_with(cx, |thread, _| {
1399 assert_eq!(thread.model().unwrap().id().0, "fake");
1400 });
1401 });
1402
1403 cx.run_until_parked();
1404
1405 // Verify settings file was updated
1406 let settings_content = fs.load(paths::settings_file()).await.unwrap();
1407 let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap();
1408
1409 // Check that the agent settings contain the selected model
1410 assert_eq!(
1411 settings_json["agent"]["default_model"]["model"],
1412 json!("fake")
1413 );
1414 assert_eq!(
1415 settings_json["agent"]["default_model"]["provider"],
1416 json!("fake")
1417 );
1418 }
1419
1420 #[gpui::test]
1421 #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
1422 async fn test_save_load_thread(cx: &mut TestAppContext) {
1423 init_test(cx);
1424 let fs = FakeFs::new(cx.executor());
1425 fs.insert_tree(
1426 "/",
1427 json!({
1428 "a": {
1429 "b.md": "Lorem"
1430 }
1431 }),
1432 )
1433 .await;
1434 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
1435 let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1436 let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1437 let agent = NativeAgent::new(
1438 project.clone(),
1439 history_store.clone(),
1440 Templates::new(),
1441 None,
1442 fs.clone(),
1443 &mut cx.to_async(),
1444 )
1445 .await
1446 .unwrap();
1447 let connection = Rc::new(NativeAgentConnection(agent.clone()));
1448
1449 let acp_thread = cx
1450 .update(|cx| {
1451 connection
1452 .clone()
1453 .new_thread(project.clone(), Path::new(""), cx)
1454 })
1455 .await
1456 .unwrap();
1457 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1458 let thread = agent.read_with(cx, |agent, _| {
1459 agent.sessions.get(&session_id).unwrap().thread.clone()
1460 });
1461
1462 // Ensure empty threads are not saved, even if they get mutated.
1463 let model = Arc::new(FakeLanguageModel::default());
1464 let summary_model = Arc::new(FakeLanguageModel::default());
1465 thread.update(cx, |thread, cx| {
1466 thread.set_model(model.clone(), cx);
1467 thread.set_summarization_model(Some(summary_model.clone()), cx);
1468 });
1469 cx.run_until_parked();
1470 assert_eq!(history_entries(&history_store, cx), vec![]);
1471
1472 let send = acp_thread.update(cx, |thread, cx| {
1473 thread.send(
1474 vec![
1475 "What does ".into(),
1476 acp::ContentBlock::ResourceLink(acp::ResourceLink {
1477 name: "b.md".into(),
1478 uri: MentionUri::File {
1479 abs_path: path!("/a/b.md").into(),
1480 }
1481 .to_uri()
1482 .to_string(),
1483 annotations: None,
1484 description: None,
1485 mime_type: None,
1486 size: None,
1487 title: None,
1488 meta: None,
1489 }),
1490 " mean?".into(),
1491 ],
1492 cx,
1493 )
1494 });
1495 let send = cx.foreground_executor().spawn(send);
1496 cx.run_until_parked();
1497
1498 model.send_last_completion_stream_text_chunk("Lorem.");
1499 model.end_last_completion_stream();
1500 cx.run_until_parked();
1501 summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
1502 summary_model.end_last_completion_stream();
1503
1504 send.await.unwrap();
1505 let uri = MentionUri::File {
1506 abs_path: path!("/a/b.md").into(),
1507 }
1508 .to_uri();
1509 acp_thread.read_with(cx, |thread, cx| {
1510 assert_eq!(
1511 thread.to_markdown(cx),
1512 formatdoc! {"
1513 ## User
1514
1515 What does [@b.md]({uri}) mean?
1516
1517 ## Assistant
1518
1519 Lorem.
1520
1521 "}
1522 )
1523 });
1524
1525 cx.run_until_parked();
1526
1527 // Drop the ACP thread, which should cause the session to be dropped as well.
1528 cx.update(|_| {
1529 drop(thread);
1530 drop(acp_thread);
1531 });
1532 agent.read_with(cx, |agent, _| {
1533 assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
1534 });
1535
1536 // Ensure the thread can be reloaded from disk.
1537 assert_eq!(
1538 history_entries(&history_store, cx),
1539 vec![(
1540 HistoryEntryId::AcpThread(session_id.clone()),
1541 "Explaining /a/b.md".into()
1542 )]
1543 );
1544 let acp_thread = agent
1545 .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
1546 .await
1547 .unwrap();
1548 acp_thread.read_with(cx, |thread, cx| {
1549 assert_eq!(
1550 thread.to_markdown(cx),
1551 formatdoc! {"
1552 ## User
1553
1554 What does [@b.md]({uri}) mean?
1555
1556 ## Assistant
1557
1558 Lorem.
1559
1560 "}
1561 )
1562 });
1563 }
1564
1565 fn history_entries(
1566 history: &Entity<HistoryStore>,
1567 cx: &mut TestAppContext,
1568 ) -> Vec<(HistoryEntryId, String)> {
1569 history.read_with(cx, |history, _| {
1570 history
1571 .entries()
1572 .map(|e| (e.id(), e.title().to_string()))
1573 .collect::<Vec<_>>()
1574 })
1575 }
1576
1577 fn init_test(cx: &mut TestAppContext) {
1578 env_logger::try_init().ok();
1579 cx.update(|cx| {
1580 let settings_store = SettingsStore::test(cx);
1581 cx.set_global(settings_store);
1582 Project::init_settings(cx);
1583 agent_settings::init(cx);
1584 language::init(cx);
1585 LanguageModelRegistry::test(cx);
1586 });
1587 }
1588}