1mod connection;
2mod diff;
3mod mention;
4mod terminal;
5
6use agent_settings::AgentSettings;
7use collections::HashSet;
8pub use connection::*;
9pub use diff::*;
10use language::language_settings::FormatOnSave;
11pub use mention::*;
12use project::lsp_store::{FormatTrigger, LspFormatTarget};
13use serde::{Deserialize, Serialize};
14use settings::Settings as _;
15pub use terminal::*;
16
17use action_log::ActionLog;
18use agent_client_protocol as acp;
19use anyhow::{Context as _, Result, anyhow};
20use editor::Bias;
21use futures::{FutureExt, channel::oneshot, future::BoxFuture};
22use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
23use itertools::Itertools;
24use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
25use markdown::Markdown;
26use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
27use std::collections::HashMap;
28use std::error::Error;
29use std::fmt::{Formatter, Write};
30use std::ops::Range;
31use std::process::ExitStatus;
32use std::rc::Rc;
33use std::time::{Duration, Instant};
34use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
35use ui::App;
36use util::ResultExt;
37
38#[derive(Debug)]
39pub struct UserMessage {
40 pub id: Option<UserMessageId>,
41 pub content: ContentBlock,
42 pub chunks: Vec<acp::ContentBlock>,
43 pub checkpoint: Option<Checkpoint>,
44}
45
46#[derive(Debug)]
47pub struct Checkpoint {
48 git_checkpoint: GitStoreCheckpoint,
49 pub show: bool,
50}
51
52impl UserMessage {
53 fn to_markdown(&self, cx: &App) -> String {
54 let mut markdown = String::new();
55 if self
56 .checkpoint
57 .as_ref()
58 .is_some_and(|checkpoint| checkpoint.show)
59 {
60 writeln!(markdown, "## User (checkpoint)").unwrap();
61 } else {
62 writeln!(markdown, "## User").unwrap();
63 }
64 writeln!(markdown).unwrap();
65 writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
66 writeln!(markdown).unwrap();
67 markdown
68 }
69}
70
71#[derive(Debug, PartialEq)]
72pub struct AssistantMessage {
73 pub chunks: Vec<AssistantMessageChunk>,
74}
75
76impl AssistantMessage {
77 pub fn to_markdown(&self, cx: &App) -> String {
78 format!(
79 "## Assistant\n\n{}\n\n",
80 self.chunks
81 .iter()
82 .map(|chunk| chunk.to_markdown(cx))
83 .join("\n\n")
84 )
85 }
86}
87
88#[derive(Debug, PartialEq)]
89pub enum AssistantMessageChunk {
90 Message { block: ContentBlock },
91 Thought { block: ContentBlock },
92}
93
94impl AssistantMessageChunk {
95 pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
96 Self::Message {
97 block: ContentBlock::new(chunk.into(), language_registry, cx),
98 }
99 }
100
101 fn to_markdown(&self, cx: &App) -> String {
102 match self {
103 Self::Message { block } => block.to_markdown(cx).to_string(),
104 Self::Thought { block } => {
105 format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
106 }
107 }
108 }
109}
110
111#[derive(Debug)]
112pub enum AgentThreadEntry {
113 UserMessage(UserMessage),
114 AssistantMessage(AssistantMessage),
115 ToolCall(ToolCall),
116}
117
118impl AgentThreadEntry {
119 pub fn to_markdown(&self, cx: &App) -> String {
120 match self {
121 Self::UserMessage(message) => message.to_markdown(cx),
122 Self::AssistantMessage(message) => message.to_markdown(cx),
123 Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
124 }
125 }
126
127 pub fn user_message(&self) -> Option<&UserMessage> {
128 if let AgentThreadEntry::UserMessage(message) = self {
129 Some(message)
130 } else {
131 None
132 }
133 }
134
135 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
136 if let AgentThreadEntry::ToolCall(call) = self {
137 itertools::Either::Left(call.diffs())
138 } else {
139 itertools::Either::Right(std::iter::empty())
140 }
141 }
142
143 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
144 if let AgentThreadEntry::ToolCall(call) = self {
145 itertools::Either::Left(call.terminals())
146 } else {
147 itertools::Either::Right(std::iter::empty())
148 }
149 }
150
151 pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
152 if let AgentThreadEntry::ToolCall(ToolCall {
153 locations,
154 resolved_locations,
155 ..
156 }) = self
157 {
158 Some((
159 locations.get(ix)?.clone(),
160 resolved_locations.get(ix)?.clone()?,
161 ))
162 } else {
163 None
164 }
165 }
166}
167
168#[derive(Debug)]
169pub struct ToolCall {
170 pub id: acp::ToolCallId,
171 pub label: Entity<Markdown>,
172 pub kind: acp::ToolKind,
173 pub content: Vec<ToolCallContent>,
174 pub status: ToolCallStatus,
175 pub locations: Vec<acp::ToolCallLocation>,
176 pub resolved_locations: Vec<Option<AgentLocation>>,
177 pub raw_input: Option<serde_json::Value>,
178 pub raw_output: Option<serde_json::Value>,
179}
180
181impl ToolCall {
182 fn from_acp(
183 tool_call: acp::ToolCall,
184 status: ToolCallStatus,
185 language_registry: Arc<LanguageRegistry>,
186 cx: &mut App,
187 ) -> Self {
188 let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
189 first_line.to_owned() + "…"
190 } else {
191 tool_call.title
192 };
193 Self {
194 id: tool_call.id,
195 label: cx
196 .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
197 kind: tool_call.kind,
198 content: tool_call
199 .content
200 .into_iter()
201 .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
202 .collect(),
203 locations: tool_call.locations,
204 resolved_locations: Vec::default(),
205 status,
206 raw_input: tool_call.raw_input,
207 raw_output: tool_call.raw_output,
208 }
209 }
210
211 fn update_fields(
212 &mut self,
213 fields: acp::ToolCallUpdateFields,
214 language_registry: Arc<LanguageRegistry>,
215 cx: &mut App,
216 ) {
217 let acp::ToolCallUpdateFields {
218 kind,
219 status,
220 title,
221 content,
222 locations,
223 raw_input,
224 raw_output,
225 } = fields;
226
227 if let Some(kind) = kind {
228 self.kind = kind;
229 }
230
231 if let Some(status) = status {
232 self.status = status.into();
233 }
234
235 if let Some(title) = title {
236 self.label.update(cx, |label, cx| {
237 if let Some((first_line, _)) = title.split_once("\n") {
238 label.replace(first_line.to_owned() + "…", cx)
239 } else {
240 label.replace(title, cx);
241 }
242 });
243 }
244
245 if let Some(content) = content {
246 let new_content_len = content.len();
247 let mut content = content.into_iter();
248
249 // Reuse existing content if we can
250 for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
251 old.update_from_acp(new, language_registry.clone(), cx);
252 }
253 for new in content {
254 self.content.push(ToolCallContent::from_acp(
255 new,
256 language_registry.clone(),
257 cx,
258 ))
259 }
260 self.content.truncate(new_content_len);
261 }
262
263 if let Some(locations) = locations {
264 self.locations = locations;
265 }
266
267 if let Some(raw_input) = raw_input {
268 self.raw_input = Some(raw_input);
269 }
270
271 if let Some(raw_output) = raw_output {
272 if self.content.is_empty()
273 && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
274 {
275 self.content
276 .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
277 markdown,
278 }));
279 }
280 self.raw_output = Some(raw_output);
281 }
282 }
283
284 pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
285 self.content.iter().filter_map(|content| match content {
286 ToolCallContent::Diff(diff) => Some(diff),
287 ToolCallContent::ContentBlock(_) => None,
288 ToolCallContent::Terminal(_) => None,
289 })
290 }
291
292 pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
293 self.content.iter().filter_map(|content| match content {
294 ToolCallContent::Terminal(terminal) => Some(terminal),
295 ToolCallContent::ContentBlock(_) => None,
296 ToolCallContent::Diff(_) => None,
297 })
298 }
299
300 fn to_markdown(&self, cx: &App) -> String {
301 let mut markdown = format!(
302 "**Tool Call: {}**\nStatus: {}\n\n",
303 self.label.read(cx).source(),
304 self.status
305 );
306 for content in &self.content {
307 markdown.push_str(content.to_markdown(cx).as_str());
308 markdown.push_str("\n\n");
309 }
310 markdown
311 }
312
313 async fn resolve_location(
314 location: acp::ToolCallLocation,
315 project: WeakEntity<Project>,
316 cx: &mut AsyncApp,
317 ) -> Option<AgentLocation> {
318 let buffer = project
319 .update(cx, |project, cx| {
320 project
321 .project_path_for_absolute_path(&location.path, cx)
322 .map(|path| project.open_buffer(path, cx))
323 })
324 .ok()??;
325 let buffer = buffer.await.log_err()?;
326 let position = buffer
327 .update(cx, |buffer, _| {
328 if let Some(row) = location.line {
329 let snapshot = buffer.snapshot();
330 let column = snapshot.indent_size_for_line(row).len;
331 let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
332 snapshot.anchor_before(point)
333 } else {
334 Anchor::MIN
335 }
336 })
337 .ok()?;
338
339 Some(AgentLocation {
340 buffer: buffer.downgrade(),
341 position,
342 })
343 }
344
345 fn resolve_locations(
346 &self,
347 project: Entity<Project>,
348 cx: &mut App,
349 ) -> Task<Vec<Option<AgentLocation>>> {
350 let locations = self.locations.clone();
351 project.update(cx, |_, cx| {
352 cx.spawn(async move |project, cx| {
353 let mut new_locations = Vec::new();
354 for location in locations {
355 new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
356 }
357 new_locations
358 })
359 })
360 }
361}
362
363#[derive(Debug)]
364pub enum ToolCallStatus {
365 /// The tool call hasn't started running yet, but we start showing it to
366 /// the user.
367 Pending,
368 /// The tool call is waiting for confirmation from the user.
369 WaitingForConfirmation {
370 options: Vec<acp::PermissionOption>,
371 respond_tx: oneshot::Sender<acp::PermissionOptionId>,
372 },
373 /// The tool call is currently running.
374 InProgress,
375 /// The tool call completed successfully.
376 Completed,
377 /// The tool call failed.
378 Failed,
379 /// The user rejected the tool call.
380 Rejected,
381 /// The user canceled generation so the tool call was canceled.
382 Canceled,
383}
384
385impl From<acp::ToolCallStatus> for ToolCallStatus {
386 fn from(status: acp::ToolCallStatus) -> Self {
387 match status {
388 acp::ToolCallStatus::Pending => Self::Pending,
389 acp::ToolCallStatus::InProgress => Self::InProgress,
390 acp::ToolCallStatus::Completed => Self::Completed,
391 acp::ToolCallStatus::Failed => Self::Failed,
392 }
393 }
394}
395
396impl Display for ToolCallStatus {
397 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
398 write!(
399 f,
400 "{}",
401 match self {
402 ToolCallStatus::Pending => "Pending",
403 ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
404 ToolCallStatus::InProgress => "In Progress",
405 ToolCallStatus::Completed => "Completed",
406 ToolCallStatus::Failed => "Failed",
407 ToolCallStatus::Rejected => "Rejected",
408 ToolCallStatus::Canceled => "Canceled",
409 }
410 )
411 }
412}
413
414#[derive(Debug, PartialEq, Clone)]
415pub enum ContentBlock {
416 Empty,
417 Markdown { markdown: Entity<Markdown> },
418 ResourceLink { resource_link: acp::ResourceLink },
419}
420
421impl ContentBlock {
422 pub fn new(
423 block: acp::ContentBlock,
424 language_registry: &Arc<LanguageRegistry>,
425 cx: &mut App,
426 ) -> Self {
427 let mut this = Self::Empty;
428 this.append(block, language_registry, cx);
429 this
430 }
431
432 pub fn new_combined(
433 blocks: impl IntoIterator<Item = acp::ContentBlock>,
434 language_registry: Arc<LanguageRegistry>,
435 cx: &mut App,
436 ) -> Self {
437 let mut this = Self::Empty;
438 for block in blocks {
439 this.append(block, &language_registry, cx);
440 }
441 this
442 }
443
444 pub fn append(
445 &mut self,
446 block: acp::ContentBlock,
447 language_registry: &Arc<LanguageRegistry>,
448 cx: &mut App,
449 ) {
450 if matches!(self, ContentBlock::Empty)
451 && let acp::ContentBlock::ResourceLink(resource_link) = block
452 {
453 *self = ContentBlock::ResourceLink { resource_link };
454 return;
455 }
456
457 let new_content = self.block_string_contents(block);
458
459 match self {
460 ContentBlock::Empty => {
461 *self = Self::create_markdown_block(new_content, language_registry, cx);
462 }
463 ContentBlock::Markdown { markdown } => {
464 markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
465 }
466 ContentBlock::ResourceLink { resource_link } => {
467 let existing_content = Self::resource_link_md(&resource_link.uri);
468 let combined = format!("{}\n{}", existing_content, new_content);
469
470 *self = Self::create_markdown_block(combined, language_registry, cx);
471 }
472 }
473 }
474
475 fn create_markdown_block(
476 content: String,
477 language_registry: &Arc<LanguageRegistry>,
478 cx: &mut App,
479 ) -> ContentBlock {
480 ContentBlock::Markdown {
481 markdown: cx
482 .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
483 }
484 }
485
486 fn block_string_contents(&self, block: acp::ContentBlock) -> String {
487 match block {
488 acp::ContentBlock::Text(text_content) => text_content.text,
489 acp::ContentBlock::ResourceLink(resource_link) => {
490 Self::resource_link_md(&resource_link.uri)
491 }
492 acp::ContentBlock::Resource(acp::EmbeddedResource {
493 resource:
494 acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
495 uri,
496 ..
497 }),
498 ..
499 }) => Self::resource_link_md(&uri),
500 acp::ContentBlock::Image(image) => Self::image_md(&image),
501 acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
502 }
503 }
504
505 fn resource_link_md(uri: &str) -> String {
506 if let Some(uri) = MentionUri::parse(uri).log_err() {
507 uri.as_link().to_string()
508 } else {
509 uri.to_string()
510 }
511 }
512
513 fn image_md(_image: &acp::ImageContent) -> String {
514 "`Image`".into()
515 }
516
517 pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
518 match self {
519 ContentBlock::Empty => "",
520 ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
521 ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
522 }
523 }
524
525 pub fn markdown(&self) -> Option<&Entity<Markdown>> {
526 match self {
527 ContentBlock::Empty => None,
528 ContentBlock::Markdown { markdown } => Some(markdown),
529 ContentBlock::ResourceLink { .. } => None,
530 }
531 }
532
533 pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
534 match self {
535 ContentBlock::ResourceLink { resource_link } => Some(resource_link),
536 _ => None,
537 }
538 }
539}
540
541#[derive(Debug)]
542pub enum ToolCallContent {
543 ContentBlock(ContentBlock),
544 Diff(Entity<Diff>),
545 Terminal(Entity<Terminal>),
546}
547
548impl ToolCallContent {
549 pub fn from_acp(
550 content: acp::ToolCallContent,
551 language_registry: Arc<LanguageRegistry>,
552 cx: &mut App,
553 ) -> Self {
554 match content {
555 acp::ToolCallContent::Content { content } => {
556 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
557 }
558 acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
559 Diff::finalized(
560 diff.path,
561 diff.old_text,
562 diff.new_text,
563 language_registry,
564 cx,
565 )
566 })),
567 }
568 }
569
570 pub fn update_from_acp(
571 &mut self,
572 new: acp::ToolCallContent,
573 language_registry: Arc<LanguageRegistry>,
574 cx: &mut App,
575 ) {
576 let needs_update = match (&self, &new) {
577 (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
578 old_diff.read(cx).needs_update(
579 new_diff.old_text.as_deref().unwrap_or(""),
580 &new_diff.new_text,
581 cx,
582 )
583 }
584 _ => true,
585 };
586
587 if needs_update {
588 *self = Self::from_acp(new, language_registry, cx);
589 }
590 }
591
592 pub fn to_markdown(&self, cx: &App) -> String {
593 match self {
594 Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
595 Self::Diff(diff) => diff.read(cx).to_markdown(cx),
596 Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
597 }
598 }
599}
600
601#[derive(Debug, PartialEq)]
602pub enum ToolCallUpdate {
603 UpdateFields(acp::ToolCallUpdate),
604 UpdateDiff(ToolCallUpdateDiff),
605 UpdateTerminal(ToolCallUpdateTerminal),
606}
607
608impl ToolCallUpdate {
609 fn id(&self) -> &acp::ToolCallId {
610 match self {
611 Self::UpdateFields(update) => &update.id,
612 Self::UpdateDiff(diff) => &diff.id,
613 Self::UpdateTerminal(terminal) => &terminal.id,
614 }
615 }
616}
617
618impl From<acp::ToolCallUpdate> for ToolCallUpdate {
619 fn from(update: acp::ToolCallUpdate) -> Self {
620 Self::UpdateFields(update)
621 }
622}
623
624impl From<ToolCallUpdateDiff> for ToolCallUpdate {
625 fn from(diff: ToolCallUpdateDiff) -> Self {
626 Self::UpdateDiff(diff)
627 }
628}
629
630#[derive(Debug, PartialEq)]
631pub struct ToolCallUpdateDiff {
632 pub id: acp::ToolCallId,
633 pub diff: Entity<Diff>,
634}
635
636impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
637 fn from(terminal: ToolCallUpdateTerminal) -> Self {
638 Self::UpdateTerminal(terminal)
639 }
640}
641
642#[derive(Debug, PartialEq)]
643pub struct ToolCallUpdateTerminal {
644 pub id: acp::ToolCallId,
645 pub terminal: Entity<Terminal>,
646}
647
648#[derive(Debug, Default)]
649pub struct Plan {
650 pub entries: Vec<PlanEntry>,
651}
652
653#[derive(Debug)]
654pub struct PlanStats<'a> {
655 pub in_progress_entry: Option<&'a PlanEntry>,
656 pub pending: u32,
657 pub completed: u32,
658}
659
660impl Plan {
661 pub fn is_empty(&self) -> bool {
662 self.entries.is_empty()
663 }
664
665 pub fn stats(&self) -> PlanStats<'_> {
666 let mut stats = PlanStats {
667 in_progress_entry: None,
668 pending: 0,
669 completed: 0,
670 };
671
672 for entry in &self.entries {
673 match &entry.status {
674 acp::PlanEntryStatus::Pending => {
675 stats.pending += 1;
676 }
677 acp::PlanEntryStatus::InProgress => {
678 stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
679 }
680 acp::PlanEntryStatus::Completed => {
681 stats.completed += 1;
682 }
683 }
684 }
685
686 stats
687 }
688}
689
690#[derive(Debug)]
691pub struct PlanEntry {
692 pub content: Entity<Markdown>,
693 pub priority: acp::PlanEntryPriority,
694 pub status: acp::PlanEntryStatus,
695}
696
697impl PlanEntry {
698 pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
699 Self {
700 content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
701 priority: entry.priority,
702 status: entry.status,
703 }
704 }
705}
706
707#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
708pub struct TokenUsage {
709 pub max_tokens: u64,
710 pub used_tokens: u64,
711}
712
713impl TokenUsage {
714 pub fn ratio(&self) -> TokenUsageRatio {
715 #[cfg(debug_assertions)]
716 let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
717 .unwrap_or("0.8".to_string())
718 .parse()
719 .unwrap();
720 #[cfg(not(debug_assertions))]
721 let warning_threshold: f32 = 0.8;
722
723 // When the maximum is unknown because there is no selected model,
724 // avoid showing the token limit warning.
725 if self.max_tokens == 0 {
726 TokenUsageRatio::Normal
727 } else if self.used_tokens >= self.max_tokens {
728 TokenUsageRatio::Exceeded
729 } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
730 TokenUsageRatio::Warning
731 } else {
732 TokenUsageRatio::Normal
733 }
734 }
735}
736
737#[derive(Debug, Clone, PartialEq, Eq)]
738pub enum TokenUsageRatio {
739 Normal,
740 Warning,
741 Exceeded,
742}
743
744#[derive(Debug, Clone)]
745pub struct RetryStatus {
746 pub last_error: SharedString,
747 pub attempt: usize,
748 pub max_attempts: usize,
749 pub started_at: Instant,
750 pub duration: Duration,
751}
752
753pub struct AcpThread {
754 title: SharedString,
755 entries: Vec<AgentThreadEntry>,
756 plan: Plan,
757 project: Entity<Project>,
758 action_log: Entity<ActionLog>,
759 shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
760 send_task: Option<Task<()>>,
761 connection: Rc<dyn AgentConnection>,
762 session_id: acp::SessionId,
763 token_usage: Option<TokenUsage>,
764 prompt_capabilities: acp::PromptCapabilities,
765 _observe_prompt_capabilities: Task<anyhow::Result<()>>,
766}
767
768#[derive(Debug)]
769pub enum AcpThreadEvent {
770 NewEntry,
771 TitleUpdated,
772 TokenUsageUpdated,
773 EntryUpdated(usize),
774 EntriesRemoved(Range<usize>),
775 ToolAuthorizationRequired,
776 Retry(RetryStatus),
777 Stopped,
778 Error,
779 LoadError(LoadError),
780 PromptCapabilitiesUpdated,
781}
782
783impl EventEmitter<AcpThreadEvent> for AcpThread {}
784
785#[derive(PartialEq, Eq, Debug)]
786pub enum ThreadStatus {
787 Idle,
788 WaitingForToolConfirmation,
789 Generating,
790}
791
792#[derive(Debug, Clone)]
793pub enum LoadError {
794 Unsupported {
795 command: SharedString,
796 current_version: SharedString,
797 minimum_version: SharedString,
798 },
799 FailedToInstall(SharedString),
800 Exited {
801 status: ExitStatus,
802 },
803 Other(SharedString),
804}
805
806impl Display for LoadError {
807 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
808 match self {
809 LoadError::Unsupported {
810 command: path,
811 current_version,
812 minimum_version,
813 } => {
814 write!(
815 f,
816 "version {current_version} from {path} is not supported (need at least {minimum_version})"
817 )
818 }
819 LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
820 LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
821 LoadError::Other(msg) => write!(f, "{msg}"),
822 }
823 }
824}
825
826impl Error for LoadError {}
827
828impl AcpThread {
829 pub fn new(
830 title: impl Into<SharedString>,
831 connection: Rc<dyn AgentConnection>,
832 project: Entity<Project>,
833 action_log: Entity<ActionLog>,
834 session_id: acp::SessionId,
835 mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
836 cx: &mut Context<Self>,
837 ) -> Self {
838 let prompt_capabilities = *prompt_capabilities_rx.borrow();
839 let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
840 loop {
841 let caps = prompt_capabilities_rx.recv().await?;
842 this.update(cx, |this, cx| {
843 this.prompt_capabilities = caps;
844 cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
845 })?;
846 }
847 });
848
849 Self {
850 action_log,
851 shared_buffers: Default::default(),
852 entries: Default::default(),
853 plan: Default::default(),
854 title: title.into(),
855 project,
856 send_task: None,
857 connection,
858 session_id,
859 token_usage: None,
860 prompt_capabilities,
861 _observe_prompt_capabilities: task,
862 }
863 }
864
865 pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
866 self.prompt_capabilities
867 }
868
869 pub fn connection(&self) -> &Rc<dyn AgentConnection> {
870 &self.connection
871 }
872
873 pub fn action_log(&self) -> &Entity<ActionLog> {
874 &self.action_log
875 }
876
877 pub fn project(&self) -> &Entity<Project> {
878 &self.project
879 }
880
881 pub fn title(&self) -> SharedString {
882 self.title.clone()
883 }
884
885 pub fn entries(&self) -> &[AgentThreadEntry] {
886 &self.entries
887 }
888
889 pub fn session_id(&self) -> &acp::SessionId {
890 &self.session_id
891 }
892
893 pub fn status(&self) -> ThreadStatus {
894 if self.send_task.is_some() {
895 if self.waiting_for_tool_confirmation() {
896 ThreadStatus::WaitingForToolConfirmation
897 } else {
898 ThreadStatus::Generating
899 }
900 } else {
901 ThreadStatus::Idle
902 }
903 }
904
905 pub fn token_usage(&self) -> Option<&TokenUsage> {
906 self.token_usage.as_ref()
907 }
908
909 pub fn has_pending_edit_tool_calls(&self) -> bool {
910 for entry in self.entries.iter().rev() {
911 match entry {
912 AgentThreadEntry::UserMessage(_) => return false,
913 AgentThreadEntry::ToolCall(
914 call @ ToolCall {
915 status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
916 ..
917 },
918 ) if call.diffs().next().is_some() => {
919 return true;
920 }
921 AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
922 }
923 }
924
925 false
926 }
927
928 pub fn used_tools_since_last_user_message(&self) -> bool {
929 for entry in self.entries.iter().rev() {
930 match entry {
931 AgentThreadEntry::UserMessage(..) => return false,
932 AgentThreadEntry::AssistantMessage(..) => continue,
933 AgentThreadEntry::ToolCall(..) => return true,
934 }
935 }
936
937 false
938 }
939
940 pub fn handle_session_update(
941 &mut self,
942 update: acp::SessionUpdate,
943 cx: &mut Context<Self>,
944 ) -> Result<(), acp::Error> {
945 match update {
946 acp::SessionUpdate::UserMessageChunk { content } => {
947 self.push_user_content_block(None, content, cx);
948 }
949 acp::SessionUpdate::AgentMessageChunk { content } => {
950 self.push_assistant_content_block(content, false, cx);
951 }
952 acp::SessionUpdate::AgentThoughtChunk { content } => {
953 self.push_assistant_content_block(content, true, cx);
954 }
955 acp::SessionUpdate::ToolCall(tool_call) => {
956 self.upsert_tool_call(tool_call, cx)?;
957 }
958 acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
959 self.update_tool_call(tool_call_update, cx)?;
960 }
961 acp::SessionUpdate::Plan(plan) => {
962 self.update_plan(plan, cx);
963 }
964 }
965 Ok(())
966 }
967
968 pub fn push_user_content_block(
969 &mut self,
970 message_id: Option<UserMessageId>,
971 chunk: acp::ContentBlock,
972 cx: &mut Context<Self>,
973 ) {
974 let language_registry = self.project.read(cx).languages().clone();
975 let entries_len = self.entries.len();
976
977 if let Some(last_entry) = self.entries.last_mut()
978 && let AgentThreadEntry::UserMessage(UserMessage {
979 id,
980 content,
981 chunks,
982 ..
983 }) = last_entry
984 {
985 *id = message_id.or(id.take());
986 content.append(chunk.clone(), &language_registry, cx);
987 chunks.push(chunk);
988 let idx = entries_len - 1;
989 cx.emit(AcpThreadEvent::EntryUpdated(idx));
990 } else {
991 let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
992 self.push_entry(
993 AgentThreadEntry::UserMessage(UserMessage {
994 id: message_id,
995 content,
996 chunks: vec![chunk],
997 checkpoint: None,
998 }),
999 cx,
1000 );
1001 }
1002 }
1003
1004 pub fn push_assistant_content_block(
1005 &mut self,
1006 chunk: acp::ContentBlock,
1007 is_thought: bool,
1008 cx: &mut Context<Self>,
1009 ) {
1010 let language_registry = self.project.read(cx).languages().clone();
1011 let entries_len = self.entries.len();
1012 if let Some(last_entry) = self.entries.last_mut()
1013 && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1014 {
1015 let idx = entries_len - 1;
1016 cx.emit(AcpThreadEvent::EntryUpdated(idx));
1017 match (chunks.last_mut(), is_thought) {
1018 (Some(AssistantMessageChunk::Message { block }), false)
1019 | (Some(AssistantMessageChunk::Thought { block }), true) => {
1020 block.append(chunk, &language_registry, cx)
1021 }
1022 _ => {
1023 let block = ContentBlock::new(chunk, &language_registry, cx);
1024 if is_thought {
1025 chunks.push(AssistantMessageChunk::Thought { block })
1026 } else {
1027 chunks.push(AssistantMessageChunk::Message { block })
1028 }
1029 }
1030 }
1031 } else {
1032 let block = ContentBlock::new(chunk, &language_registry, cx);
1033 let chunk = if is_thought {
1034 AssistantMessageChunk::Thought { block }
1035 } else {
1036 AssistantMessageChunk::Message { block }
1037 };
1038
1039 self.push_entry(
1040 AgentThreadEntry::AssistantMessage(AssistantMessage {
1041 chunks: vec![chunk],
1042 }),
1043 cx,
1044 );
1045 }
1046 }
1047
1048 fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1049 self.entries.push(entry);
1050 cx.emit(AcpThreadEvent::NewEntry);
1051 }
1052
1053 pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1054 self.connection.set_title(&self.session_id, cx).is_some()
1055 }
1056
1057 pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1058 if title != self.title {
1059 self.title = title.clone();
1060 cx.emit(AcpThreadEvent::TitleUpdated);
1061 if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1062 return set_title.run(title, cx);
1063 }
1064 }
1065 Task::ready(Ok(()))
1066 }
1067
1068 pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1069 self.token_usage = usage;
1070 cx.emit(AcpThreadEvent::TokenUsageUpdated);
1071 }
1072
1073 pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1074 cx.emit(AcpThreadEvent::Retry(status));
1075 }
1076
1077 pub fn update_tool_call(
1078 &mut self,
1079 update: impl Into<ToolCallUpdate>,
1080 cx: &mut Context<Self>,
1081 ) -> Result<()> {
1082 let update = update.into();
1083 let languages = self.project.read(cx).languages().clone();
1084
1085 let (ix, current_call) = self
1086 .tool_call_mut(update.id())
1087 .context("Tool call not found")?;
1088 match update {
1089 ToolCallUpdate::UpdateFields(update) => {
1090 let location_updated = update.fields.locations.is_some();
1091 current_call.update_fields(update.fields, languages, cx);
1092 if location_updated {
1093 self.resolve_locations(update.id, cx);
1094 }
1095 }
1096 ToolCallUpdate::UpdateDiff(update) => {
1097 current_call.content.clear();
1098 current_call
1099 .content
1100 .push(ToolCallContent::Diff(update.diff));
1101 }
1102 ToolCallUpdate::UpdateTerminal(update) => {
1103 current_call.content.clear();
1104 current_call
1105 .content
1106 .push(ToolCallContent::Terminal(update.terminal));
1107 }
1108 }
1109
1110 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1111
1112 Ok(())
1113 }
1114
1115 /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1116 pub fn upsert_tool_call(
1117 &mut self,
1118 tool_call: acp::ToolCall,
1119 cx: &mut Context<Self>,
1120 ) -> Result<(), acp::Error> {
1121 let status = tool_call.status.into();
1122 self.upsert_tool_call_inner(tool_call.into(), status, cx)
1123 }
1124
1125 /// Fails if id does not match an existing entry.
1126 pub fn upsert_tool_call_inner(
1127 &mut self,
1128 tool_call_update: acp::ToolCallUpdate,
1129 status: ToolCallStatus,
1130 cx: &mut Context<Self>,
1131 ) -> Result<(), acp::Error> {
1132 let language_registry = self.project.read(cx).languages().clone();
1133 let id = tool_call_update.id.clone();
1134
1135 if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1136 current_call.update_fields(tool_call_update.fields, language_registry, cx);
1137 current_call.status = status;
1138
1139 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1140 } else {
1141 let call =
1142 ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1143 self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1144 };
1145
1146 self.resolve_locations(id, cx);
1147 Ok(())
1148 }
1149
1150 fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1151 // The tool call we are looking for is typically the last one, or very close to the end.
1152 // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1153 self.entries
1154 .iter_mut()
1155 .enumerate()
1156 .rev()
1157 .find_map(|(index, tool_call)| {
1158 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1159 && &tool_call.id == id
1160 {
1161 Some((index, tool_call))
1162 } else {
1163 None
1164 }
1165 })
1166 }
1167
1168 pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1169 self.entries
1170 .iter()
1171 .enumerate()
1172 .rev()
1173 .find_map(|(index, tool_call)| {
1174 if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1175 && &tool_call.id == id
1176 {
1177 Some((index, tool_call))
1178 } else {
1179 None
1180 }
1181 })
1182 }
1183
1184 pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1185 let project = self.project.clone();
1186 let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1187 return;
1188 };
1189 let task = tool_call.resolve_locations(project, cx);
1190 cx.spawn(async move |this, cx| {
1191 let resolved_locations = task.await;
1192 this.update(cx, |this, cx| {
1193 let project = this.project.clone();
1194 let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1195 return;
1196 };
1197 if let Some(Some(location)) = resolved_locations.last() {
1198 project.update(cx, |project, cx| {
1199 if let Some(agent_location) = project.agent_location() {
1200 let should_ignore = agent_location.buffer == location.buffer
1201 && location
1202 .buffer
1203 .update(cx, |buffer, _| {
1204 let snapshot = buffer.snapshot();
1205 let old_position =
1206 agent_location.position.to_point(&snapshot);
1207 let new_position = location.position.to_point(&snapshot);
1208 // ignore this so that when we get updates from the edit tool
1209 // the position doesn't reset to the startof line
1210 old_position.row == new_position.row
1211 && old_position.column > new_position.column
1212 })
1213 .ok()
1214 .unwrap_or_default();
1215 if !should_ignore {
1216 project.set_agent_location(Some(location.clone()), cx);
1217 }
1218 }
1219 });
1220 }
1221 if tool_call.resolved_locations != resolved_locations {
1222 tool_call.resolved_locations = resolved_locations;
1223 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224 }
1225 })
1226 })
1227 .detach();
1228 }
1229
1230 pub fn request_tool_call_authorization(
1231 &mut self,
1232 tool_call: acp::ToolCallUpdate,
1233 options: Vec<acp::PermissionOption>,
1234 cx: &mut Context<Self>,
1235 ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1236 let (tx, rx) = oneshot::channel();
1237
1238 if AgentSettings::get_global(cx).always_allow_tool_actions {
1239 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1240 // some tools would (incorrectly) continue to auto-accept.
1241 if let Some(allow_once_option) = options.iter().find_map(|option| {
1242 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1243 Some(option.id.clone())
1244 } else {
1245 None
1246 }
1247 }) {
1248 self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1249 return Ok(async {
1250 acp::RequestPermissionOutcome::Selected {
1251 option_id: allow_once_option,
1252 }
1253 }
1254 .boxed());
1255 }
1256 }
1257
1258 let status = ToolCallStatus::WaitingForConfirmation {
1259 options,
1260 respond_tx: tx,
1261 };
1262
1263 self.upsert_tool_call_inner(tool_call, status, cx)?;
1264 cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1265
1266 let fut = async {
1267 match rx.await {
1268 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1269 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1270 }
1271 }
1272 .boxed();
1273
1274 Ok(fut)
1275 }
1276
1277 pub fn authorize_tool_call(
1278 &mut self,
1279 id: acp::ToolCallId,
1280 option_id: acp::PermissionOptionId,
1281 option_kind: acp::PermissionOptionKind,
1282 cx: &mut Context<Self>,
1283 ) {
1284 let Some((ix, call)) = self.tool_call_mut(&id) else {
1285 return;
1286 };
1287
1288 let new_status = match option_kind {
1289 acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1290 ToolCallStatus::Rejected
1291 }
1292 acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1293 ToolCallStatus::InProgress
1294 }
1295 };
1296
1297 let curr_status = mem::replace(&mut call.status, new_status);
1298
1299 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1300 respond_tx.send(option_id).log_err();
1301 } else if cfg!(debug_assertions) {
1302 panic!("tried to authorize an already authorized tool call");
1303 }
1304
1305 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1306 }
1307
1308 /// Returns true if the last turn is awaiting tool authorization
1309 pub fn waiting_for_tool_confirmation(&self) -> bool {
1310 for entry in self.entries.iter().rev() {
1311 match &entry {
1312 AgentThreadEntry::ToolCall(call) => match call.status {
1313 ToolCallStatus::WaitingForConfirmation { .. } => return true,
1314 ToolCallStatus::Pending
1315 | ToolCallStatus::InProgress
1316 | ToolCallStatus::Completed
1317 | ToolCallStatus::Failed
1318 | ToolCallStatus::Rejected
1319 | ToolCallStatus::Canceled => continue,
1320 },
1321 AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1322 // Reached the beginning of the turn
1323 return false;
1324 }
1325 }
1326 }
1327 false
1328 }
1329
1330 pub fn plan(&self) -> &Plan {
1331 &self.plan
1332 }
1333
1334 pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1335 let new_entries_len = request.entries.len();
1336 let mut new_entries = request.entries.into_iter();
1337
1338 // Reuse existing markdown to prevent flickering
1339 for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1340 let PlanEntry {
1341 content,
1342 priority,
1343 status,
1344 } = old;
1345 content.update(cx, |old, cx| {
1346 old.replace(new.content, cx);
1347 });
1348 *priority = new.priority;
1349 *status = new.status;
1350 }
1351 for new in new_entries {
1352 self.plan.entries.push(PlanEntry::from_acp(new, cx))
1353 }
1354 self.plan.entries.truncate(new_entries_len);
1355
1356 cx.notify();
1357 }
1358
1359 fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1360 self.plan
1361 .entries
1362 .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1363 cx.notify();
1364 }
1365
1366 #[cfg(any(test, feature = "test-support"))]
1367 pub fn send_raw(
1368 &mut self,
1369 message: &str,
1370 cx: &mut Context<Self>,
1371 ) -> BoxFuture<'static, Result<()>> {
1372 self.send(
1373 vec![acp::ContentBlock::Text(acp::TextContent {
1374 text: message.to_string(),
1375 annotations: None,
1376 })],
1377 cx,
1378 )
1379 }
1380
1381 pub fn send(
1382 &mut self,
1383 message: Vec<acp::ContentBlock>,
1384 cx: &mut Context<Self>,
1385 ) -> BoxFuture<'static, Result<()>> {
1386 let block = ContentBlock::new_combined(
1387 message.clone(),
1388 self.project.read(cx).languages().clone(),
1389 cx,
1390 );
1391 let request = acp::PromptRequest {
1392 prompt: message.clone(),
1393 session_id: self.session_id.clone(),
1394 };
1395 let git_store = self.project.read(cx).git_store().clone();
1396
1397 let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1398 Some(UserMessageId::new())
1399 } else {
1400 None
1401 };
1402
1403 self.run_turn(cx, async move |this, cx| {
1404 this.update(cx, |this, cx| {
1405 this.push_entry(
1406 AgentThreadEntry::UserMessage(UserMessage {
1407 id: message_id.clone(),
1408 content: block,
1409 chunks: message,
1410 checkpoint: None,
1411 }),
1412 cx,
1413 );
1414 })
1415 .ok();
1416
1417 let old_checkpoint = git_store
1418 .update(cx, |git, cx| git.checkpoint(cx))?
1419 .await
1420 .context("failed to get old checkpoint")
1421 .log_err();
1422 this.update(cx, |this, cx| {
1423 if let Some((_ix, message)) = this.last_user_message() {
1424 message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1425 git_checkpoint,
1426 show: false,
1427 });
1428 }
1429 this.connection.prompt(message_id, request, cx)
1430 })?
1431 .await
1432 })
1433 }
1434
1435 pub fn can_resume(&self, cx: &App) -> bool {
1436 self.connection.resume(&self.session_id, cx).is_some()
1437 }
1438
1439 pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1440 self.run_turn(cx, async move |this, cx| {
1441 this.update(cx, |this, cx| {
1442 this.connection
1443 .resume(&this.session_id, cx)
1444 .map(|resume| resume.run(cx))
1445 })?
1446 .context("resuming a session is not supported")?
1447 .await
1448 })
1449 }
1450
1451 fn run_turn(
1452 &mut self,
1453 cx: &mut Context<Self>,
1454 f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1455 ) -> BoxFuture<'static, Result<()>> {
1456 self.clear_completed_plan_entries(cx);
1457
1458 let (tx, rx) = oneshot::channel();
1459 let cancel_task = self.cancel(cx);
1460
1461 self.send_task = Some(cx.spawn(async move |this, cx| {
1462 cancel_task.await;
1463 tx.send(f(this, cx).await).ok();
1464 }));
1465
1466 cx.spawn(async move |this, cx| {
1467 let response = rx.await;
1468
1469 this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1470 .await?;
1471
1472 this.update(cx, |this, cx| {
1473 this.project
1474 .update(cx, |project, cx| project.set_agent_location(None, cx));
1475 match response {
1476 Ok(Err(e)) => {
1477 this.send_task.take();
1478 cx.emit(AcpThreadEvent::Error);
1479 Err(e)
1480 }
1481 result => {
1482 let canceled = matches!(
1483 result,
1484 Ok(Ok(acp::PromptResponse {
1485 stop_reason: acp::StopReason::Cancelled
1486 }))
1487 );
1488
1489 // We only take the task if the current prompt wasn't canceled.
1490 //
1491 // This prompt may have been canceled because another one was sent
1492 // while it was still generating. In these cases, dropping `send_task`
1493 // would cause the next generation to be canceled.
1494 if !canceled {
1495 this.send_task.take();
1496 }
1497
1498 // Truncate entries if the last prompt was refused.
1499 if let Ok(Ok(acp::PromptResponse {
1500 stop_reason: acp::StopReason::Refusal,
1501 })) = result
1502 && let Some((ix, _)) = this.last_user_message()
1503 {
1504 let range = ix..this.entries.len();
1505 this.entries.truncate(ix);
1506 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1507 }
1508
1509 cx.emit(AcpThreadEvent::Stopped);
1510 Ok(())
1511 }
1512 }
1513 })?
1514 })
1515 .boxed()
1516 }
1517
1518 pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1519 let Some(send_task) = self.send_task.take() else {
1520 return Task::ready(());
1521 };
1522
1523 for entry in self.entries.iter_mut() {
1524 if let AgentThreadEntry::ToolCall(call) = entry {
1525 let cancel = matches!(
1526 call.status,
1527 ToolCallStatus::Pending
1528 | ToolCallStatus::WaitingForConfirmation { .. }
1529 | ToolCallStatus::InProgress
1530 );
1531
1532 if cancel {
1533 call.status = ToolCallStatus::Canceled;
1534 }
1535 }
1536 }
1537
1538 self.connection.cancel(&self.session_id, cx);
1539
1540 // Wait for the send task to complete
1541 cx.foreground_executor().spawn(send_task)
1542 }
1543
1544 /// Rewinds this thread to before the entry at `index`, removing it and all
1545 /// subsequent entries while reverting any changes made from that point.
1546 pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1547 let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1548 return Task::ready(Err(anyhow!("not supported")));
1549 };
1550 let Some(message) = self.user_message(&id) else {
1551 return Task::ready(Err(anyhow!("message not found")));
1552 };
1553
1554 let checkpoint = message
1555 .checkpoint
1556 .as_ref()
1557 .map(|c| c.git_checkpoint.clone());
1558
1559 let git_store = self.project.read(cx).git_store().clone();
1560 cx.spawn(async move |this, cx| {
1561 if let Some(checkpoint) = checkpoint {
1562 git_store
1563 .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1564 .await?;
1565 }
1566
1567 cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1568 this.update(cx, |this, cx| {
1569 if let Some((ix, _)) = this.user_message_mut(&id) {
1570 let range = ix..this.entries.len();
1571 this.entries.truncate(ix);
1572 cx.emit(AcpThreadEvent::EntriesRemoved(range));
1573 }
1574 })
1575 })
1576 }
1577
1578 fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1579 let git_store = self.project.read(cx).git_store().clone();
1580
1581 let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1582 if let Some(checkpoint) = message.checkpoint.as_ref() {
1583 checkpoint.git_checkpoint.clone()
1584 } else {
1585 return Task::ready(Ok(()));
1586 }
1587 } else {
1588 return Task::ready(Ok(()));
1589 };
1590
1591 let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1592 cx.spawn(async move |this, cx| {
1593 let new_checkpoint = new_checkpoint
1594 .await
1595 .context("failed to get new checkpoint")
1596 .log_err();
1597 if let Some(new_checkpoint) = new_checkpoint {
1598 let equal = git_store
1599 .update(cx, |git, cx| {
1600 git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1601 })?
1602 .await
1603 .unwrap_or(true);
1604 this.update(cx, |this, cx| {
1605 let (ix, message) = this.last_user_message().context("no user message")?;
1606 let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1607 checkpoint.show = !equal;
1608 cx.emit(AcpThreadEvent::EntryUpdated(ix));
1609 anyhow::Ok(())
1610 })??;
1611 }
1612
1613 Ok(())
1614 })
1615 }
1616
1617 fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1618 self.entries
1619 .iter_mut()
1620 .enumerate()
1621 .rev()
1622 .find_map(|(ix, entry)| {
1623 if let AgentThreadEntry::UserMessage(message) = entry {
1624 Some((ix, message))
1625 } else {
1626 None
1627 }
1628 })
1629 }
1630
1631 fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1632 self.entries.iter().find_map(|entry| {
1633 if let AgentThreadEntry::UserMessage(message) = entry {
1634 if message.id.as_ref() == Some(id) {
1635 Some(message)
1636 } else {
1637 None
1638 }
1639 } else {
1640 None
1641 }
1642 })
1643 }
1644
1645 fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1646 self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1647 if let AgentThreadEntry::UserMessage(message) = entry {
1648 if message.id.as_ref() == Some(id) {
1649 Some((ix, message))
1650 } else {
1651 None
1652 }
1653 } else {
1654 None
1655 }
1656 })
1657 }
1658
1659 pub fn read_text_file(
1660 &self,
1661 path: PathBuf,
1662 line: Option<u32>,
1663 limit: Option<u32>,
1664 reuse_shared_snapshot: bool,
1665 cx: &mut Context<Self>,
1666 ) -> Task<Result<String>> {
1667 let project = self.project.clone();
1668 let action_log = self.action_log.clone();
1669 cx.spawn(async move |this, cx| {
1670 let load = project.update(cx, |project, cx| {
1671 let path = project
1672 .project_path_for_absolute_path(&path, cx)
1673 .context("invalid path")?;
1674 anyhow::Ok(project.open_buffer(path, cx))
1675 });
1676 let buffer = load??.await?;
1677
1678 let snapshot = if reuse_shared_snapshot {
1679 this.read_with(cx, |this, _| {
1680 this.shared_buffers.get(&buffer.clone()).cloned()
1681 })
1682 .log_err()
1683 .flatten()
1684 } else {
1685 None
1686 };
1687
1688 let snapshot = if let Some(snapshot) = snapshot {
1689 snapshot
1690 } else {
1691 action_log.update(cx, |action_log, cx| {
1692 action_log.buffer_read(buffer.clone(), cx);
1693 })?;
1694 project.update(cx, |project, cx| {
1695 let position = buffer
1696 .read(cx)
1697 .snapshot()
1698 .anchor_before(Point::new(line.unwrap_or_default(), 0));
1699 project.set_agent_location(
1700 Some(AgentLocation {
1701 buffer: buffer.downgrade(),
1702 position,
1703 }),
1704 cx,
1705 );
1706 })?;
1707
1708 buffer.update(cx, |buffer, _| buffer.snapshot())?
1709 };
1710
1711 this.update(cx, |this, _| {
1712 let text = snapshot.text();
1713 this.shared_buffers.insert(buffer.clone(), snapshot);
1714 if line.is_none() && limit.is_none() {
1715 return Ok(text);
1716 }
1717 let limit = limit.unwrap_or(u32::MAX) as usize;
1718 let Some(line) = line else {
1719 return Ok(text.lines().take(limit).collect::<String>());
1720 };
1721
1722 let count = text.lines().count();
1723 if count < line as usize {
1724 anyhow::bail!("There are only {} lines", count);
1725 }
1726 Ok(text
1727 .lines()
1728 .skip(line as usize + 1)
1729 .take(limit)
1730 .collect::<String>())
1731 })?
1732 })
1733 }
1734
1735 pub fn write_text_file(
1736 &self,
1737 path: PathBuf,
1738 content: String,
1739 cx: &mut Context<Self>,
1740 ) -> Task<Result<()>> {
1741 let project = self.project.clone();
1742 let action_log = self.action_log.clone();
1743 cx.spawn(async move |this, cx| {
1744 let load = project.update(cx, |project, cx| {
1745 let path = project
1746 .project_path_for_absolute_path(&path, cx)
1747 .context("invalid path")?;
1748 anyhow::Ok(project.open_buffer(path, cx))
1749 });
1750 let buffer = load??.await?;
1751 let snapshot = this.update(cx, |this, cx| {
1752 this.shared_buffers
1753 .get(&buffer)
1754 .cloned()
1755 .unwrap_or_else(|| buffer.read(cx).snapshot())
1756 })?;
1757 let edits = cx
1758 .background_executor()
1759 .spawn(async move {
1760 let old_text = snapshot.text();
1761 text_diff(old_text.as_str(), &content)
1762 .into_iter()
1763 .map(|(range, replacement)| {
1764 (
1765 snapshot.anchor_after(range.start)
1766 ..snapshot.anchor_before(range.end),
1767 replacement,
1768 )
1769 })
1770 .collect::<Vec<_>>()
1771 })
1772 .await;
1773
1774 project.update(cx, |project, cx| {
1775 project.set_agent_location(
1776 Some(AgentLocation {
1777 buffer: buffer.downgrade(),
1778 position: edits
1779 .last()
1780 .map(|(range, _)| range.end)
1781 .unwrap_or(Anchor::MIN),
1782 }),
1783 cx,
1784 );
1785 })?;
1786
1787 let format_on_save = cx.update(|cx| {
1788 action_log.update(cx, |action_log, cx| {
1789 action_log.buffer_read(buffer.clone(), cx);
1790 });
1791
1792 let format_on_save = buffer.update(cx, |buffer, cx| {
1793 buffer.edit(edits, None, cx);
1794
1795 let settings = language::language_settings::language_settings(
1796 buffer.language().map(|l| l.name()),
1797 buffer.file(),
1798 cx,
1799 );
1800
1801 settings.format_on_save != FormatOnSave::Off
1802 });
1803 action_log.update(cx, |action_log, cx| {
1804 action_log.buffer_edited(buffer.clone(), cx);
1805 });
1806 format_on_save
1807 })?;
1808
1809 if format_on_save {
1810 let format_task = project.update(cx, |project, cx| {
1811 project.format(
1812 HashSet::from_iter([buffer.clone()]),
1813 LspFormatTarget::Buffers,
1814 false,
1815 FormatTrigger::Save,
1816 cx,
1817 )
1818 })?;
1819 format_task.await.log_err();
1820
1821 action_log.update(cx, |action_log, cx| {
1822 action_log.buffer_edited(buffer.clone(), cx);
1823 })?;
1824 }
1825
1826 project
1827 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1828 .await
1829 })
1830 }
1831
1832 pub fn to_markdown(&self, cx: &App) -> String {
1833 self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1834 }
1835
1836 pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1837 cx.emit(AcpThreadEvent::LoadError(error));
1838 }
1839}
1840
1841fn markdown_for_raw_output(
1842 raw_output: &serde_json::Value,
1843 language_registry: &Arc<LanguageRegistry>,
1844 cx: &mut App,
1845) -> Option<Entity<Markdown>> {
1846 match raw_output {
1847 serde_json::Value::Null => None,
1848 serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1849 Markdown::new(
1850 value.to_string().into(),
1851 Some(language_registry.clone()),
1852 None,
1853 cx,
1854 )
1855 })),
1856 serde_json::Value::Number(value) => Some(cx.new(|cx| {
1857 Markdown::new(
1858 value.to_string().into(),
1859 Some(language_registry.clone()),
1860 None,
1861 cx,
1862 )
1863 })),
1864 serde_json::Value::String(value) => Some(cx.new(|cx| {
1865 Markdown::new(
1866 value.clone().into(),
1867 Some(language_registry.clone()),
1868 None,
1869 cx,
1870 )
1871 })),
1872 value => Some(cx.new(|cx| {
1873 Markdown::new(
1874 format!("```json\n{}\n```", value).into(),
1875 Some(language_registry.clone()),
1876 None,
1877 cx,
1878 )
1879 })),
1880 }
1881}
1882
1883#[cfg(test)]
1884mod tests {
1885 use super::*;
1886 use anyhow::anyhow;
1887 use futures::{channel::mpsc, future::LocalBoxFuture, select};
1888 use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1889 use indoc::indoc;
1890 use project::{FakeFs, Fs};
1891 use rand::Rng as _;
1892 use serde_json::json;
1893 use settings::SettingsStore;
1894 use smol::stream::StreamExt as _;
1895 use std::{
1896 any::Any,
1897 cell::RefCell,
1898 path::Path,
1899 rc::Rc,
1900 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1901 time::Duration,
1902 };
1903 use util::path;
1904
1905 fn init_test(cx: &mut TestAppContext) {
1906 env_logger::try_init().ok();
1907 cx.update(|cx| {
1908 let settings_store = SettingsStore::test(cx);
1909 cx.set_global(settings_store);
1910 Project::init_settings(cx);
1911 language::init(cx);
1912 });
1913 }
1914
1915 #[gpui::test]
1916 async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1917 init_test(cx);
1918
1919 let fs = FakeFs::new(cx.executor());
1920 let project = Project::test(fs, [], cx).await;
1921 let connection = Rc::new(FakeAgentConnection::new());
1922 let thread = cx
1923 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1924 .await
1925 .unwrap();
1926
1927 // Test creating a new user message
1928 thread.update(cx, |thread, cx| {
1929 thread.push_user_content_block(
1930 None,
1931 acp::ContentBlock::Text(acp::TextContent {
1932 annotations: None,
1933 text: "Hello, ".to_string(),
1934 }),
1935 cx,
1936 );
1937 });
1938
1939 thread.update(cx, |thread, cx| {
1940 assert_eq!(thread.entries.len(), 1);
1941 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1942 assert_eq!(user_msg.id, None);
1943 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1944 } else {
1945 panic!("Expected UserMessage");
1946 }
1947 });
1948
1949 // Test appending to existing user message
1950 let message_1_id = UserMessageId::new();
1951 thread.update(cx, |thread, cx| {
1952 thread.push_user_content_block(
1953 Some(message_1_id.clone()),
1954 acp::ContentBlock::Text(acp::TextContent {
1955 annotations: None,
1956 text: "world!".to_string(),
1957 }),
1958 cx,
1959 );
1960 });
1961
1962 thread.update(cx, |thread, cx| {
1963 assert_eq!(thread.entries.len(), 1);
1964 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1965 assert_eq!(user_msg.id, Some(message_1_id));
1966 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1967 } else {
1968 panic!("Expected UserMessage");
1969 }
1970 });
1971
1972 // Test creating new user message after assistant message
1973 thread.update(cx, |thread, cx| {
1974 thread.push_assistant_content_block(
1975 acp::ContentBlock::Text(acp::TextContent {
1976 annotations: None,
1977 text: "Assistant response".to_string(),
1978 }),
1979 false,
1980 cx,
1981 );
1982 });
1983
1984 let message_2_id = UserMessageId::new();
1985 thread.update(cx, |thread, cx| {
1986 thread.push_user_content_block(
1987 Some(message_2_id.clone()),
1988 acp::ContentBlock::Text(acp::TextContent {
1989 annotations: None,
1990 text: "New user message".to_string(),
1991 }),
1992 cx,
1993 );
1994 });
1995
1996 thread.update(cx, |thread, cx| {
1997 assert_eq!(thread.entries.len(), 3);
1998 if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1999 assert_eq!(user_msg.id, Some(message_2_id));
2000 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2001 } else {
2002 panic!("Expected UserMessage at index 2");
2003 }
2004 });
2005 }
2006
2007 #[gpui::test]
2008 async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2009 init_test(cx);
2010
2011 let fs = FakeFs::new(cx.executor());
2012 let project = Project::test(fs, [], cx).await;
2013 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2014 |_, thread, mut cx| {
2015 async move {
2016 thread.update(&mut cx, |thread, cx| {
2017 thread
2018 .handle_session_update(
2019 acp::SessionUpdate::AgentThoughtChunk {
2020 content: "Thinking ".into(),
2021 },
2022 cx,
2023 )
2024 .unwrap();
2025 thread
2026 .handle_session_update(
2027 acp::SessionUpdate::AgentThoughtChunk {
2028 content: "hard!".into(),
2029 },
2030 cx,
2031 )
2032 .unwrap();
2033 })?;
2034 Ok(acp::PromptResponse {
2035 stop_reason: acp::StopReason::EndTurn,
2036 })
2037 }
2038 .boxed_local()
2039 },
2040 ));
2041
2042 let thread = cx
2043 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2044 .await
2045 .unwrap();
2046
2047 thread
2048 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2049 .await
2050 .unwrap();
2051
2052 let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2053 assert_eq!(
2054 output,
2055 indoc! {r#"
2056 ## User
2057
2058 Hello from Zed!
2059
2060 ## Assistant
2061
2062 <thinking>
2063 Thinking hard!
2064 </thinking>
2065
2066 "#}
2067 );
2068 }
2069
2070 #[gpui::test]
2071 async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2072 init_test(cx);
2073
2074 let fs = FakeFs::new(cx.executor());
2075 fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2076 .await;
2077 let project = Project::test(fs.clone(), [], cx).await;
2078 let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2079 let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2080 let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2081 move |_, thread, mut cx| {
2082 let read_file_tx = read_file_tx.clone();
2083 async move {
2084 let content = thread
2085 .update(&mut cx, |thread, cx| {
2086 thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2087 })
2088 .unwrap()
2089 .await
2090 .unwrap();
2091 assert_eq!(content, "one\ntwo\nthree\n");
2092 read_file_tx.take().unwrap().send(()).unwrap();
2093 thread
2094 .update(&mut cx, |thread, cx| {
2095 thread.write_text_file(
2096 path!("/tmp/foo").into(),
2097 "one\ntwo\nthree\nfour\nfive\n".to_string(),
2098 cx,
2099 )
2100 })
2101 .unwrap()
2102 .await
2103 .unwrap();
2104 Ok(acp::PromptResponse {
2105 stop_reason: acp::StopReason::EndTurn,
2106 })
2107 }
2108 .boxed_local()
2109 },
2110 ));
2111
2112 let (worktree, pathbuf) = project
2113 .update(cx, |project, cx| {
2114 project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2115 })
2116 .await
2117 .unwrap();
2118 let buffer = project
2119 .update(cx, |project, cx| {
2120 project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2121 })
2122 .await
2123 .unwrap();
2124
2125 let thread = cx
2126 .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2127 .await
2128 .unwrap();
2129
2130 let request = thread.update(cx, |thread, cx| {
2131 thread.send_raw("Extend the count in /tmp/foo", cx)
2132 });
2133 read_file_rx.await.ok();
2134 buffer.update(cx, |buffer, cx| {
2135 buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2136 });
2137 cx.run_until_parked();
2138 assert_eq!(
2139 buffer.read_with(cx, |buffer, _| buffer.text()),
2140 "zero\none\ntwo\nthree\nfour\nfive\n"
2141 );
2142 assert_eq!(
2143 String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2144 "zero\none\ntwo\nthree\nfour\nfive\n"
2145 );
2146 request.await.unwrap();
2147 }
2148
2149 #[gpui::test]
2150 async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2151 init_test(cx);
2152
2153 let fs = FakeFs::new(cx.executor());
2154 let project = Project::test(fs, [], cx).await;
2155 let id = acp::ToolCallId("test".into());
2156
2157 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2158 let id = id.clone();
2159 move |_, thread, mut cx| {
2160 let id = id.clone();
2161 async move {
2162 thread
2163 .update(&mut cx, |thread, cx| {
2164 thread.handle_session_update(
2165 acp::SessionUpdate::ToolCall(acp::ToolCall {
2166 id: id.clone(),
2167 title: "Label".into(),
2168 kind: acp::ToolKind::Fetch,
2169 status: acp::ToolCallStatus::InProgress,
2170 content: vec![],
2171 locations: vec![],
2172 raw_input: None,
2173 raw_output: None,
2174 }),
2175 cx,
2176 )
2177 })
2178 .unwrap()
2179 .unwrap();
2180 Ok(acp::PromptResponse {
2181 stop_reason: acp::StopReason::EndTurn,
2182 })
2183 }
2184 .boxed_local()
2185 }
2186 }));
2187
2188 let thread = cx
2189 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2190 .await
2191 .unwrap();
2192
2193 let request = thread.update(cx, |thread, cx| {
2194 thread.send_raw("Fetch https://example.com", cx)
2195 });
2196
2197 run_until_first_tool_call(&thread, cx).await;
2198
2199 thread.read_with(cx, |thread, _| {
2200 assert!(matches!(
2201 thread.entries[1],
2202 AgentThreadEntry::ToolCall(ToolCall {
2203 status: ToolCallStatus::InProgress,
2204 ..
2205 })
2206 ));
2207 });
2208
2209 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2210
2211 thread.read_with(cx, |thread, _| {
2212 assert!(matches!(
2213 &thread.entries[1],
2214 AgentThreadEntry::ToolCall(ToolCall {
2215 status: ToolCallStatus::Canceled,
2216 ..
2217 })
2218 ));
2219 });
2220
2221 thread
2222 .update(cx, |thread, cx| {
2223 thread.handle_session_update(
2224 acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2225 id,
2226 fields: acp::ToolCallUpdateFields {
2227 status: Some(acp::ToolCallStatus::Completed),
2228 ..Default::default()
2229 },
2230 }),
2231 cx,
2232 )
2233 })
2234 .unwrap();
2235
2236 request.await.unwrap();
2237
2238 thread.read_with(cx, |thread, _| {
2239 assert!(matches!(
2240 thread.entries[1],
2241 AgentThreadEntry::ToolCall(ToolCall {
2242 status: ToolCallStatus::Completed,
2243 ..
2244 })
2245 ));
2246 });
2247 }
2248
2249 #[gpui::test]
2250 async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2251 init_test(cx);
2252 let fs = FakeFs::new(cx.background_executor.clone());
2253 fs.insert_tree(path!("/test"), json!({})).await;
2254 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2255
2256 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2257 move |_, thread, mut cx| {
2258 async move {
2259 thread
2260 .update(&mut cx, |thread, cx| {
2261 thread.handle_session_update(
2262 acp::SessionUpdate::ToolCall(acp::ToolCall {
2263 id: acp::ToolCallId("test".into()),
2264 title: "Label".into(),
2265 kind: acp::ToolKind::Edit,
2266 status: acp::ToolCallStatus::Completed,
2267 content: vec![acp::ToolCallContent::Diff {
2268 diff: acp::Diff {
2269 path: "/test/test.txt".into(),
2270 old_text: None,
2271 new_text: "foo".into(),
2272 },
2273 }],
2274 locations: vec![],
2275 raw_input: None,
2276 raw_output: None,
2277 }),
2278 cx,
2279 )
2280 })
2281 .unwrap()
2282 .unwrap();
2283 Ok(acp::PromptResponse {
2284 stop_reason: acp::StopReason::EndTurn,
2285 })
2286 }
2287 .boxed_local()
2288 }
2289 }));
2290
2291 let thread = cx
2292 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2293 .await
2294 .unwrap();
2295
2296 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2297 .await
2298 .unwrap();
2299
2300 assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2301 }
2302
2303 #[gpui::test(iterations = 10)]
2304 async fn test_checkpoints(cx: &mut TestAppContext) {
2305 init_test(cx);
2306 let fs = FakeFs::new(cx.background_executor.clone());
2307 fs.insert_tree(
2308 path!("/test"),
2309 json!({
2310 ".git": {}
2311 }),
2312 )
2313 .await;
2314 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2315
2316 let simulate_changes = Arc::new(AtomicBool::new(true));
2317 let next_filename = Arc::new(AtomicUsize::new(0));
2318 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2319 let simulate_changes = simulate_changes.clone();
2320 let next_filename = next_filename.clone();
2321 let fs = fs.clone();
2322 move |request, thread, mut cx| {
2323 let fs = fs.clone();
2324 let simulate_changes = simulate_changes.clone();
2325 let next_filename = next_filename.clone();
2326 async move {
2327 if simulate_changes.load(SeqCst) {
2328 let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2329 fs.write(Path::new(&filename), b"").await?;
2330 }
2331
2332 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2333 panic!("expected text content block");
2334 };
2335 thread.update(&mut cx, |thread, cx| {
2336 thread
2337 .handle_session_update(
2338 acp::SessionUpdate::AgentMessageChunk {
2339 content: content.text.to_uppercase().into(),
2340 },
2341 cx,
2342 )
2343 .unwrap();
2344 })?;
2345 Ok(acp::PromptResponse {
2346 stop_reason: acp::StopReason::EndTurn,
2347 })
2348 }
2349 .boxed_local()
2350 }
2351 }));
2352 let thread = cx
2353 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2354 .await
2355 .unwrap();
2356
2357 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2358 .await
2359 .unwrap();
2360 thread.read_with(cx, |thread, cx| {
2361 assert_eq!(
2362 thread.to_markdown(cx),
2363 indoc! {"
2364 ## User (checkpoint)
2365
2366 Lorem
2367
2368 ## Assistant
2369
2370 LOREM
2371
2372 "}
2373 );
2374 });
2375 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2376
2377 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2378 .await
2379 .unwrap();
2380 thread.read_with(cx, |thread, cx| {
2381 assert_eq!(
2382 thread.to_markdown(cx),
2383 indoc! {"
2384 ## User (checkpoint)
2385
2386 Lorem
2387
2388 ## Assistant
2389
2390 LOREM
2391
2392 ## User (checkpoint)
2393
2394 ipsum
2395
2396 ## Assistant
2397
2398 IPSUM
2399
2400 "}
2401 );
2402 });
2403 assert_eq!(
2404 fs.files(),
2405 vec![
2406 Path::new(path!("/test/file-0")),
2407 Path::new(path!("/test/file-1"))
2408 ]
2409 );
2410
2411 // Checkpoint isn't stored when there are no changes.
2412 simulate_changes.store(false, SeqCst);
2413 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2414 .await
2415 .unwrap();
2416 thread.read_with(cx, |thread, cx| {
2417 assert_eq!(
2418 thread.to_markdown(cx),
2419 indoc! {"
2420 ## User (checkpoint)
2421
2422 Lorem
2423
2424 ## Assistant
2425
2426 LOREM
2427
2428 ## User (checkpoint)
2429
2430 ipsum
2431
2432 ## Assistant
2433
2434 IPSUM
2435
2436 ## User
2437
2438 dolor
2439
2440 ## Assistant
2441
2442 DOLOR
2443
2444 "}
2445 );
2446 });
2447 assert_eq!(
2448 fs.files(),
2449 vec![
2450 Path::new(path!("/test/file-0")),
2451 Path::new(path!("/test/file-1"))
2452 ]
2453 );
2454
2455 // Rewinding the conversation truncates the history and restores the checkpoint.
2456 thread
2457 .update(cx, |thread, cx| {
2458 let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2459 panic!("unexpected entries {:?}", thread.entries)
2460 };
2461 thread.rewind(message.id.clone().unwrap(), cx)
2462 })
2463 .await
2464 .unwrap();
2465 thread.read_with(cx, |thread, cx| {
2466 assert_eq!(
2467 thread.to_markdown(cx),
2468 indoc! {"
2469 ## User (checkpoint)
2470
2471 Lorem
2472
2473 ## Assistant
2474
2475 LOREM
2476
2477 "}
2478 );
2479 });
2480 assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2481 }
2482
2483 #[gpui::test]
2484 async fn test_refusal(cx: &mut TestAppContext) {
2485 init_test(cx);
2486 let fs = FakeFs::new(cx.background_executor.clone());
2487 fs.insert_tree(path!("/"), json!({})).await;
2488 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2489
2490 let refuse_next = Arc::new(AtomicBool::new(false));
2491 let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2492 let refuse_next = refuse_next.clone();
2493 move |request, thread, mut cx| {
2494 let refuse_next = refuse_next.clone();
2495 async move {
2496 if refuse_next.load(SeqCst) {
2497 return Ok(acp::PromptResponse {
2498 stop_reason: acp::StopReason::Refusal,
2499 });
2500 }
2501
2502 let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2503 panic!("expected text content block");
2504 };
2505 thread.update(&mut cx, |thread, cx| {
2506 thread
2507 .handle_session_update(
2508 acp::SessionUpdate::AgentMessageChunk {
2509 content: content.text.to_uppercase().into(),
2510 },
2511 cx,
2512 )
2513 .unwrap();
2514 })?;
2515 Ok(acp::PromptResponse {
2516 stop_reason: acp::StopReason::EndTurn,
2517 })
2518 }
2519 .boxed_local()
2520 }
2521 }));
2522 let thread = cx
2523 .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2524 .await
2525 .unwrap();
2526
2527 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2528 .await
2529 .unwrap();
2530 thread.read_with(cx, |thread, cx| {
2531 assert_eq!(
2532 thread.to_markdown(cx),
2533 indoc! {"
2534 ## User
2535
2536 hello
2537
2538 ## Assistant
2539
2540 HELLO
2541
2542 "}
2543 );
2544 });
2545
2546 // Simulate refusing the second message, ensuring the conversation gets
2547 // truncated to before sending it.
2548 refuse_next.store(true, SeqCst);
2549 cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2550 .await
2551 .unwrap();
2552 thread.read_with(cx, |thread, cx| {
2553 assert_eq!(
2554 thread.to_markdown(cx),
2555 indoc! {"
2556 ## User
2557
2558 hello
2559
2560 ## Assistant
2561
2562 HELLO
2563
2564 "}
2565 );
2566 });
2567 }
2568
2569 async fn run_until_first_tool_call(
2570 thread: &Entity<AcpThread>,
2571 cx: &mut TestAppContext,
2572 ) -> usize {
2573 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2574
2575 let subscription = cx.update(|cx| {
2576 cx.subscribe(thread, move |thread, _, cx| {
2577 for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2578 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2579 return tx.try_send(ix).unwrap();
2580 }
2581 }
2582 })
2583 });
2584
2585 select! {
2586 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2587 panic!("Timeout waiting for tool call")
2588 }
2589 ix = rx.next().fuse() => {
2590 drop(subscription);
2591 ix.unwrap()
2592 }
2593 }
2594 }
2595
2596 #[derive(Clone, Default)]
2597 struct FakeAgentConnection {
2598 auth_methods: Vec<acp::AuthMethod>,
2599 sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2600 on_user_message: Option<
2601 Rc<
2602 dyn Fn(
2603 acp::PromptRequest,
2604 WeakEntity<AcpThread>,
2605 AsyncApp,
2606 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2607 + 'static,
2608 >,
2609 >,
2610 }
2611
2612 impl FakeAgentConnection {
2613 fn new() -> Self {
2614 Self {
2615 auth_methods: Vec::new(),
2616 on_user_message: None,
2617 sessions: Arc::default(),
2618 }
2619 }
2620
2621 #[expect(unused)]
2622 fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2623 self.auth_methods = auth_methods;
2624 self
2625 }
2626
2627 fn on_user_message(
2628 mut self,
2629 handler: impl Fn(
2630 acp::PromptRequest,
2631 WeakEntity<AcpThread>,
2632 AsyncApp,
2633 ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2634 + 'static,
2635 ) -> Self {
2636 self.on_user_message.replace(Rc::new(handler));
2637 self
2638 }
2639 }
2640
2641 impl AgentConnection for FakeAgentConnection {
2642 fn auth_methods(&self) -> &[acp::AuthMethod] {
2643 &self.auth_methods
2644 }
2645
2646 fn new_thread(
2647 self: Rc<Self>,
2648 project: Entity<Project>,
2649 _cwd: &Path,
2650 cx: &mut App,
2651 ) -> Task<gpui::Result<Entity<AcpThread>>> {
2652 let session_id = acp::SessionId(
2653 rand::thread_rng()
2654 .sample_iter(&rand::distributions::Alphanumeric)
2655 .take(7)
2656 .map(char::from)
2657 .collect::<String>()
2658 .into(),
2659 );
2660 let action_log = cx.new(|_| ActionLog::new(project.clone()));
2661 let thread = cx.new(|cx| {
2662 AcpThread::new(
2663 "Test",
2664 self.clone(),
2665 project,
2666 action_log,
2667 session_id.clone(),
2668 watch::Receiver::constant(acp::PromptCapabilities {
2669 image: true,
2670 audio: true,
2671 embedded_context: true,
2672 }),
2673 cx,
2674 )
2675 });
2676 self.sessions.lock().insert(session_id, thread.downgrade());
2677 Task::ready(Ok(thread))
2678 }
2679
2680 fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2681 if self.auth_methods().iter().any(|m| m.id == method) {
2682 Task::ready(Ok(()))
2683 } else {
2684 Task::ready(Err(anyhow!("Invalid Auth Method")))
2685 }
2686 }
2687
2688 fn prompt(
2689 &self,
2690 _id: Option<UserMessageId>,
2691 params: acp::PromptRequest,
2692 cx: &mut App,
2693 ) -> Task<gpui::Result<acp::PromptResponse>> {
2694 let sessions = self.sessions.lock();
2695 let thread = sessions.get(¶ms.session_id).unwrap();
2696 if let Some(handler) = &self.on_user_message {
2697 let handler = handler.clone();
2698 let thread = thread.clone();
2699 cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2700 } else {
2701 Task::ready(Ok(acp::PromptResponse {
2702 stop_reason: acp::StopReason::EndTurn,
2703 }))
2704 }
2705 }
2706
2707 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2708 let sessions = self.sessions.lock();
2709 let thread = sessions.get(session_id).unwrap().clone();
2710
2711 cx.spawn(async move |cx| {
2712 thread
2713 .update(cx, |thread, cx| thread.cancel(cx))
2714 .unwrap()
2715 .await
2716 })
2717 .detach();
2718 }
2719
2720 fn truncate(
2721 &self,
2722 session_id: &acp::SessionId,
2723 _cx: &App,
2724 ) -> Option<Rc<dyn AgentSessionTruncate>> {
2725 Some(Rc::new(FakeAgentSessionEditor {
2726 _session_id: session_id.clone(),
2727 }))
2728 }
2729
2730 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2731 self
2732 }
2733 }
2734
2735 struct FakeAgentSessionEditor {
2736 _session_id: acp::SessionId,
2737 }
2738
2739 impl AgentSessionTruncate for FakeAgentSessionEditor {
2740 fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2741 Task::ready(Ok(()))
2742 }
2743 }
2744}